cleaned up AttributeDelta handling
[p2pool.git] / p2pool / util / forest.py
1 '''
2 forest data structure
3 '''
4
5 import itertools
6 import weakref
7
8 from p2pool.util import skiplist, variable
9 from p2pool.bitcoin import data as bitcoin_data
10
11
12 class TrackerSkipList(skiplist.SkipList):
13     def __init__(self, tracker):
14         skiplist.SkipList.__init__(self)
15         self.tracker = tracker
16         
17         self_ref = weakref.ref(self, lambda _: tracker.removed.unwatch(watch_id))
18         watch_id = self.tracker.removed.watch(lambda share: self_ref().forget_item(share.hash))
19     
20     def previous(self, element):
21         return self.tracker.delta_type.from_element(self.tracker.shares[element]).tail
22
23
24 class DistanceSkipList(TrackerSkipList):
25     def get_delta(self, element):
26         return element, 1, self.previous(element)
27     
28     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
29         if to_hash1 != from_hash2:
30             raise AssertionError()
31         return from_hash1, dist1 + dist2, to_hash2
32     
33     def initial_solution(self, start, (n,)):
34         return 0, start
35     
36     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
37         if to_hash1 != from_hash2:
38             raise AssertionError()
39         return dist1 + dist2, to_hash2
40     
41     def judge(self, (dist, hash), (n,)):
42         if dist > n:
43             return 1
44         elif dist == n:
45             return 0
46         else:
47             return -1
48     
49     def finalize(self, (dist, hash), (n,)):
50         assert dist == n
51         return hash
52
53 def get_attributedelta_type(attrs): # attrs: {name: func}
54     class ProtoAttributeDelta(object):
55         __slots__ = ['head', 'tail'] + attrs.keys()
56         
57         @classmethod
58         def get_none(cls, element_id):
59             return cls(element_id, element_id, **dict((k, 0) for k in attrs))
60         
61         @classmethod
62         def from_element(cls, share):
63             return cls(share.hash, share.previous_hash, **dict((k, v(share)) for k, v in attrs.iteritems()))
64         
65         def __init__(self, head, tail, **kwargs):
66             self.head, self.tail = head, tail
67             for k, v in kwargs.iteritems():
68                 setattr(self, k, v)
69         
70         def __add__(self, other):
71             assert self.tail == other.head
72             return self.__class__(self.head, other.tail, **dict((k, getattr(self, k) + getattr(other, k)) for k in attrs))
73         
74         def __sub__(self, other):
75             if self.head == other.head:
76                 return self.__class__(other.tail, self.tail, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
77             elif self.tail == other.tail:
78                 return self.__class__(self.head, other.head, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
79             else:
80                 raise AssertionError()
81         
82         def __repr__(self):
83             return '%s(%r, %r%s)' % (self.__class__, self.head, self.tail, ''.join(', %s=%r' % (k, getattr(self, k)) for k in attrs))
84     ProtoAttributeDelta.attrs = attrs
85     return ProtoAttributeDelta
86
87 AttributeDelta = get_attributedelta_type(dict(
88     height=lambda share: 1,
89     work=lambda share: bitcoin_data.target_to_average_attempts(share.target),
90 ))
91
92 class Tracker(object):
93     def __init__(self, shares=[], delta_type=AttributeDelta):
94         self.shares = {} # hash -> share
95         self.reverse_shares = {} # delta.tail -> set of share_hashes
96         
97         self.heads = {} # head hash -> tail_hash
98         self.tails = {} # tail hash -> set of head hashes
99         
100         self.deltas = {} # share_hash -> delta, ref
101         self.reverse_deltas = {} # ref -> set of share_hashes
102         
103         self.ref_generator = itertools.count()
104         self.delta_refs = {} # ref -> delta
105         self.reverse_delta_refs = {} # delta.tail -> ref
106         
107         self.added = variable.Event()
108         self.removed = variable.Event()
109         
110         self.get_nth_parent_hash = DistanceSkipList(self)
111         
112         self.delta_type = delta_type
113         
114         for share in shares:
115             self.add(share)
116     
117     def add(self, share):
118         assert not isinstance(share, (int, long, type(None)))
119         delta = self.delta_type.from_element(share)
120         
121         if delta.head in self.shares:
122             raise ValueError('share already present')
123         
124         if delta.head in self.tails:
125             heads = self.tails.pop(delta.head)
126         else:
127             heads = set([delta.head])
128         
129         if delta.tail in self.heads:
130             tail = self.heads.pop(delta.tail)
131         else:
132             tail = self.get_last(delta.tail)
133         
134         self.shares[delta.head] = share
135         self.reverse_shares.setdefault(delta.tail, set()).add(delta.head)
136         
137         self.tails.setdefault(tail, set()).update(heads)
138         if delta.tail in self.tails[tail]:
139             self.tails[tail].remove(delta.tail)
140         
141         for head in heads:
142             self.heads[head] = tail
143         
144         self.added.happened(share)
145     
146     def remove(self, share_hash):
147         assert isinstance(share_hash, (int, long, type(None)))
148         if share_hash not in self.shares:
149             raise KeyError()
150         
151         share = self.shares[share_hash]
152         del share_hash
153         
154         delta = self.delta_type.from_element(share)
155         
156         children = self.reverse_shares.get(delta.head, set())
157         
158         if delta.head in self.heads and delta.tail in self.tails:
159             tail = self.heads.pop(delta.head)
160             self.tails[tail].remove(delta.head)
161             if not self.tails[delta.tail]:
162                 self.tails.pop(delta.tail)
163         elif delta.head in self.heads:
164             tail = self.heads.pop(delta.head)
165             self.tails[tail].remove(delta.head)
166             if self.reverse_shares[delta.tail] != set([delta.head]):
167                 pass # has sibling
168             else:
169                 self.tails[tail].add(delta.tail)
170                 self.heads[delta.tail] = tail
171         elif delta.tail in self.tails and len(self.reverse_shares[delta.tail]) <= 1:
172             # move delta refs referencing children down to this, so they can be moved up in one step
173             if delta.tail in self.reverse_delta_refs:
174                 for x in list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, object()), set())):
175                     self.get_last(x)
176                 assert delta.head not in self.reverse_delta_refs, list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, None), set()))
177             
178             heads = self.tails.pop(delta.tail)
179             for head in heads:
180                 self.heads[head] = delta.head
181             self.tails[delta.head] = set(heads)
182             
183             # move ref pointing to this up
184             if delta.tail in self.reverse_delta_refs:
185                 assert delta.head not in self.reverse_delta_refs, list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, object()), set()))
186                 
187                 ref = self.reverse_delta_refs[delta.tail]
188                 cur_delta = self.delta_refs[ref]
189                 assert cur_delta.tail == delta.tail
190                 self.delta_refs[ref] = cur_delta - self.delta_type.from_element(share)
191                 assert self.delta_refs[ref].tail == delta.head
192                 del self.reverse_delta_refs[delta.tail]
193                 self.reverse_delta_refs[delta.head] = ref
194         else:
195             raise NotImplementedError()
196         
197         # delete delta entry and ref if it is empty
198         if delta.head in self.deltas:
199             delta1, ref = self.deltas.pop(delta.head)
200             self.reverse_deltas[ref].remove(delta.head)
201             if not self.reverse_deltas[ref]:
202                 del self.reverse_deltas[ref]
203                 delta2 = self.delta_refs.pop(ref)
204                 del self.reverse_delta_refs[delta2.tail]
205         
206         self.shares.pop(delta.head)
207         self.reverse_shares[delta.tail].remove(delta.head)
208         if not self.reverse_shares[delta.tail]:
209             self.reverse_shares.pop(delta.tail)
210         
211         self.removed.happened(share)
212     
213     def get_height(self, share_hash):
214         return self.get_delta(share_hash).height
215     
216     def get_work(self, share_hash):
217         return self.get_delta(share_hash).work
218     
219     def get_last(self, share_hash):
220         return self.get_delta(share_hash).tail
221     
222     def get_height_and_last(self, share_hash):
223         delta = self.get_delta(share_hash)
224         return delta.height, delta.tail
225     
226     def get_height_work_and_last(self, share_hash):
227         delta = self.get_delta(share_hash)
228         return delta.height, delta.work, delta.tail
229     
230     def _get_delta(self, share_hash):
231         if share_hash in self.deltas:
232             delta1, ref = self.deltas[share_hash]
233             delta2 = self.delta_refs[ref]
234             res = delta1 + delta2
235         else:
236             res = self.delta_type.from_element(self.shares[share_hash])
237         assert res.head == share_hash
238         return res
239     
240     def _set_delta(self, share_hash, delta):
241         other_share_hash = delta.tail
242         if other_share_hash not in self.reverse_delta_refs:
243             ref = self.ref_generator.next()
244             assert ref not in self.delta_refs
245             self.delta_refs[ref] = self.delta_type.get_none(other_share_hash)
246             self.reverse_delta_refs[other_share_hash] = ref
247             del ref
248         
249         ref = self.reverse_delta_refs[other_share_hash]
250         ref_delta = self.delta_refs[ref]
251         assert ref_delta.tail == other_share_hash
252         
253         if share_hash in self.deltas:
254             prev_ref = self.deltas[share_hash][1]
255             self.reverse_deltas[prev_ref].remove(share_hash)
256             if not self.reverse_deltas[prev_ref] and prev_ref != ref:
257                 self.reverse_deltas.pop(prev_ref)
258                 x = self.delta_refs.pop(prev_ref)
259                 self.reverse_delta_refs.pop(x.tail)
260         self.deltas[share_hash] = delta - ref_delta, ref
261         self.reverse_deltas.setdefault(ref, set()).add(share_hash)
262     
263     def get_delta(self, share_hash):
264         assert isinstance(share_hash, (int, long, type(None)))
265         delta = self.delta_type.get_none(share_hash)
266         updates = []
267         while delta.tail in self.shares:
268             updates.append((delta.tail, delta))
269             this_delta = self._get_delta(delta.tail)
270             delta += this_delta
271         for update_hash, delta_then in updates:
272             self._set_delta(update_hash, delta - delta_then)
273         return delta
274     
275     def get_chain(self, start_hash, length):
276         assert length <= self.get_height(start_hash)
277         for i in xrange(length):
278             yield self.shares[start_hash]
279             start_hash = self.delta_type.from_element(self.shares[start_hash]).tail
280     
281     def is_child_of(self, share_hash, possible_child_hash):
282         height, last = self.get_height_and_last(share_hash)
283         child_height, child_last = self.get_height_and_last(possible_child_hash)
284         if child_last != last:
285             return None # not connected, so can't be determined
286         height_up = child_height - height
287         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == share_hash