fixed error in assertion text
[p2pool.git] / p2pool / util / forest.py
1 '''
2 forest data structure
3 '''
4
5 import itertools
6
7 from p2pool.util import skiplist, variable
8
9
10 class TrackerSkipList(skiplist.SkipList):
11     def __init__(self, tracker):
12         skiplist.SkipList.__init__(self)
13         self.tracker = tracker
14         
15         self.tracker.removed.watch_weakref(self, lambda self, item: self.forget_item(item.hash))
16     
17     def previous(self, element):
18         return self.tracker._delta_type.from_element(self.tracker.items[element]).tail
19
20
21 class DistanceSkipList(TrackerSkipList):
22     def get_delta(self, element):
23         return element, 1, self.previous(element)
24     
25     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
26         if to_hash1 != from_hash2:
27             raise AssertionError()
28         return from_hash1, dist1 + dist2, to_hash2
29     
30     def initial_solution(self, start, (n,)):
31         return 0, start
32     
33     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
34         if to_hash1 != from_hash2:
35             raise AssertionError()
36         return dist1 + dist2, to_hash2
37     
38     def judge(self, (dist, hash), (n,)):
39         if dist > n:
40             return 1
41         elif dist == n:
42             return 0
43         else:
44             return -1
45     
46     def finalize(self, (dist, hash), (n,)):
47         assert dist == n
48         return hash
49
50 def get_attributedelta_type(attrs): # attrs: {name: func}
51     class ProtoAttributeDelta(object):
52         __slots__ = ['head', 'tail'] + attrs.keys()
53         
54         @classmethod
55         def get_none(cls, element_id):
56             return cls(element_id, element_id, **dict((k, 0) for k in attrs))
57         
58         @classmethod
59         def from_element(cls, item):
60             return cls(item.hash, item.previous_hash, **dict((k, v(item)) for k, v in attrs.iteritems()))
61         
62         @staticmethod
63         def get_head(item):
64             return item.hash
65         
66         @staticmethod
67         def get_tail(item):
68             return item.previous_hash
69         
70         def __init__(self, head, tail, **kwargs):
71             self.head, self.tail = head, tail
72             for k, v in kwargs.iteritems():
73                 setattr(self, k, v)
74         
75         def __add__(self, other):
76             assert self.tail == other.head
77             return self.__class__(self.head, other.tail, **dict((k, getattr(self, k) + getattr(other, k)) for k in attrs))
78         
79         def __sub__(self, other):
80             if self.head == other.head:
81                 return self.__class__(other.tail, self.tail, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
82             elif self.tail == other.tail:
83                 return self.__class__(self.head, other.head, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
84             else:
85                 raise AssertionError()
86         
87         def __repr__(self):
88             return '%s(%r, %r%s)' % (self.__class__, self.head, self.tail, ''.join(', %s=%r' % (k, getattr(self, k)) for k in attrs))
89     ProtoAttributeDelta.attrs = attrs
90     return ProtoAttributeDelta
91
92 AttributeDelta = get_attributedelta_type(dict(
93     height=lambda item: 1,
94 ))
95
96 class TrackerView(object):
97     def __init__(self, tracker, delta_type):
98         self._tracker = tracker
99         self._delta_type = delta_type
100         
101         self._deltas = {} # item_hash -> delta, ref
102         self._reverse_deltas = {} # ref -> set of item_hashes
103         
104         self._ref_generator = itertools.count()
105         self._delta_refs = {} # ref -> delta
106         self._reverse_delta_refs = {} # delta.tail -> ref
107         
108         self._tracker.remove_special.watch_weakref(self, lambda self, item: self._handle_remove_special(item))
109         self._tracker.remove_special2.watch_weakref(self, lambda self, item: self._handle_remove_special2(item))
110         self._tracker.removed.watch_weakref(self, lambda self, item: self._handle_removed(item))
111     
112     def _handle_remove_special(self, item):
113         delta = self._delta_type.from_element(item)
114         
115         if delta.tail not in self._reverse_delta_refs:
116             return
117         
118         # move delta refs referencing children down to this, so they can be moved up in one step
119         for x in list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set())):
120             self.get_last(x)
121         
122         assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set()))
123         
124         if delta.tail not in self._reverse_delta_refs:
125             return
126         
127         # move ref pointing to this up
128         
129         ref = self._reverse_delta_refs[delta.tail]
130         cur_delta = self._delta_refs[ref]
131         assert cur_delta.tail == delta.tail
132         self._delta_refs[ref] = cur_delta - delta
133         assert self._delta_refs[ref].tail == delta.head
134         del self._reverse_delta_refs[delta.tail]
135         self._reverse_delta_refs[delta.head] = ref
136     
137     def _handle_remove_special2(self, item):
138         delta = self._delta_type.from_element(item)
139         
140         if delta.tail not in self._reverse_delta_refs:
141             return
142         
143         ref = self._reverse_delta_refs.pop(delta.tail)
144         del self._delta_refs[ref]
145         
146         for x in self._reverse_deltas.pop(ref):
147             del self._deltas[x]
148     
149     def _handle_removed(self, item):
150         delta = self._delta_type.from_element(item)
151         
152         # delete delta entry and ref if it is empty
153         if delta.head in self._deltas:
154             delta1, ref = self._deltas.pop(delta.head)
155             self._reverse_deltas[ref].remove(delta.head)
156             if not self._reverse_deltas[ref]:
157                 del self._reverse_deltas[ref]
158                 delta2 = self._delta_refs.pop(ref)
159                 del self._reverse_delta_refs[delta2.tail]
160     
161     
162     def get_height(self, item_hash):
163         return self.get_delta_to_last(item_hash).height
164     
165     def get_work(self, item_hash):
166         return self.get_delta_to_last(item_hash).work
167     
168     def get_last(self, item_hash):
169         return self.get_delta_to_last(item_hash).tail
170     
171     def get_height_and_last(self, item_hash):
172         delta = self.get_delta_to_last(item_hash)
173         return delta.height, delta.tail
174     
175     def _get_delta(self, item_hash):
176         if item_hash in self._deltas:
177             delta1, ref = self._deltas[item_hash]
178             delta2 = self._delta_refs[ref]
179             res = delta1 + delta2
180         else:
181             res = self._delta_type.from_element(self._tracker.items[item_hash])
182         assert res.head == item_hash
183         return res
184     
185     def _set_delta(self, item_hash, delta):
186         other_item_hash = delta.tail
187         if other_item_hash not in self._reverse_delta_refs:
188             ref = self._ref_generator.next()
189             assert ref not in self._delta_refs
190             self._delta_refs[ref] = self._delta_type.get_none(other_item_hash)
191             self._reverse_delta_refs[other_item_hash] = ref
192             del ref
193         
194         ref = self._reverse_delta_refs[other_item_hash]
195         ref_delta = self._delta_refs[ref]
196         assert ref_delta.tail == other_item_hash
197         
198         if item_hash in self._deltas:
199             prev_ref = self._deltas[item_hash][1]
200             self._reverse_deltas[prev_ref].remove(item_hash)
201             if not self._reverse_deltas[prev_ref] and prev_ref != ref:
202                 self._reverse_deltas.pop(prev_ref)
203                 x = self._delta_refs.pop(prev_ref)
204                 self._reverse_delta_refs.pop(x.tail)
205         self._deltas[item_hash] = delta - ref_delta, ref
206         self._reverse_deltas.setdefault(ref, set()).add(item_hash)
207     
208     def get_delta_to_last(self, item_hash):
209         assert isinstance(item_hash, (int, long, type(None)))
210         delta = self._delta_type.get_none(item_hash)
211         updates = []
212         while delta.tail in self._tracker.items:
213             updates.append((delta.tail, delta))
214             this_delta = self._get_delta(delta.tail)
215             delta += this_delta
216         for update_hash, delta_then in updates:
217             self._set_delta(update_hash, delta - delta_then)
218         return delta
219     
220     def get_delta(self, item, ancestor):
221         assert self._tracker.is_child_of(ancestor, item)
222         return self.get_delta_to_last(item) - self.get_delta_to_last(ancestor)
223
224 class Tracker(object):
225     def __init__(self, items=[], delta_type=AttributeDelta):
226         self.items = {} # hash -> item
227         self.reverse = {} # delta.tail -> set of item_hashes
228         
229         self.heads = {} # head hash -> tail_hash
230         self.tails = {} # tail hash -> set of head hashes
231         
232         self.added = variable.Event()
233         self.remove_special = variable.Event()
234         self.remove_special2 = variable.Event()
235         self.removed = variable.Event()
236         
237         self.get_nth_parent_hash = DistanceSkipList(self)
238         
239         self._delta_type = delta_type
240         self._default_view = TrackerView(self, delta_type)
241         
242         for item in items:
243             self.add(item)
244     
245     def __getattr__(self, name):
246         attr = getattr(self._default_view, name)
247         setattr(self, name, attr)
248         return attr
249     
250     def add(self, item):
251         assert not isinstance(item, (int, long, type(None)))
252         delta = self._delta_type.from_element(item)
253         
254         if delta.head in self.items:
255             raise ValueError('item already present')
256         
257         if delta.head in self.tails:
258             heads = self.tails.pop(delta.head)
259         else:
260             heads = set([delta.head])
261         
262         if delta.tail in self.heads:
263             tail = self.heads.pop(delta.tail)
264         else:
265             tail = self.get_last(delta.tail)
266         
267         self.items[delta.head] = item
268         self.reverse.setdefault(delta.tail, set()).add(delta.head)
269         
270         self.tails.setdefault(tail, set()).update(heads)
271         if delta.tail in self.tails[tail]:
272             self.tails[tail].remove(delta.tail)
273         
274         for head in heads:
275             self.heads[head] = tail
276         
277         self.added.happened(item)
278     
279     def remove(self, item_hash):
280         assert isinstance(item_hash, (int, long, type(None)))
281         if item_hash not in self.items:
282             raise KeyError()
283         
284         item = self.items[item_hash]
285         del item_hash
286         
287         delta = self._delta_type.from_element(item)
288         
289         children = self.reverse.get(delta.head, set())
290         
291         if delta.head in self.heads and delta.tail in self.tails:
292             tail = self.heads.pop(delta.head)
293             self.tails[tail].remove(delta.head)
294             if not self.tails[delta.tail]:
295                 self.tails.pop(delta.tail)
296         elif delta.head in self.heads:
297             tail = self.heads.pop(delta.head)
298             self.tails[tail].remove(delta.head)
299             if self.reverse[delta.tail] != set([delta.head]):
300                 pass # has sibling
301             else:
302                 self.tails[tail].add(delta.tail)
303                 self.heads[delta.tail] = tail
304         elif delta.tail in self.tails and len(self.reverse[delta.tail]) <= 1:
305             heads = self.tails.pop(delta.tail)
306             for head in heads:
307                 self.heads[head] = delta.head
308             self.tails[delta.head] = set(heads)
309             
310             self.remove_special.happened(item)
311         elif delta.tail in self.tails and len(self.reverse[delta.tail]) > 1:
312             heads = [x for x in self.tails[delta.tail] if self.is_child_of(delta.head, x)]
313             self.tails[delta.tail] -= set(heads)
314             if not self.tails[delta.tail]:
315                 self.tails.pop(delta.tail)
316             for head in heads:
317                 self.heads[head] = delta.head
318             assert delta.head not in self.tails
319             self.tails[delta.head] = set(heads)
320             
321             self.remove_special2.happened(item)
322         else:
323             raise NotImplementedError()
324         
325         self.items.pop(delta.head)
326         self.reverse[delta.tail].remove(delta.head)
327         if not self.reverse[delta.tail]:
328             self.reverse.pop(delta.tail)
329         
330         self.removed.happened(item)
331     
332     def get_chain(self, start_hash, length):
333         assert length <= self.get_height(start_hash)
334         for i in xrange(length):
335             item = self.items[start_hash]
336             yield item
337             start_hash = self._delta_type.get_tail(item)
338     
339     def is_child_of(self, item_hash, possible_child_hash):
340         height, last = self.get_height_and_last(item_hash)
341         child_height, child_last = self.get_height_and_last(possible_child_hash)
342         if child_last != last:
343             return None # not connected, so can't be determined
344         height_up = child_height - height
345         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == item_hash
346
347 class SubsetTracker(Tracker):
348     def __init__(self, subset_of, **kwargs):
349         Tracker.__init__(self, **kwargs)
350         self.get_nth_parent_hash = subset_of.get_nth_parent_hash # overwrites Tracker.__init__'s
351         self._subset_of = subset_of
352     
353     def add(self, item):
354         if self._subset_of is not None:
355             assert self._delta_type.get_head(item) in self._subset_of.items
356         Tracker.add(self, item)
357     
358     def remove(self, item_hash):
359         if self._subset_of is not None:
360             assert item_hash in self._subset_of.items
361         Tracker.remove(self, item_hash)