fixed bug in Tracker.remove found by unit tests
[p2pool.git] / p2pool / util / forest.py
1 import itertools
2
3 from p2pool.util import skiplist, variable
4
5 from p2pool.bitcoin import data as bitcoin_data
6
7 class DistanceSkipList(skiplist.SkipList):
8     def __init__(self, tracker):
9         skiplist.SkipList.__init__(self)
10         self.tracker = tracker
11     
12     def previous(self, element):
13         return self.tracker.shares[element].previous_hash
14     
15     def get_delta(self, element):
16         return element, 1, self.tracker.shares[element].previous_hash
17     
18     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
19         if to_hash1 != from_hash2:
20             raise AssertionError()
21         return from_hash1, dist1 + dist2, to_hash2
22     
23     def initial_solution(self, start, (n,)):
24         return 0, start
25     
26     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
27         if to_hash1 != from_hash2:
28             raise AssertionError()
29         return dist1 + dist2, to_hash2
30     
31     def judge(self, (dist, hash), (n,)):
32         if dist > n:
33             return 1
34         elif dist == n:
35             return 0
36         else:
37             return -1
38     
39     def finalize(self, (dist, hash)):
40         return hash
41
42
43 # linked list tracker
44
45 class Tracker(object):
46     def __init__(self, shares=[]):
47         self.shares = {} # hash -> share
48         #self.ids = {} # hash -> (id, height)
49         self.reverse_shares = {} # previous_hash -> set of share_hashes
50         
51         self.heads = {} # head hash -> tail_hash
52         self.tails = {} # tail hash -> set of head hashes
53         
54         self.heights = {} # share_hash -> height_to, ref, work_inc
55         self.reverse_heights = {} # ref -> set of share_hashes
56         
57         self.ref_generator = itertools.count()
58         self.height_refs = {} # ref -> height, share_hash, work_inc
59         self.reverse_height_refs = {} # share_hash -> ref
60         
61         self.get_nth_parent_hash = DistanceSkipList(self)
62         
63         self.added = variable.Event()
64         self.removed = variable.Event()
65         
66         for share in shares:
67             self.add(share)
68     
69     def add(self, share):
70         assert not isinstance(share, (int, long, type(None)))
71         if share.hash in self.shares:
72             raise ValueError('share already present')
73         
74         if share.hash in self.tails:
75             heads = self.tails.pop(share.hash)
76         else:
77             heads = set([share.hash])
78         
79         if share.previous_hash in self.heads:
80             tail = self.heads.pop(share.previous_hash)
81         else:
82             tail = self.get_last(share.previous_hash)
83             #tail2 = share.previous_hash
84             #while tail2 in self.shares:
85             #    tail2 = self.shares[tail2].previous_hash
86             #assert tail == tail2
87         
88         self.shares[share.hash] = share
89         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
90         
91         self.tails.setdefault(tail, set()).update(heads)
92         if share.previous_hash in self.tails[tail]:
93             self.tails[tail].remove(share.previous_hash)
94         
95         for head in heads:
96             self.heads[head] = tail
97         
98         self.added.happened(share)
99     
100     def remove(self, share_hash):
101         assert isinstance(share_hash, (int, long, type(None)))
102         if share_hash not in self.shares:
103             raise KeyError()
104         
105         share = self.shares[share_hash]
106         del share_hash
107         
108         children = self.reverse_shares.get(share.hash, set())
109         
110         # move height refs referencing children down to this, so they can be moved up in one step
111         if share.previous_hash in self.reverse_height_refs:
112             if share.previous_hash not in self.tails:
113                 for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.previous_hash, object()), set())):
114                     self.get_last(x)
115             for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set())):
116                 self.get_last(x)
117             assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, None), set()))
118         
119         if share.hash in self.heads and share.previous_hash in self.tails:
120             tail = self.heads.pop(share.hash)
121             self.tails[tail].remove(share.hash)
122             if not self.tails[share.previous_hash]:
123                 self.tails.pop(share.previous_hash)
124         elif share.hash in self.heads:
125             tail = self.heads.pop(share.hash)
126             self.tails[tail].remove(share.hash)
127             if self.reverse_shares[share.previous_hash] != set([share.hash]):
128                 pass # has sibling
129             else:
130                 self.tails[tail].add(share.previous_hash)
131                 self.heads[share.previous_hash] = tail
132         elif share.previous_hash in self.tails and len(self.reverse_shares[share.previous_hash]) <= 1:
133             heads = self.tails.pop(share.previous_hash)
134             for head in heads:
135                 self.heads[head] = share.hash
136             self.tails[share.hash] = set(heads)
137             
138             # move ref pointing to this up
139             if share.previous_hash in self.reverse_height_refs:
140                 assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set()))
141                 
142                 ref = self.reverse_height_refs[share.previous_hash]
143                 cur_height, cur_hash, cur_work = self.height_refs[ref]
144                 assert cur_hash == share.previous_hash
145                 self.height_refs[ref] = cur_height - 1, share.hash, cur_work - bitcoin_data.target_to_average_attempts(share.target)
146                 del self.reverse_height_refs[share.previous_hash]
147                 self.reverse_height_refs[share.hash] = ref
148         else:
149             raise NotImplementedError()
150         
151         # delete height entry, and ref if it is empty
152         if share.hash in self.heights:
153             _, ref, _ = self.heights.pop(share.hash)
154             self.reverse_heights[ref].remove(share.hash)
155             if not self.reverse_heights[ref]:
156                 del self.reverse_heights[ref]
157                 _, ref_hash, _ = self.height_refs.pop(ref)
158                 del self.reverse_height_refs[ref_hash]
159         
160         self.shares.pop(share.hash)
161         self.reverse_shares[share.previous_hash].remove(share.hash)
162         if not self.reverse_shares[share.previous_hash]:
163             self.reverse_shares.pop(share.previous_hash)
164         
165         self.removed.happened(share)
166     
167     def get_height(self, share_hash):
168         height, work, last = self.get_height_work_and_last(share_hash)
169         return height
170     
171     def get_work(self, share_hash):
172         height, work, last = self.get_height_work_and_last(share_hash)
173         return work
174     
175     def get_last(self, share_hash):
176         height, work, last = self.get_height_work_and_last(share_hash)
177         return last
178     
179     def get_height_and_last(self, share_hash):
180         height, work, last = self.get_height_work_and_last(share_hash)
181         return height, last
182     
183     def _get_height_jump(self, share_hash):
184         if share_hash in self.heights:
185             height_to1, ref, work_inc1 = self.heights[share_hash]
186             height_to2, share_hash, work_inc2 = self.height_refs[ref]
187             height_inc = height_to1 + height_to2
188             work_inc = work_inc1 + work_inc2
189         else:
190             height_inc, share_hash, work_inc = 1, self.shares[share_hash].previous_hash, bitcoin_data.target_to_average_attempts(self.shares[share_hash].target)
191         return height_inc, share_hash, work_inc
192     
193     def _set_height_jump(self, share_hash, height_inc, other_share_hash, work_inc):
194         if other_share_hash not in self.reverse_height_refs:
195             ref = self.ref_generator.next()
196             assert ref not in self.height_refs
197             self.height_refs[ref] = 0, other_share_hash, 0
198             self.reverse_height_refs[other_share_hash] = ref
199             del ref
200         
201         ref = self.reverse_height_refs[other_share_hash]
202         ref_height_to, ref_share_hash, ref_work_inc = self.height_refs[ref]
203         assert ref_share_hash == other_share_hash
204         
205         if share_hash in self.heights:
206             prev_ref = self.heights[share_hash][1]
207             self.reverse_heights[prev_ref].remove(share_hash)
208             if not self.reverse_heights[prev_ref] and prev_ref != ref:
209                 self.reverse_heights.pop(prev_ref)
210                 _, x, _ = self.height_refs.pop(prev_ref)
211                 self.reverse_height_refs.pop(x)
212         self.heights[share_hash] = height_inc - ref_height_to, ref, work_inc - ref_work_inc
213         self.reverse_heights.setdefault(ref, set()).add(share_hash)
214     
215     def get_height_work_and_last(self, share_hash):
216         assert isinstance(share_hash, (int, long, type(None)))
217         height = 0
218         work = 0
219         updates = []
220         while share_hash in self.shares:
221             updates.append((share_hash, height, work))
222             height_inc, share_hash, work_inc = self._get_height_jump(share_hash)
223             height += height_inc
224             work += work_inc
225         for update_hash, height_then, work_then in updates:
226             self._set_height_jump(update_hash, height - height_then, share_hash, work - work_then)
227         return height, work, share_hash
228     
229     def get_chain_known(self, start_hash):
230         assert isinstance(start_hash, (int, long, type(None)))
231         '''
232         Chain starting with item of hash I{start_hash} of items that this Tracker contains
233         '''
234         item_hash_to_get = start_hash
235         while True:
236             if item_hash_to_get not in self.shares:
237                 break
238             share = self.shares[item_hash_to_get]
239             assert not isinstance(share, long)
240             yield share
241             item_hash_to_get = share.previous_hash
242     
243     def get_chain_to_root(self, start_hash, root=None):
244         assert isinstance(start_hash, (int, long, type(None)))
245         assert isinstance(root, (int, long, type(None)))
246         '''
247         Chain of hashes starting with share_hash of shares to the root (doesn't include root)
248         Raises an error if one is missing
249         '''
250         share_hash_to_get = start_hash
251         while share_hash_to_get != root:
252             share = self.shares[share_hash_to_get]
253             yield share
254             share_hash_to_get = share.previous_hash
255     
256     def get_best_hash(self):
257         '''
258         Returns hash of item with the most items in its chain
259         '''
260         if not self.heads:
261             return None
262         return max(self.heads, key=self.get_height_and_last)
263     
264     def get_highest_height(self):
265         return max(self.get_height_and_last(head)[0] for head in self.heads) if self.heads else 0
266     
267     def is_child_of(self, share_hash, possible_child_hash):
268         height, last = self.get_height_and_last(share_hash)
269         child_height, child_last = self.get_height_and_last(possible_child_hash)
270         if child_last != last:
271             return None # not connected, so can't be determined
272         height_up = child_height - height
273         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == share_hash