added util.variable.Event.watch_weakref for code reuse
[p2pool.git] / p2pool / util / forest.py
1 '''
2 forest data structure
3 '''
4
5 import itertools
6
7 from p2pool.util import skiplist, variable
8
9
10 class TrackerSkipList(skiplist.SkipList):
11     def __init__(self, tracker):
12         skiplist.SkipList.__init__(self)
13         self.tracker = tracker
14         
15         self.tracker.removed.watch_weakref(self, lambda self, item: self.forget_item(item.hash))
16     
17     def previous(self, element):
18         return self.tracker._delta_type.from_element(self.tracker.items[element]).tail
19
20
21 class DistanceSkipList(TrackerSkipList):
22     def get_delta(self, element):
23         return element, 1, self.previous(element)
24     
25     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
26         if to_hash1 != from_hash2:
27             raise AssertionError()
28         return from_hash1, dist1 + dist2, to_hash2
29     
30     def initial_solution(self, start, (n,)):
31         return 0, start
32     
33     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
34         if to_hash1 != from_hash2:
35             raise AssertionError()
36         return dist1 + dist2, to_hash2
37     
38     def judge(self, (dist, hash), (n,)):
39         if dist > n:
40             return 1
41         elif dist == n:
42             return 0
43         else:
44             return -1
45     
46     def finalize(self, (dist, hash), (n,)):
47         assert dist == n
48         return hash
49
50 def get_attributedelta_type(attrs): # attrs: {name: func}
51     class ProtoAttributeDelta(object):
52         __slots__ = ['head', 'tail'] + attrs.keys()
53         
54         @classmethod
55         def get_none(cls, element_id):
56             return cls(element_id, element_id, **dict((k, 0) for k in attrs))
57         
58         @classmethod
59         def from_element(cls, item):
60             return cls(item.hash, item.previous_hash, **dict((k, v(item)) for k, v in attrs.iteritems()))
61         
62         def __init__(self, head, tail, **kwargs):
63             self.head, self.tail = head, tail
64             for k, v in kwargs.iteritems():
65                 setattr(self, k, v)
66         
67         def __add__(self, other):
68             assert self.tail == other.head
69             return self.__class__(self.head, other.tail, **dict((k, getattr(self, k) + getattr(other, k)) for k in attrs))
70         
71         def __sub__(self, other):
72             if self.head == other.head:
73                 return self.__class__(other.tail, self.tail, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
74             elif self.tail == other.tail:
75                 return self.__class__(self.head, other.head, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
76             else:
77                 raise AssertionError()
78         
79         def __repr__(self):
80             return '%s(%r, %r%s)' % (self.__class__, self.head, self.tail, ''.join(', %s=%r' % (k, getattr(self, k)) for k in attrs))
81     ProtoAttributeDelta.attrs = attrs
82     return ProtoAttributeDelta
83
84 AttributeDelta = get_attributedelta_type(dict(
85     height=lambda item: 1,
86 ))
87
88 class Tracker(object):
89     def __init__(self, items=[], delta_type=AttributeDelta):
90         self.items = {} # hash -> item
91         self.reverse = {} # delta.tail -> set of item_hashes
92         
93         self.heads = {} # head hash -> tail_hash
94         self.tails = {} # tail hash -> set of head hashes
95         
96         self._deltas = {} # item_hash -> delta, ref
97         self._reverse_deltas = {} # ref -> set of item_hashes
98         
99         self._ref_generator = itertools.count()
100         self._delta_refs = {} # ref -> delta
101         self._reverse_delta_refs = {} # delta.tail -> ref
102         
103         self.added = variable.Event()
104         self.removed = variable.Event()
105         
106         self.get_nth_parent_hash = DistanceSkipList(self)
107         
108         self._delta_type = delta_type
109         
110         for item in items:
111             self.add(item)
112     
113     def add(self, item):
114         assert not isinstance(item, (int, long, type(None)))
115         delta = self._delta_type.from_element(item)
116         
117         if delta.head in self.items:
118             raise ValueError('item already present')
119         
120         if delta.head in self.tails:
121             heads = self.tails.pop(delta.head)
122         else:
123             heads = set([delta.head])
124         
125         if delta.tail in self.heads:
126             tail = self.heads.pop(delta.tail)
127         else:
128             tail = self.get_last(delta.tail)
129         
130         self.items[delta.head] = item
131         self.reverse.setdefault(delta.tail, set()).add(delta.head)
132         
133         self.tails.setdefault(tail, set()).update(heads)
134         if delta.tail in self.tails[tail]:
135             self.tails[tail].remove(delta.tail)
136         
137         for head in heads:
138             self.heads[head] = tail
139         
140         self.added.happened(item)
141     
142     def remove(self, item_hash):
143         assert isinstance(item_hash, (int, long, type(None)))
144         if item_hash not in self.items:
145             raise KeyError()
146         
147         item = self.items[item_hash]
148         del item_hash
149         
150         delta = self._delta_type.from_element(item)
151         
152         children = self.reverse.get(delta.head, set())
153         
154         if delta.head in self.heads and delta.tail in self.tails:
155             tail = self.heads.pop(delta.head)
156             self.tails[tail].remove(delta.head)
157             if not self.tails[delta.tail]:
158                 self.tails.pop(delta.tail)
159         elif delta.head in self.heads:
160             tail = self.heads.pop(delta.head)
161             self.tails[tail].remove(delta.head)
162             if self.reverse[delta.tail] != set([delta.head]):
163                 pass # has sibling
164             else:
165                 self.tails[tail].add(delta.tail)
166                 self.heads[delta.tail] = tail
167         elif delta.tail in self.tails and len(self.reverse[delta.tail]) <= 1:
168             # move delta refs referencing children down to this, so they can be moved up in one step
169             if delta.tail in self._reverse_delta_refs:
170                 for x in list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set())):
171                     self.get_last(x)
172                 assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, None), set()))
173             
174             heads = self.tails.pop(delta.tail)
175             for head in heads:
176                 self.heads[head] = delta.head
177             self.tails[delta.head] = set(heads)
178             
179             # move ref pointing to this up
180             if delta.tail in self._reverse_delta_refs:
181                 assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set()))
182                 
183                 ref = self._reverse_delta_refs[delta.tail]
184                 cur_delta = self._delta_refs[ref]
185                 assert cur_delta.tail == delta.tail
186                 self._delta_refs[ref] = cur_delta - self._delta_type.from_element(item)
187                 assert self._delta_refs[ref].tail == delta.head
188                 del self._reverse_delta_refs[delta.tail]
189                 self._reverse_delta_refs[delta.head] = ref
190         else:
191             raise NotImplementedError()
192         
193         # delete delta entry and ref if it is empty
194         if delta.head in self._deltas:
195             delta1, ref = self._deltas.pop(delta.head)
196             self._reverse_deltas[ref].remove(delta.head)
197             if not self._reverse_deltas[ref]:
198                 del self._reverse_deltas[ref]
199                 delta2 = self._delta_refs.pop(ref)
200                 del self._reverse_delta_refs[delta2.tail]
201         
202         self.items.pop(delta.head)
203         self.reverse[delta.tail].remove(delta.head)
204         if not self.reverse[delta.tail]:
205             self.reverse.pop(delta.tail)
206         
207         self.removed.happened(item)
208     
209     def get_height(self, item_hash):
210         return self.get_delta_to_last(item_hash).height
211     
212     def get_work(self, item_hash):
213         return self.get_delta_to_last(item_hash).work
214     
215     def get_last(self, item_hash):
216         return self.get_delta_to_last(item_hash).tail
217     
218     def get_height_and_last(self, item_hash):
219         delta = self.get_delta_to_last(item_hash)
220         return delta.height, delta.tail
221     
222     def _get_delta(self, item_hash):
223         if item_hash in self._deltas:
224             delta1, ref = self._deltas[item_hash]
225             delta2 = self._delta_refs[ref]
226             res = delta1 + delta2
227         else:
228             res = self._delta_type.from_element(self.items[item_hash])
229         assert res.head == item_hash
230         return res
231     
232     def _set_delta(self, item_hash, delta):
233         other_item_hash = delta.tail
234         if other_item_hash not in self._reverse_delta_refs:
235             ref = self._ref_generator.next()
236             assert ref not in self._delta_refs
237             self._delta_refs[ref] = self._delta_type.get_none(other_item_hash)
238             self._reverse_delta_refs[other_item_hash] = ref
239             del ref
240         
241         ref = self._reverse_delta_refs[other_item_hash]
242         ref_delta = self._delta_refs[ref]
243         assert ref_delta.tail == other_item_hash
244         
245         if item_hash in self._deltas:
246             prev_ref = self._deltas[item_hash][1]
247             self._reverse_deltas[prev_ref].remove(item_hash)
248             if not self._reverse_deltas[prev_ref] and prev_ref != ref:
249                 self._reverse_deltas.pop(prev_ref)
250                 x = self._delta_refs.pop(prev_ref)
251                 self._reverse_delta_refs.pop(x.tail)
252         self._deltas[item_hash] = delta - ref_delta, ref
253         self._reverse_deltas.setdefault(ref, set()).add(item_hash)
254     
255     def get_delta_to_last(self, item_hash):
256         assert isinstance(item_hash, (int, long, type(None)))
257         delta = self._delta_type.get_none(item_hash)
258         updates = []
259         while delta.tail in self.items:
260             updates.append((delta.tail, delta))
261             this_delta = self._get_delta(delta.tail)
262             delta += this_delta
263         for update_hash, delta_then in updates:
264             self._set_delta(update_hash, delta - delta_then)
265         return delta
266     
267     def get_delta(self, item, ancestor):
268         assert self.is_child_of(ancestor, item)
269         return self.get_delta_to_last(item) - self.get_delta_to_last(ancestor)
270     
271     def get_chain(self, start_hash, length):
272         assert length <= self.get_height(start_hash)
273         for i in xrange(length):
274             yield self.items[start_hash]
275             start_hash = self._delta_type.from_element(self.items[start_hash]).tail
276     
277     def is_child_of(self, item_hash, possible_child_hash):
278         height, last = self.get_height_and_last(item_hash)
279         child_height, child_last = self.get_height_and_last(possible_child_hash)
280         if child_last != last:
281             return None # not connected, so can't be determined
282         height_up = child_height - height
283         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == item_hash
284
285 class SubsetTracker(Tracker):
286     def __init__(self, subset_of, **kwargs):
287         Tracker.__init__(self, **kwargs)
288         self.get_nth_parent_hash = subset_of.get_nth_parent_hash # overwrites Tracker.__init__'s
289         self._subset_of = subset_of
290     
291     def add(self, item):
292         delta = self._delta_type.from_element(item)
293         if self._subset_of is not None:
294             assert delta.head in self._subset_of.items
295         Tracker.add(self, item)
296     
297     def remove(self, item_hash):
298         if self._subset_of is not None:
299             assert item_hash in self._subset_of.items
300         Tracker.remove(self, item_hash)