moved forest tests to test/ and fixed them up
[p2pool.git] / p2pool / util / forest.py
1 import itertools
2
3 from p2pool.util import skiplist, variable
4
5 from p2pool.bitcoin import data as bitcoin_data
6
7 class DistanceSkipList(skiplist.SkipList):
8     def __init__(self, tracker):
9         skiplist.SkipList.__init__(self)
10         self.tracker = tracker
11     
12     def previous(self, element):
13         return self.tracker.shares[element].previous_hash
14     
15     def get_delta(self, element):
16         return element, 1, self.tracker.shares[element].previous_hash
17     
18     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
19         if to_hash1 != from_hash2:
20             raise AssertionError()
21         return from_hash1, dist1 + dist2, to_hash2
22     
23     def initial_solution(self, start, (n,)):
24         return 0, start
25     
26     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
27         if to_hash1 != from_hash2:
28             raise AssertionError()
29         return dist1 + dist2, to_hash2
30     
31     def judge(self, (dist, hash), (n,)):
32         if dist > n:
33             return 1
34         elif dist == n:
35             return 0
36         else:
37             return -1
38     
39     def finalize(self, (dist, hash)):
40         return hash
41
42
43 # linked list tracker
44
45 class Tracker(object):
46     def __init__(self):
47         self.shares = {} # hash -> share
48         #self.ids = {} # hash -> (id, height)
49         self.reverse_shares = {} # previous_hash -> set of share_hashes
50         
51         self.heads = {} # head hash -> tail_hash
52         self.tails = {} # tail hash -> set of head hashes
53         
54         self.heights = {} # share_hash -> height_to, ref, work_inc
55         self.reverse_heights = {} # ref -> set of share_hashes
56         
57         self.ref_generator = itertools.count()
58         self.height_refs = {} # ref -> height, share_hash, work_inc
59         self.reverse_height_refs = {} # share_hash -> ref
60         
61         self.get_nth_parent_hash = DistanceSkipList(self)
62         
63         self.added = variable.Event()
64         self.removed = variable.Event()
65     
66     def add(self, share):
67         assert not isinstance(share, (int, long, type(None)))
68         if share.hash in self.shares:
69             raise ValueError('share already present')
70         
71         if share.hash in self.tails:
72             heads = self.tails.pop(share.hash)
73         else:
74             heads = set([share.hash])
75         
76         if share.previous_hash in self.heads:
77             tail = self.heads.pop(share.previous_hash)
78         else:
79             tail = self.get_last(share.previous_hash)
80             #tail2 = share.previous_hash
81             #while tail2 in self.shares:
82             #    tail2 = self.shares[tail2].previous_hash
83             #assert tail == tail2
84         
85         self.shares[share.hash] = share
86         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
87         
88         self.tails.setdefault(tail, set()).update(heads)
89         if share.previous_hash in self.tails[tail]:
90             self.tails[tail].remove(share.previous_hash)
91         
92         for head in heads:
93             self.heads[head] = tail
94         
95         self.added.happened(share)
96     
97     def test(self):
98         t = Tracker()
99         for s in self.shares.itervalues():
100             t.add(s)
101         
102         assert self.shares == t.shares, (self.shares, t.shares)
103         assert self.reverse_shares == t.reverse_shares, (self.reverse_shares, t.reverse_shares)
104         assert self.heads == t.heads, (self.heads, t.heads)
105         assert self.tails == t.tails, (self.tails, t.tails)
106     
107     def remove(self, share_hash):
108         assert isinstance(share_hash, (int, long, type(None)))
109         if share_hash not in self.shares:
110             raise KeyError()
111         
112         share = self.shares[share_hash]
113         del share_hash
114         
115         children = self.reverse_shares.get(share.hash, set())
116         
117         # move height refs referencing children down to this, so they can be moved up in one step
118         if share.previous_hash in self.reverse_height_refs:
119             if share.previous_hash not in self.tails:
120                 for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.previous_hash, object()), set())):
121                     self.get_last(x)
122             for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set())):
123                 self.get_last(x)
124             assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, None), set()))
125         
126         if share.hash in self.heads and share.previous_hash in self.tails:
127             tail = self.heads.pop(share.hash)
128             self.tails[tail].remove(share.hash)
129             if not self.tails[share.previous_hash]:
130                 self.tails.pop(share.previous_hash)
131         elif share.hash in self.heads:
132             tail = self.heads.pop(share.hash)
133             self.tails[tail].remove(share.hash)
134             if self.reverse_shares[share.previous_hash] != set([share.hash]):
135                 pass # has sibling
136             else:
137                 self.tails[tail].add(share.previous_hash)
138                 self.heads[share.previous_hash] = tail
139         elif share.previous_hash in self.tails:
140             heads = self.tails[share.previous_hash]
141             if len(self.reverse_shares[share.previous_hash]) > 1:
142                 raise NotImplementedError()
143             else:
144                 del self.tails[share.previous_hash]
145                 for head in heads:
146                     self.heads[head] = share.hash
147                 self.tails[share.hash] = set(heads)
148         else:
149             raise NotImplementedError()
150         
151         # move ref pointing to this up
152         if share.previous_hash in self.reverse_height_refs:
153             assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set()))
154             
155             ref = self.reverse_height_refs[share.previous_hash]
156             cur_height, cur_hash, cur_work = self.height_refs[ref]
157             assert cur_hash == share.previous_hash
158             self.height_refs[ref] = cur_height - 1, share.hash, cur_work - bitcoin_data.target_to_average_attempts(share.target)
159             del self.reverse_height_refs[share.previous_hash]
160             self.reverse_height_refs[share.hash] = ref
161         
162         # delete height entry, and ref if it is empty
163         if share.hash in self.heights:
164             _, ref, _ = self.heights.pop(share.hash)
165             self.reverse_heights[ref].remove(share.hash)
166             if not self.reverse_heights[ref]:
167                 del self.reverse_heights[ref]
168                 _, ref_hash, _ = self.height_refs.pop(ref)
169                 del self.reverse_height_refs[ref_hash]
170         
171         self.shares.pop(share.hash)
172         self.reverse_shares[share.previous_hash].remove(share.hash)
173         if not self.reverse_shares[share.previous_hash]:
174             self.reverse_shares.pop(share.previous_hash)
175         
176         #assert self.test() is None
177         self.removed.happened(share)
178     
179     def get_height(self, share_hash):
180         height, work, last = self.get_height_work_and_last(share_hash)
181         return height
182     
183     def get_work(self, share_hash):
184         height, work, last = self.get_height_work_and_last(share_hash)
185         return work
186     
187     def get_last(self, share_hash):
188         height, work, last = self.get_height_work_and_last(share_hash)
189         return last
190     
191     def get_height_and_last(self, share_hash):
192         height, work, last = self.get_height_work_and_last(share_hash)
193         return height, last
194     
195     def _get_height_jump(self, share_hash):
196         if share_hash in self.heights:
197             height_to1, ref, work_inc1 = self.heights[share_hash]
198             height_to2, share_hash, work_inc2 = self.height_refs[ref]
199             height_inc = height_to1 + height_to2
200             work_inc = work_inc1 + work_inc2
201         else:
202             height_inc, share_hash, work_inc = 1, self.shares[share_hash].previous_hash, bitcoin_data.target_to_average_attempts(self.shares[share_hash].target)
203         return height_inc, share_hash, work_inc
204     
205     def _set_height_jump(self, share_hash, height_inc, other_share_hash, work_inc):
206         if other_share_hash not in self.reverse_height_refs:
207             ref = self.ref_generator.next()
208             assert ref not in self.height_refs
209             self.height_refs[ref] = 0, other_share_hash, 0
210             self.reverse_height_refs[other_share_hash] = ref
211             del ref
212         
213         ref = self.reverse_height_refs[other_share_hash]
214         ref_height_to, ref_share_hash, ref_work_inc = self.height_refs[ref]
215         assert ref_share_hash == other_share_hash
216         
217         if share_hash in self.heights:
218             prev_ref = self.heights[share_hash][1]
219             self.reverse_heights[prev_ref].remove(share_hash)
220             if not self.reverse_heights[prev_ref] and prev_ref != ref:
221                 self.reverse_heights.pop(prev_ref)
222                 _, x, _ = self.height_refs.pop(prev_ref)
223                 self.reverse_height_refs.pop(x)
224         self.heights[share_hash] = height_inc - ref_height_to, ref, work_inc - ref_work_inc
225         self.reverse_heights.setdefault(ref, set()).add(share_hash)
226     
227     def get_height_work_and_last(self, share_hash):
228         assert isinstance(share_hash, (int, long, type(None)))
229         orig = share_hash
230         height = 0
231         work = 0
232         updates = []
233         while share_hash in self.shares:
234             updates.append((share_hash, height, work))
235             height_inc, share_hash, work_inc = self._get_height_jump(share_hash)
236             height += height_inc
237             work += work_inc
238         for update_hash, height_then, work_then in updates:
239             self._set_height_jump(update_hash, height - height_then, share_hash, work - work_then)
240         return height, work, share_hash
241     
242     def get_chain_known(self, start_hash):
243         assert isinstance(start_hash, (int, long, type(None)))
244         '''
245         Chain starting with item of hash I{start_hash} of items that this Tracker contains
246         '''
247         item_hash_to_get = start_hash
248         while True:
249             if item_hash_to_get not in self.shares:
250                 break
251             share = self.shares[item_hash_to_get]
252             assert not isinstance(share, long)
253             yield share
254             item_hash_to_get = share.previous_hash
255     
256     def get_chain_to_root(self, start_hash, root=None):
257         assert isinstance(start_hash, (int, long, type(None)))
258         assert isinstance(root, (int, long, type(None)))
259         '''
260         Chain of hashes starting with share_hash of shares to the root (doesn't include root)
261         Raises an error if one is missing
262         '''
263         share_hash_to_get = start_hash
264         while share_hash_to_get != root:
265             share = self.shares[share_hash_to_get]
266             yield share
267             share_hash_to_get = share.previous_hash
268     
269     def get_best_hash(self):
270         '''
271         Returns hash of item with the most items in its chain
272         '''
273         if not self.heads:
274             return None
275         return max(self.heads, key=self.get_height_and_last)
276     
277     def get_highest_height(self):
278         return max(self.get_height_and_last(head)[0] for head in self.heads) if self.heads else 0
279     
280     def is_child_of(self, share_hash, possible_child_hash):
281         height, last = self.get_height_and_last(share_hash)
282         child_height, child_last = self.get_height_and_last(possible_child_hash)
283         if child_last != last:
284             return None # not connected, so can't be determined
285         height_up = child_height - height
286         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == share_hash