32e6fff9232849b417dcfe19fe761e3a8d18d2e0
[p2pool.git] / p2pool / util / forest.py
1 '''
2 forest data structure
3 '''
4
5 import itertools
6 import weakref
7
8 from p2pool.util import skiplist, variable
9
10
11 class TrackerSkipList(skiplist.SkipList):
12     def __init__(self, tracker):
13         skiplist.SkipList.__init__(self)
14         self.tracker = tracker
15         
16         self_ref = weakref.ref(self, lambda _: tracker.removed.unwatch(watch_id))
17         watch_id = self.tracker.removed.watch(lambda share: self_ref().forget_item(share.hash))
18     
19     def previous(self, element):
20         return self.tracker.delta_type.from_element(self.tracker.shares[element]).tail
21
22
23 class DistanceSkipList(TrackerSkipList):
24     def get_delta(self, element):
25         return element, 1, self.previous(element)
26     
27     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
28         if to_hash1 != from_hash2:
29             raise AssertionError()
30         return from_hash1, dist1 + dist2, to_hash2
31     
32     def initial_solution(self, start, (n,)):
33         return 0, start
34     
35     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
36         if to_hash1 != from_hash2:
37             raise AssertionError()
38         return dist1 + dist2, to_hash2
39     
40     def judge(self, (dist, hash), (n,)):
41         if dist > n:
42             return 1
43         elif dist == n:
44             return 0
45         else:
46             return -1
47     
48     def finalize(self, (dist, hash), (n,)):
49         assert dist == n
50         return hash
51
52 def get_attributedelta_type(attrs): # attrs: {name: func}
53     class ProtoAttributeDelta(object):
54         __slots__ = ['head', 'tail'] + attrs.keys()
55         
56         @classmethod
57         def get_none(cls, element_id):
58             return cls(element_id, element_id, **dict((k, 0) for k in attrs))
59         
60         @classmethod
61         def from_element(cls, share):
62             return cls(share.hash, share.previous_hash, **dict((k, v(share)) for k, v in attrs.iteritems()))
63         
64         def __init__(self, head, tail, **kwargs):
65             self.head, self.tail = head, tail
66             for k, v in kwargs.iteritems():
67                 setattr(self, k, v)
68         
69         def __add__(self, other):
70             assert self.tail == other.head
71             return self.__class__(self.head, other.tail, **dict((k, getattr(self, k) + getattr(other, k)) for k in attrs))
72         
73         def __sub__(self, other):
74             if self.head == other.head:
75                 return self.__class__(other.tail, self.tail, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
76             elif self.tail == other.tail:
77                 return self.__class__(self.head, other.head, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
78             else:
79                 raise AssertionError()
80         
81         def __repr__(self):
82             return '%s(%r, %r%s)' % (self.__class__, self.head, self.tail, ''.join(', %s=%r' % (k, getattr(self, k)) for k in attrs))
83     ProtoAttributeDelta.attrs = attrs
84     return ProtoAttributeDelta
85
86 AttributeDelta = get_attributedelta_type(dict(
87     height=lambda item: 1,
88 ))
89
90 class Tracker(object):
91     def __init__(self, shares=[], delta_type=AttributeDelta):
92         self.shares = {} # hash -> share
93         self.reverse_shares = {} # delta.tail -> set of share_hashes
94         
95         self.heads = {} # head hash -> tail_hash
96         self.tails = {} # tail hash -> set of head hashes
97         
98         self.deltas = {} # share_hash -> delta, ref
99         self.reverse_deltas = {} # ref -> set of share_hashes
100         
101         self.ref_generator = itertools.count()
102         self.delta_refs = {} # ref -> delta
103         self.reverse_delta_refs = {} # delta.tail -> ref
104         
105         self.added = variable.Event()
106         self.removed = variable.Event()
107         
108         self.get_nth_parent_hash = DistanceSkipList(self)
109         
110         self.delta_type = delta_type
111         
112         for share in shares:
113             self.add(share)
114     
115     def add(self, share):
116         assert not isinstance(share, (int, long, type(None)))
117         delta = self.delta_type.from_element(share)
118         
119         if delta.head in self.shares:
120             raise ValueError('share already present')
121         
122         if delta.head in self.tails:
123             heads = self.tails.pop(delta.head)
124         else:
125             heads = set([delta.head])
126         
127         if delta.tail in self.heads:
128             tail = self.heads.pop(delta.tail)
129         else:
130             tail = self.get_last(delta.tail)
131         
132         self.shares[delta.head] = share
133         self.reverse_shares.setdefault(delta.tail, set()).add(delta.head)
134         
135         self.tails.setdefault(tail, set()).update(heads)
136         if delta.tail in self.tails[tail]:
137             self.tails[tail].remove(delta.tail)
138         
139         for head in heads:
140             self.heads[head] = tail
141         
142         self.added.happened(share)
143     
144     def remove(self, share_hash):
145         assert isinstance(share_hash, (int, long, type(None)))
146         if share_hash not in self.shares:
147             raise KeyError()
148         
149         share = self.shares[share_hash]
150         del share_hash
151         
152         delta = self.delta_type.from_element(share)
153         
154         children = self.reverse_shares.get(delta.head, set())
155         
156         if delta.head in self.heads and delta.tail in self.tails:
157             tail = self.heads.pop(delta.head)
158             self.tails[tail].remove(delta.head)
159             if not self.tails[delta.tail]:
160                 self.tails.pop(delta.tail)
161         elif delta.head in self.heads:
162             tail = self.heads.pop(delta.head)
163             self.tails[tail].remove(delta.head)
164             if self.reverse_shares[delta.tail] != set([delta.head]):
165                 pass # has sibling
166             else:
167                 self.tails[tail].add(delta.tail)
168                 self.heads[delta.tail] = tail
169         elif delta.tail in self.tails and len(self.reverse_shares[delta.tail]) <= 1:
170             # move delta refs referencing children down to this, so they can be moved up in one step
171             if delta.tail in self.reverse_delta_refs:
172                 for x in list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, object()), set())):
173                     self.get_last(x)
174                 assert delta.head not in self.reverse_delta_refs, list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, None), set()))
175             
176             heads = self.tails.pop(delta.tail)
177             for head in heads:
178                 self.heads[head] = delta.head
179             self.tails[delta.head] = set(heads)
180             
181             # move ref pointing to this up
182             if delta.tail in self.reverse_delta_refs:
183                 assert delta.head not in self.reverse_delta_refs, list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, object()), set()))
184                 
185                 ref = self.reverse_delta_refs[delta.tail]
186                 cur_delta = self.delta_refs[ref]
187                 assert cur_delta.tail == delta.tail
188                 self.delta_refs[ref] = cur_delta - self.delta_type.from_element(share)
189                 assert self.delta_refs[ref].tail == delta.head
190                 del self.reverse_delta_refs[delta.tail]
191                 self.reverse_delta_refs[delta.head] = ref
192         else:
193             raise NotImplementedError()
194         
195         # delete delta entry and ref if it is empty
196         if delta.head in self.deltas:
197             delta1, ref = self.deltas.pop(delta.head)
198             self.reverse_deltas[ref].remove(delta.head)
199             if not self.reverse_deltas[ref]:
200                 del self.reverse_deltas[ref]
201                 delta2 = self.delta_refs.pop(ref)
202                 del self.reverse_delta_refs[delta2.tail]
203         
204         self.shares.pop(delta.head)
205         self.reverse_shares[delta.tail].remove(delta.head)
206         if not self.reverse_shares[delta.tail]:
207             self.reverse_shares.pop(delta.tail)
208         
209         self.removed.happened(share)
210     
211     def get_height(self, share_hash):
212         return self.get_delta(share_hash).height
213     
214     def get_work(self, share_hash):
215         return self.get_delta(share_hash).work
216     
217     def get_last(self, share_hash):
218         return self.get_delta(share_hash).tail
219     
220     def get_height_and_last(self, share_hash):
221         delta = self.get_delta(share_hash)
222         return delta.height, delta.tail
223     
224     def get_height_work_and_last(self, share_hash):
225         delta = self.get_delta(share_hash)
226         return delta.height, delta.work, delta.tail
227     
228     def _get_delta(self, share_hash):
229         if share_hash in self.deltas:
230             delta1, ref = self.deltas[share_hash]
231             delta2 = self.delta_refs[ref]
232             res = delta1 + delta2
233         else:
234             res = self.delta_type.from_element(self.shares[share_hash])
235         assert res.head == share_hash
236         return res
237     
238     def _set_delta(self, share_hash, delta):
239         other_share_hash = delta.tail
240         if other_share_hash not in self.reverse_delta_refs:
241             ref = self.ref_generator.next()
242             assert ref not in self.delta_refs
243             self.delta_refs[ref] = self.delta_type.get_none(other_share_hash)
244             self.reverse_delta_refs[other_share_hash] = ref
245             del ref
246         
247         ref = self.reverse_delta_refs[other_share_hash]
248         ref_delta = self.delta_refs[ref]
249         assert ref_delta.tail == other_share_hash
250         
251         if share_hash in self.deltas:
252             prev_ref = self.deltas[share_hash][1]
253             self.reverse_deltas[prev_ref].remove(share_hash)
254             if not self.reverse_deltas[prev_ref] and prev_ref != ref:
255                 self.reverse_deltas.pop(prev_ref)
256                 x = self.delta_refs.pop(prev_ref)
257                 self.reverse_delta_refs.pop(x.tail)
258         self.deltas[share_hash] = delta - ref_delta, ref
259         self.reverse_deltas.setdefault(ref, set()).add(share_hash)
260     
261     def get_delta(self, share_hash):
262         assert isinstance(share_hash, (int, long, type(None)))
263         delta = self.delta_type.get_none(share_hash)
264         updates = []
265         while delta.tail in self.shares:
266             updates.append((delta.tail, delta))
267             this_delta = self._get_delta(delta.tail)
268             delta += this_delta
269         for update_hash, delta_then in updates:
270             self._set_delta(update_hash, delta - delta_then)
271         return delta
272     
273     def get_chain(self, start_hash, length):
274         assert length <= self.get_height(start_hash)
275         for i in xrange(length):
276             yield self.shares[start_hash]
277             start_hash = self.delta_type.from_element(self.shares[start_hash]).tail
278     
279     def is_child_of(self, share_hash, possible_child_hash):
280         height, last = self.get_height_and_last(share_hash)
281         child_height, child_last = self.get_height_and_last(possible_child_hash)
282         if child_last != last:
283             return None # not connected, so can't be determined
284         height_up = child_height - height
285         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == share_hash