moved bitcoin.data.Tracker to util.forest.Tracker
[p2pool.git] / p2pool / util / forest.py
1 import itertools
2
3 from p2pool.util import skiplist, variable
4
5 from p2pool.bitcoin import data as bitcoin_data
6
7 class DistanceSkipList(skiplist.SkipList):
8     def __init__(self, tracker):
9         skiplist.SkipList.__init__(self)
10         self.tracker = tracker
11     
12     def previous(self, element):
13         return self.tracker.shares[element].previous_hash
14     
15     def get_delta(self, element):
16         return element, 1, self.tracker.shares[element].previous_hash
17     
18     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
19         if to_hash1 != from_hash2:
20             raise AssertionError()
21         return from_hash1, dist1 + dist2, to_hash2
22     
23     def initial_solution(self, start, (n,)):
24         return 0, start
25     
26     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
27         if to_hash1 != from_hash2:
28             raise AssertionError()
29         return dist1 + dist2, to_hash2
30     
31     def judge(self, (dist, hash), (n,)):
32         if dist > n:
33             return 1
34         elif dist == n:
35             return 0
36         else:
37             return -1
38     
39     def finalize(self, (dist, hash)):
40         return hash
41
42 if __name__ == '__main__':
43     import random
44     from p2pool.bitcoin import data
45     t = data.Tracker()
46     d = DistanceSkipList(t)
47     for i in xrange(2000):
48         t.add(data.FakeShare(hash=i, previous_hash=i - 1 if i > 0 else None))
49     for i in xrange(2000):
50         a = random.randrange(2000)
51         b = random.randrange(a + 1)
52         res = d(a, b)
53         assert res == a - b, (a, b, res)
54
55 # linked list tracker
56
57 class Tracker(object):
58     def __init__(self):
59         self.shares = {} # hash -> share
60         #self.ids = {} # hash -> (id, height)
61         self.reverse_shares = {} # previous_hash -> set of share_hashes
62         
63         self.heads = {} # head hash -> tail_hash
64         self.tails = {} # tail hash -> set of head hashes
65         
66         self.heights = {} # share_hash -> height_to, ref, work_inc
67         self.reverse_heights = {} # ref -> set of share_hashes
68         
69         self.ref_generator = itertools.count()
70         self.height_refs = {} # ref -> height, share_hash, work_inc
71         self.reverse_height_refs = {} # share_hash -> ref
72         
73         self.get_nth_parent_hash = DistanceSkipList(self)
74         
75         self.added = variable.Event()
76         self.removed = variable.Event()
77     
78     def add(self, share):
79         assert not isinstance(share, (int, long, type(None)))
80         if share.hash in self.shares:
81             raise ValueError('share already present')
82         
83         if share.hash in self.tails:
84             heads = self.tails.pop(share.hash)
85         else:
86             heads = set([share.hash])
87         
88         if share.previous_hash in self.heads:
89             tail = self.heads.pop(share.previous_hash)
90         else:
91             tail = self.get_last(share.previous_hash)
92             #tail2 = share.previous_hash
93             #while tail2 in self.shares:
94             #    tail2 = self.shares[tail2].previous_hash
95             #assert tail == tail2
96         
97         self.shares[share.hash] = share
98         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
99         
100         self.tails.setdefault(tail, set()).update(heads)
101         if share.previous_hash in self.tails[tail]:
102             self.tails[tail].remove(share.previous_hash)
103         
104         for head in heads:
105             self.heads[head] = tail
106         
107         self.added.happened(share)
108     
109     def test(self):
110         t = Tracker()
111         for s in self.shares.itervalues():
112             t.add(s)
113         
114         assert self.shares == t.shares, (self.shares, t.shares)
115         assert self.reverse_shares == t.reverse_shares, (self.reverse_shares, t.reverse_shares)
116         assert self.heads == t.heads, (self.heads, t.heads)
117         assert self.tails == t.tails, (self.tails, t.tails)
118     
119     def remove(self, share_hash):
120         assert isinstance(share_hash, (int, long, type(None)))
121         if share_hash not in self.shares:
122             raise KeyError()
123         
124         share = self.shares[share_hash]
125         del share_hash
126         
127         children = self.reverse_shares.get(share.hash, set())
128         
129         # move height refs referencing children down to this, so they can be moved up in one step
130         if share.previous_hash in self.reverse_height_refs:
131             if share.previous_hash not in self.tails:
132                 for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.previous_hash, object()), set())):
133                     self.get_last(x)
134             for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set())):
135                 self.get_last(x)
136             assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, None), set()))
137         
138         if share.hash in self.heads and share.previous_hash in self.tails:
139             tail = self.heads.pop(share.hash)
140             self.tails[tail].remove(share.hash)
141             if not self.tails[share.previous_hash]:
142                 self.tails.pop(share.previous_hash)
143         elif share.hash in self.heads:
144             tail = self.heads.pop(share.hash)
145             self.tails[tail].remove(share.hash)
146             if self.reverse_shares[share.previous_hash] != set([share.hash]):
147                 pass # has sibling
148             else:
149                 self.tails[tail].add(share.previous_hash)
150                 self.heads[share.previous_hash] = tail
151         elif share.previous_hash in self.tails:
152             heads = self.tails[share.previous_hash]
153             if len(self.reverse_shares[share.previous_hash]) > 1:
154                 raise NotImplementedError()
155             else:
156                 del self.tails[share.previous_hash]
157                 for head in heads:
158                     self.heads[head] = share.hash
159                 self.tails[share.hash] = set(heads)
160         else:
161             raise NotImplementedError()
162         
163         # move ref pointing to this up
164         if share.previous_hash in self.reverse_height_refs:
165             assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set()))
166             
167             ref = self.reverse_height_refs[share.previous_hash]
168             cur_height, cur_hash, cur_work = self.height_refs[ref]
169             assert cur_hash == share.previous_hash
170             self.height_refs[ref] = cur_height - 1, share.hash, cur_work - bitcoin_data.target_to_average_attempts(share.target)
171             del self.reverse_height_refs[share.previous_hash]
172             self.reverse_height_refs[share.hash] = ref
173         
174         # delete height entry, and ref if it is empty
175         if share.hash in self.heights:
176             _, ref, _ = self.heights.pop(share.hash)
177             self.reverse_heights[ref].remove(share.hash)
178             if not self.reverse_heights[ref]:
179                 del self.reverse_heights[ref]
180                 _, ref_hash, _ = self.height_refs.pop(ref)
181                 del self.reverse_height_refs[ref_hash]
182         
183         self.shares.pop(share.hash)
184         self.reverse_shares[share.previous_hash].remove(share.hash)
185         if not self.reverse_shares[share.previous_hash]:
186             self.reverse_shares.pop(share.previous_hash)
187         
188         #assert self.test() is None
189         self.removed.happened(share)
190     
191     def get_height(self, share_hash):
192         height, work, last = self.get_height_work_and_last(share_hash)
193         return height
194     
195     def get_work(self, share_hash):
196         height, work, last = self.get_height_work_and_last(share_hash)
197         return work
198     
199     def get_last(self, share_hash):
200         height, work, last = self.get_height_work_and_last(share_hash)
201         return last
202     
203     def get_height_and_last(self, share_hash):
204         height, work, last = self.get_height_work_and_last(share_hash)
205         return height, last
206     
207     def _get_height_jump(self, share_hash):
208         if share_hash in self.heights:
209             height_to1, ref, work_inc1 = self.heights[share_hash]
210             height_to2, share_hash, work_inc2 = self.height_refs[ref]
211             height_inc = height_to1 + height_to2
212             work_inc = work_inc1 + work_inc2
213         else:
214             height_inc, share_hash, work_inc = 1, self.shares[share_hash].previous_hash, bitcoin_data.target_to_average_attempts(self.shares[share_hash].target)
215         return height_inc, share_hash, work_inc
216     
217     def _set_height_jump(self, share_hash, height_inc, other_share_hash, work_inc):
218         if other_share_hash not in self.reverse_height_refs:
219             ref = self.ref_generator.next()
220             assert ref not in self.height_refs
221             self.height_refs[ref] = 0, other_share_hash, 0
222             self.reverse_height_refs[other_share_hash] = ref
223             del ref
224         
225         ref = self.reverse_height_refs[other_share_hash]
226         ref_height_to, ref_share_hash, ref_work_inc = self.height_refs[ref]
227         assert ref_share_hash == other_share_hash
228         
229         if share_hash in self.heights:
230             prev_ref = self.heights[share_hash][1]
231             self.reverse_heights[prev_ref].remove(share_hash)
232             if not self.reverse_heights[prev_ref] and prev_ref != ref:
233                 self.reverse_heights.pop(prev_ref)
234                 _, x, _ = self.height_refs.pop(prev_ref)
235                 self.reverse_height_refs.pop(x)
236         self.heights[share_hash] = height_inc - ref_height_to, ref, work_inc - ref_work_inc
237         self.reverse_heights.setdefault(ref, set()).add(share_hash)
238     
239     def get_height_work_and_last(self, share_hash):
240         assert isinstance(share_hash, (int, long, type(None)))
241         orig = share_hash
242         height = 0
243         work = 0
244         updates = []
245         while share_hash in self.shares:
246             updates.append((share_hash, height, work))
247             height_inc, share_hash, work_inc = self._get_height_jump(share_hash)
248             height += height_inc
249             work += work_inc
250         for update_hash, height_then, work_then in updates:
251             self._set_height_jump(update_hash, height - height_then, share_hash, work - work_then)
252         return height, work, share_hash
253     
254     def get_chain_known(self, start_hash):
255         assert isinstance(start_hash, (int, long, type(None)))
256         '''
257         Chain starting with item of hash I{start_hash} of items that this Tracker contains
258         '''
259         item_hash_to_get = start_hash
260         while True:
261             if item_hash_to_get not in self.shares:
262                 break
263             share = self.shares[item_hash_to_get]
264             assert not isinstance(share, long)
265             yield share
266             item_hash_to_get = share.previous_hash
267     
268     def get_chain_to_root(self, start_hash, root=None):
269         assert isinstance(start_hash, (int, long, type(None)))
270         assert isinstance(root, (int, long, type(None)))
271         '''
272         Chain of hashes starting with share_hash of shares to the root (doesn't include root)
273         Raises an error if one is missing
274         '''
275         share_hash_to_get = start_hash
276         while share_hash_to_get != root:
277             share = self.shares[share_hash_to_get]
278             yield share
279             share_hash_to_get = share.previous_hash
280     
281     def get_best_hash(self):
282         '''
283         Returns hash of item with the most items in its chain
284         '''
285         if not self.heads:
286             return None
287         return max(self.heads, key=self.get_height_and_last)
288     
289     def get_highest_height(self):
290         return max(self.get_height_and_last(head)[0] for head in self.heads) if self.heads else 0
291     
292     def is_child_of(self, share_hash, possible_child_hash):
293         height, last = self.get_height_and_last(share_hash)
294         child_height, child_last = self.get_height_and_last(possible_child_hash)
295         if child_last != last:
296             return None # not connected, so can't be determined
297         height_up = child_height - height
298         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == share_hash
299
300 class FakeShare(object):
301     def __init__(self, **kwargs):
302         self.__dict__.update(kwargs)
303
304 if __name__ == '__main__':
305     
306     t = Tracker()
307     
308     for i in xrange(10000):
309         t.add(FakeShare(hash=i, previous_hash=i - 1 if i > 0 else None))
310     
311     #t.remove(99)
312     
313     print 'HEADS', t.heads
314     print 'TAILS', t.tails
315     
316     import random
317     
318     while False:
319         print
320         print '-'*30
321         print
322         t = Tracker()
323         for i in xrange(random.randrange(100)):
324             x = random.choice(list(t.shares) + [None])
325             print i, '->', x
326             t.add(FakeShare(i, x))
327         while t.shares:
328             x = random.choice(list(t.shares))
329             print 'DEL', x, t.__dict__
330             try:
331                 t.remove(x)
332             except NotImplementedError:
333                 print 'aborted; not implemented'
334         import time
335         time.sleep(.1)
336         print 'HEADS', t.heads
337         print 'TAILS', t.tails
338     
339     #for share_hash, share in sorted(t.shares.iteritems()):
340     #    print share_hash, share.previous_hash, t.heads.get(share_hash), t.tails.get(share_hash)
341     
342     #import sys;sys.exit()
343     
344     print t.get_nth_parent_hash(9000, 5000)
345     print t.get_nth_parent_hash(9001, 412)
346     #print t.get_nth_parent_hash(90, 51)
347     
348     for share_hash in sorted(t.shares):
349         print str(share_hash).rjust(4),
350         x = t.skips.get(share_hash, None)
351         if x is not None:
352             print str(x[0]).rjust(4),
353             for a in x[1]:
354                 print str(a).rjust(10),
355         print