changed Tracker.get_chain_known to the new, more explicit get_chain
[p2pool.git] / p2pool / util / forest.py
1 '''
2 forest data structure
3 '''
4
5 import itertools
6
7 from p2pool.util import skiplist, variable
8 from p2pool.bitcoin import data as bitcoin_data
9
10
11 class DistanceSkipList(skiplist.SkipList):
12     def __init__(self, tracker):
13         skiplist.SkipList.__init__(self)
14         self.tracker = tracker
15     
16     def previous(self, element):
17         return self.tracker.shares[element].previous_hash
18     
19     def get_delta(self, element):
20         return element, 1, self.tracker.shares[element].previous_hash
21     
22     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
23         if to_hash1 != from_hash2:
24             raise AssertionError()
25         return from_hash1, dist1 + dist2, to_hash2
26     
27     def initial_solution(self, start, (n,)):
28         return 0, start
29     
30     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
31         if to_hash1 != from_hash2:
32             raise AssertionError()
33         return dist1 + dist2, to_hash2
34     
35     def judge(self, (dist, hash), (n,)):
36         if dist > n:
37             return 1
38         elif dist == n:
39             return 0
40         else:
41             return -1
42     
43     def finalize(self, (dist, hash)):
44         return hash
45
46
47 class Tracker(object):
48     def __init__(self, shares=[]):
49         self.shares = {} # hash -> share
50         self.reverse_shares = {} # previous_hash -> set of share_hashes
51         
52         self.heads = {} # head hash -> tail_hash
53         self.tails = {} # tail hash -> set of head hashes
54         
55         self.heights = {} # share_hash -> height_to, ref, work_inc
56         self.reverse_heights = {} # ref -> set of share_hashes
57         
58         self.ref_generator = itertools.count()
59         self.height_refs = {} # ref -> height, share_hash, work_inc
60         self.reverse_height_refs = {} # share_hash -> ref
61         
62         self.get_nth_parent_hash = DistanceSkipList(self)
63         
64         self.added = variable.Event()
65         self.removed = variable.Event()
66         
67         for share in shares:
68             self.add(share)
69     
70     def add(self, share):
71         assert not isinstance(share, (int, long, type(None)))
72         if share.hash in self.shares:
73             raise ValueError('share already present')
74         
75         if share.hash in self.tails:
76             heads = self.tails.pop(share.hash)
77         else:
78             heads = set([share.hash])
79         
80         if share.previous_hash in self.heads:
81             tail = self.heads.pop(share.previous_hash)
82         else:
83             tail = self.get_last(share.previous_hash)
84         
85         self.shares[share.hash] = share
86         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
87         
88         self.tails.setdefault(tail, set()).update(heads)
89         if share.previous_hash in self.tails[tail]:
90             self.tails[tail].remove(share.previous_hash)
91         
92         for head in heads:
93             self.heads[head] = tail
94         
95         self.added.happened(share)
96     
97     def remove(self, share_hash):
98         assert isinstance(share_hash, (int, long, type(None)))
99         if share_hash not in self.shares:
100             raise KeyError()
101         
102         share = self.shares[share_hash]
103         del share_hash
104         
105         children = self.reverse_shares.get(share.hash, set())
106         
107         # move height refs referencing children down to this, so they can be moved up in one step
108         if share.previous_hash in self.reverse_height_refs:
109             if share.previous_hash not in self.tails:
110                 for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.previous_hash, object()), set())):
111                     self.get_last(x)
112             for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set())):
113                 self.get_last(x)
114             assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, None), set()))
115         
116         if share.hash in self.heads and share.previous_hash in self.tails:
117             tail = self.heads.pop(share.hash)
118             self.tails[tail].remove(share.hash)
119             if not self.tails[share.previous_hash]:
120                 self.tails.pop(share.previous_hash)
121         elif share.hash in self.heads:
122             tail = self.heads.pop(share.hash)
123             self.tails[tail].remove(share.hash)
124             if self.reverse_shares[share.previous_hash] != set([share.hash]):
125                 pass # has sibling
126             else:
127                 self.tails[tail].add(share.previous_hash)
128                 self.heads[share.previous_hash] = tail
129         elif share.previous_hash in self.tails and len(self.reverse_shares[share.previous_hash]) <= 1:
130             heads = self.tails.pop(share.previous_hash)
131             for head in heads:
132                 self.heads[head] = share.hash
133             self.tails[share.hash] = set(heads)
134             
135             # move ref pointing to this up
136             if share.previous_hash in self.reverse_height_refs:
137                 assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set()))
138                 
139                 ref = self.reverse_height_refs[share.previous_hash]
140                 cur_height, cur_hash, cur_work = self.height_refs[ref]
141                 assert cur_hash == share.previous_hash
142                 self.height_refs[ref] = cur_height - 1, share.hash, cur_work - bitcoin_data.target_to_average_attempts(share.target)
143                 del self.reverse_height_refs[share.previous_hash]
144                 self.reverse_height_refs[share.hash] = ref
145         else:
146             raise NotImplementedError()
147         
148         # delete height entry, and ref if it is empty
149         if share.hash in self.heights:
150             _, ref, _ = self.heights.pop(share.hash)
151             self.reverse_heights[ref].remove(share.hash)
152             if not self.reverse_heights[ref]:
153                 del self.reverse_heights[ref]
154                 _, ref_hash, _ = self.height_refs.pop(ref)
155                 del self.reverse_height_refs[ref_hash]
156         
157         self.shares.pop(share.hash)
158         self.reverse_shares[share.previous_hash].remove(share.hash)
159         if not self.reverse_shares[share.previous_hash]:
160             self.reverse_shares.pop(share.previous_hash)
161         
162         self.removed.happened(share)
163     
164     def get_height(self, share_hash):
165         height, work, last = self.get_height_work_and_last(share_hash)
166         return height
167     
168     def get_work(self, share_hash):
169         height, work, last = self.get_height_work_and_last(share_hash)
170         return work
171     
172     def get_last(self, share_hash):
173         height, work, last = self.get_height_work_and_last(share_hash)
174         return last
175     
176     def get_height_and_last(self, share_hash):
177         height, work, last = self.get_height_work_and_last(share_hash)
178         return height, last
179     
180     def _get_height_jump(self, share_hash):
181         if share_hash in self.heights:
182             height_to1, ref, work_inc1 = self.heights[share_hash]
183             height_to2, share_hash, work_inc2 = self.height_refs[ref]
184             height_inc = height_to1 + height_to2
185             work_inc = work_inc1 + work_inc2
186         else:
187             height_inc, share_hash, work_inc = 1, self.shares[share_hash].previous_hash, bitcoin_data.target_to_average_attempts(self.shares[share_hash].target)
188         return height_inc, share_hash, work_inc
189     
190     def _set_height_jump(self, share_hash, height_inc, other_share_hash, work_inc):
191         if other_share_hash not in self.reverse_height_refs:
192             ref = self.ref_generator.next()
193             assert ref not in self.height_refs
194             self.height_refs[ref] = 0, other_share_hash, 0
195             self.reverse_height_refs[other_share_hash] = ref
196             del ref
197         
198         ref = self.reverse_height_refs[other_share_hash]
199         ref_height_to, ref_share_hash, ref_work_inc = self.height_refs[ref]
200         assert ref_share_hash == other_share_hash
201         
202         if share_hash in self.heights:
203             prev_ref = self.heights[share_hash][1]
204             self.reverse_heights[prev_ref].remove(share_hash)
205             if not self.reverse_heights[prev_ref] and prev_ref != ref:
206                 self.reverse_heights.pop(prev_ref)
207                 _, x, _ = self.height_refs.pop(prev_ref)
208                 self.reverse_height_refs.pop(x)
209         self.heights[share_hash] = height_inc - ref_height_to, ref, work_inc - ref_work_inc
210         self.reverse_heights.setdefault(ref, set()).add(share_hash)
211     
212     def get_height_work_and_last(self, share_hash):
213         assert isinstance(share_hash, (int, long, type(None)))
214         height = 0
215         work = 0
216         updates = []
217         while share_hash in self.shares:
218             updates.append((share_hash, height, work))
219             height_inc, share_hash, work_inc = self._get_height_jump(share_hash)
220             height += height_inc
221             work += work_inc
222         for update_hash, height_then, work_then in updates:
223             self._set_height_jump(update_hash, height - height_then, share_hash, work - work_then)
224         return height, work, share_hash
225     
226     def get_chain(self, start_hash, length):
227         assert length <= self.get_height(start_hash)
228         for i in xrange(length):
229             yield self.shares[start_hash]
230             start_hash = self.shares[start_hash].previous_hash
231     
232     def is_child_of(self, share_hash, possible_child_hash):
233         height, last = self.get_height_and_last(share_hash)
234         child_height, child_last = self.get_height_and_last(possible_child_hash)
235         if child_last != last:
236             return None # not connected, so can't be determined
237         height_up = child_height - height
238         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == share_hash