c9fe9e734973b184288e939bde48d5177990316d
[p2pool.git] / p2pool / util / forest.py
1 '''
2 forest data structure
3 '''
4
5 import itertools
6 import weakref
7
8 from p2pool.util import skiplist, variable
9
10
11 class TrackerSkipList(skiplist.SkipList):
12     def __init__(self, tracker):
13         skiplist.SkipList.__init__(self)
14         self.tracker = tracker
15         
16         self_ref = weakref.ref(self, lambda _: tracker.removed.unwatch(watch_id))
17         watch_id = self.tracker.removed.watch(lambda item: self_ref().forget_item(item.hash))
18     
19     def previous(self, element):
20         return self.tracker._delta_type.from_element(self.tracker.items[element]).tail
21
22
23 class DistanceSkipList(TrackerSkipList):
24     def get_delta(self, element):
25         return element, 1, self.previous(element)
26     
27     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
28         if to_hash1 != from_hash2:
29             raise AssertionError()
30         return from_hash1, dist1 + dist2, to_hash2
31     
32     def initial_solution(self, start, (n,)):
33         return 0, start
34     
35     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
36         if to_hash1 != from_hash2:
37             raise AssertionError()
38         return dist1 + dist2, to_hash2
39     
40     def judge(self, (dist, hash), (n,)):
41         if dist > n:
42             return 1
43         elif dist == n:
44             return 0
45         else:
46             return -1
47     
48     def finalize(self, (dist, hash), (n,)):
49         assert dist == n
50         return hash
51
52 def get_attributedelta_type(attrs): # attrs: {name: func}
53     class ProtoAttributeDelta(object):
54         __slots__ = ['head', 'tail'] + attrs.keys()
55         
56         @classmethod
57         def get_none(cls, element_id):
58             return cls(element_id, element_id, **dict((k, 0) for k in attrs))
59         
60         @classmethod
61         def from_element(cls, item):
62             return cls(item.hash, item.previous_hash, **dict((k, v(item)) for k, v in attrs.iteritems()))
63         
64         def __init__(self, head, tail, **kwargs):
65             self.head, self.tail = head, tail
66             for k, v in kwargs.iteritems():
67                 setattr(self, k, v)
68         
69         def __add__(self, other):
70             assert self.tail == other.head
71             return self.__class__(self.head, other.tail, **dict((k, getattr(self, k) + getattr(other, k)) for k in attrs))
72         
73         def __sub__(self, other):
74             if self.head == other.head:
75                 return self.__class__(other.tail, self.tail, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
76             elif self.tail == other.tail:
77                 return self.__class__(self.head, other.head, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
78             else:
79                 raise AssertionError()
80         
81         def __repr__(self):
82             return '%s(%r, %r%s)' % (self.__class__, self.head, self.tail, ''.join(', %s=%r' % (k, getattr(self, k)) for k in attrs))
83     ProtoAttributeDelta.attrs = attrs
84     return ProtoAttributeDelta
85
86 AttributeDelta = get_attributedelta_type(dict(
87     height=lambda item: 1,
88 ))
89
90 class Tracker(object):
91     def __init__(self, items=[], delta_type=AttributeDelta, subset_of=None):
92         self.items = {} # hash -> item
93         self.reverse = {} # delta.tail -> set of item_hashes
94         
95         self.heads = {} # head hash -> tail_hash
96         self.tails = {} # tail hash -> set of head hashes
97         
98         self._deltas = {} # item_hash -> delta, ref
99         self._reverse_deltas = {} # ref -> set of item_hashes
100         
101         self._ref_generator = itertools.count()
102         self._delta_refs = {} # ref -> delta
103         self._reverse_delta_refs = {} # delta.tail -> ref
104         
105         self.added = variable.Event()
106         self.removed = variable.Event()
107         
108         if subset_of is None:
109             self.get_nth_parent_hash = DistanceSkipList(self)
110         else:
111             self.get_nth_parent_hash = subset_of.get_nth_parent_hash
112         
113         self._delta_type = delta_type
114         self._subset_of = subset_of
115         
116         for item in items:
117             self.add(item)
118     
119     def add(self, item):
120         assert not isinstance(item, (int, long, type(None)))
121         delta = self._delta_type.from_element(item)
122         if self._subset_of is not None:
123             assert delta.head in self._subset_of.items
124         
125         if delta.head in self.items:
126             raise ValueError('item already present')
127         
128         if delta.head in self.tails:
129             heads = self.tails.pop(delta.head)
130         else:
131             heads = set([delta.head])
132         
133         if delta.tail in self.heads:
134             tail = self.heads.pop(delta.tail)
135         else:
136             tail = self.get_last(delta.tail)
137         
138         self.items[delta.head] = item
139         self.reverse.setdefault(delta.tail, set()).add(delta.head)
140         
141         self.tails.setdefault(tail, set()).update(heads)
142         if delta.tail in self.tails[tail]:
143             self.tails[tail].remove(delta.tail)
144         
145         for head in heads:
146             self.heads[head] = tail
147         
148         self.added.happened(item)
149     
150     def remove(self, item_hash):
151         assert isinstance(item_hash, (int, long, type(None)))
152         if item_hash not in self.items:
153             raise KeyError()
154         if self._subset_of is not None:
155             assert item_hash in self._subset_of.items
156         
157         item = self.items[item_hash]
158         del item_hash
159         
160         delta = self._delta_type.from_element(item)
161         
162         children = self.reverse.get(delta.head, set())
163         
164         if delta.head in self.heads and delta.tail in self.tails:
165             tail = self.heads.pop(delta.head)
166             self.tails[tail].remove(delta.head)
167             if not self.tails[delta.tail]:
168                 self.tails.pop(delta.tail)
169         elif delta.head in self.heads:
170             tail = self.heads.pop(delta.head)
171             self.tails[tail].remove(delta.head)
172             if self.reverse[delta.tail] != set([delta.head]):
173                 pass # has sibling
174             else:
175                 self.tails[tail].add(delta.tail)
176                 self.heads[delta.tail] = tail
177         elif delta.tail in self.tails and len(self.reverse[delta.tail]) <= 1:
178             # move delta refs referencing children down to this, so they can be moved up in one step
179             if delta.tail in self._reverse_delta_refs:
180                 for x in list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set())):
181                     self.get_last(x)
182                 assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, None), set()))
183             
184             heads = self.tails.pop(delta.tail)
185             for head in heads:
186                 self.heads[head] = delta.head
187             self.tails[delta.head] = set(heads)
188             
189             # move ref pointing to this up
190             if delta.tail in self._reverse_delta_refs:
191                 assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set()))
192                 
193                 ref = self._reverse_delta_refs[delta.tail]
194                 cur_delta = self._delta_refs[ref]
195                 assert cur_delta.tail == delta.tail
196                 self._delta_refs[ref] = cur_delta - self._delta_type.from_element(item)
197                 assert self._delta_refs[ref].tail == delta.head
198                 del self._reverse_delta_refs[delta.tail]
199                 self._reverse_delta_refs[delta.head] = ref
200         else:
201             raise NotImplementedError()
202         
203         # delete delta entry and ref if it is empty
204         if delta.head in self._deltas:
205             delta1, ref = self._deltas.pop(delta.head)
206             self._reverse_deltas[ref].remove(delta.head)
207             if not self._reverse_deltas[ref]:
208                 del self._reverse_deltas[ref]
209                 delta2 = self._delta_refs.pop(ref)
210                 del self._reverse_delta_refs[delta2.tail]
211         
212         self.items.pop(delta.head)
213         self.reverse[delta.tail].remove(delta.head)
214         if not self.reverse[delta.tail]:
215             self.reverse.pop(delta.tail)
216         
217         self.removed.happened(item)
218     
219     def get_height(self, item_hash):
220         return self.get_delta_to_last(item_hash).height
221     
222     def get_work(self, item_hash):
223         return self.get_delta_to_last(item_hash).work
224     
225     def get_last(self, item_hash):
226         return self.get_delta_to_last(item_hash).tail
227     
228     def get_height_and_last(self, item_hash):
229         delta = self.get_delta_to_last(item_hash)
230         return delta.height, delta.tail
231     
232     def _get_delta(self, item_hash):
233         if item_hash in self._deltas:
234             delta1, ref = self._deltas[item_hash]
235             delta2 = self._delta_refs[ref]
236             res = delta1 + delta2
237         else:
238             res = self._delta_type.from_element(self.items[item_hash])
239         assert res.head == item_hash
240         return res
241     
242     def _set_delta(self, item_hash, delta):
243         other_item_hash = delta.tail
244         if other_item_hash not in self._reverse_delta_refs:
245             ref = self._ref_generator.next()
246             assert ref not in self._delta_refs
247             self._delta_refs[ref] = self._delta_type.get_none(other_item_hash)
248             self._reverse_delta_refs[other_item_hash] = ref
249             del ref
250         
251         ref = self._reverse_delta_refs[other_item_hash]
252         ref_delta = self._delta_refs[ref]
253         assert ref_delta.tail == other_item_hash
254         
255         if item_hash in self._deltas:
256             prev_ref = self._deltas[item_hash][1]
257             self._reverse_deltas[prev_ref].remove(item_hash)
258             if not self._reverse_deltas[prev_ref] and prev_ref != ref:
259                 self._reverse_deltas.pop(prev_ref)
260                 x = self._delta_refs.pop(prev_ref)
261                 self._reverse_delta_refs.pop(x.tail)
262         self._deltas[item_hash] = delta - ref_delta, ref
263         self._reverse_deltas.setdefault(ref, set()).add(item_hash)
264     
265     def get_delta_to_last(self, item_hash):
266         assert isinstance(item_hash, (int, long, type(None)))
267         delta = self._delta_type.get_none(item_hash)
268         updates = []
269         while delta.tail in self.items:
270             updates.append((delta.tail, delta))
271             this_delta = self._get_delta(delta.tail)
272             delta += this_delta
273         for update_hash, delta_then in updates:
274             self._set_delta(update_hash, delta - delta_then)
275         return delta
276     
277     def get_delta(self, item, ancestor):
278         assert self.is_child_of(ancestor, item)
279         return self.get_delta_to_last(item) - self.get_delta_to_last(ancestor)
280     
281     def get_chain(self, start_hash, length):
282         assert length <= self.get_height(start_hash)
283         for i in xrange(length):
284             yield self.items[start_hash]
285             start_hash = self._delta_type.from_element(self.items[start_hash]).tail
286     
287     def is_child_of(self, item_hash, possible_child_hash):
288         height, last = self.get_height_and_last(item_hash)
289         child_height, child_last = self.get_height_and_last(possible_child_hash)
290         if child_last != last:
291             return None # not connected, so can't be determined
292         height_up = child_height - height
293         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == item_hash