speed up forest
[p2pool.git] / p2pool / util / forest.py
1 '''
2 forest data structure
3 '''
4
5 import itertools
6
7 from p2pool.util import skiplist, variable
8 from p2pool.bitcoin import data as bitcoin_data
9
10
11 class DistanceSkipList(skiplist.SkipList):
12     def __init__(self, tracker):
13         skiplist.SkipList.__init__(self)
14         self.tracker = tracker
15     
16     def previous(self, element):
17         return self.tracker.shares[element].previous_hash
18     
19     def get_delta(self, element):
20         return element, 1, self.tracker.shares[element].previous_hash
21     
22     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
23         if to_hash1 != from_hash2:
24             raise AssertionError()
25         return from_hash1, dist1 + dist2, to_hash2
26     
27     def initial_solution(self, start, (n,)):
28         return 0, start
29     
30     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
31         if to_hash1 != from_hash2:
32             raise AssertionError()
33         return dist1 + dist2, to_hash2
34     
35     def judge(self, (dist, hash), (n,)):
36         if dist > n:
37             return 1
38         elif dist == n:
39             return 0
40         else:
41             return -1
42     
43     def finalize(self, (dist, hash)):
44         return hash
45
46
47 class Tracker(object):
48     def __init__(self, shares=[]):
49         self.shares = {} # hash -> share
50         self.reverse_shares = {} # previous_hash -> set of share_hashes
51         
52         self.heads = {} # head hash -> tail_hash
53         self.tails = {} # tail hash -> set of head hashes
54         
55         self.heights = {} # share_hash -> height_to, ref, work_inc
56         self.reverse_heights = {} # ref -> set of share_hashes
57         
58         self.ref_generator = itertools.count()
59         self.height_refs = {} # ref -> height, share_hash, work_inc
60         self.reverse_height_refs = {} # share_hash -> ref
61         
62         self.get_nth_parent_hash = DistanceSkipList(self)
63         
64         self.added = variable.Event()
65         self.removed = variable.Event()
66         
67         for share in shares:
68             self.add(share)
69     
70     def add(self, share):
71         assert not isinstance(share, (int, long, type(None)))
72         if share.hash in self.shares:
73             raise ValueError('share already present')
74         
75         if share.hash in self.tails:
76             heads = self.tails.pop(share.hash)
77         else:
78             heads = set([share.hash])
79         
80         if share.previous_hash in self.heads:
81             tail = self.heads.pop(share.previous_hash)
82         else:
83             tail = self.get_last(share.previous_hash)
84         
85         self.shares[share.hash] = share
86         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
87         
88         self.tails.setdefault(tail, set()).update(heads)
89         if share.previous_hash in self.tails[tail]:
90             self.tails[tail].remove(share.previous_hash)
91         
92         for head in heads:
93             self.heads[head] = tail
94         
95         self.added.happened(share)
96     
97     def remove(self, share_hash):
98         assert isinstance(share_hash, (int, long, type(None)))
99         if share_hash not in self.shares:
100             raise KeyError()
101         
102         share = self.shares[share_hash]
103         del share_hash
104         
105         children = self.reverse_shares.get(share.hash, set())
106         
107         if share.hash in self.heads and share.previous_hash in self.tails:
108             tail = self.heads.pop(share.hash)
109             self.tails[tail].remove(share.hash)
110             if not self.tails[share.previous_hash]:
111                 self.tails.pop(share.previous_hash)
112         elif share.hash in self.heads:
113             tail = self.heads.pop(share.hash)
114             self.tails[tail].remove(share.hash)
115             if self.reverse_shares[share.previous_hash] != set([share.hash]):
116                 pass # has sibling
117             else:
118                 self.tails[tail].add(share.previous_hash)
119                 self.heads[share.previous_hash] = tail
120         elif share.previous_hash in self.tails and len(self.reverse_shares[share.previous_hash]) <= 1:
121             # move height refs referencing children down to this, so they can be moved up in one step
122             if share.previous_hash in self.reverse_height_refs:
123                 for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set())):
124                     self.get_last(x)
125                 assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, None), set()))
126             
127             heads = self.tails.pop(share.previous_hash)
128             for head in heads:
129                 self.heads[head] = share.hash
130             self.tails[share.hash] = set(heads)
131             
132             # move ref pointing to this up
133             if share.previous_hash in self.reverse_height_refs:
134                 assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set()))
135                 
136                 ref = self.reverse_height_refs[share.previous_hash]
137                 cur_height, cur_hash, cur_work = self.height_refs[ref]
138                 assert cur_hash == share.previous_hash
139                 self.height_refs[ref] = cur_height - 1, share.hash, cur_work - bitcoin_data.target_to_average_attempts(share.target)
140                 del self.reverse_height_refs[share.previous_hash]
141                 self.reverse_height_refs[share.hash] = ref
142         else:
143             raise NotImplementedError()
144         
145         # delete height entry and ref if it is empty
146         if share.hash in self.heights:
147             _, ref, _ = self.heights.pop(share.hash)
148             self.reverse_heights[ref].remove(share.hash)
149             if not self.reverse_heights[ref]:
150                 del self.reverse_heights[ref]
151                 _, ref_hash, _ = self.height_refs.pop(ref)
152                 del self.reverse_height_refs[ref_hash]
153         
154         self.shares.pop(share.hash)
155         self.reverse_shares[share.previous_hash].remove(share.hash)
156         if not self.reverse_shares[share.previous_hash]:
157             self.reverse_shares.pop(share.previous_hash)
158         
159         self.removed.happened(share)
160     
161     def get_height(self, share_hash):
162         height, work, last = self.get_height_work_and_last(share_hash)
163         return height
164     
165     def get_work(self, share_hash):
166         height, work, last = self.get_height_work_and_last(share_hash)
167         return work
168     
169     def get_last(self, share_hash):
170         height, work, last = self.get_height_work_and_last(share_hash)
171         return last
172     
173     def get_height_and_last(self, share_hash):
174         height, work, last = self.get_height_work_and_last(share_hash)
175         return height, last
176     
177     def _get_height_jump(self, share_hash):
178         if share_hash in self.heights:
179             height_to1, ref, work_inc1 = self.heights[share_hash]
180             height_to2, share_hash, work_inc2 = self.height_refs[ref]
181             height_inc = height_to1 + height_to2
182             work_inc = work_inc1 + work_inc2
183         else:
184             height_inc, share_hash, work_inc = 1, self.shares[share_hash].previous_hash, bitcoin_data.target_to_average_attempts(self.shares[share_hash].target)
185         return height_inc, share_hash, work_inc
186     
187     def _set_height_jump(self, share_hash, height_inc, other_share_hash, work_inc):
188         if other_share_hash not in self.reverse_height_refs:
189             ref = self.ref_generator.next()
190             assert ref not in self.height_refs
191             self.height_refs[ref] = 0, other_share_hash, 0
192             self.reverse_height_refs[other_share_hash] = ref
193             del ref
194         
195         ref = self.reverse_height_refs[other_share_hash]
196         ref_height_to, ref_share_hash, ref_work_inc = self.height_refs[ref]
197         assert ref_share_hash == other_share_hash
198         
199         if share_hash in self.heights:
200             prev_ref = self.heights[share_hash][1]
201             self.reverse_heights[prev_ref].remove(share_hash)
202             if not self.reverse_heights[prev_ref] and prev_ref != ref:
203                 self.reverse_heights.pop(prev_ref)
204                 _, x, _ = self.height_refs.pop(prev_ref)
205                 self.reverse_height_refs.pop(x)
206         self.heights[share_hash] = height_inc - ref_height_to, ref, work_inc - ref_work_inc
207         self.reverse_heights.setdefault(ref, set()).add(share_hash)
208     
209     def get_height_work_and_last(self, share_hash):
210         assert isinstance(share_hash, (int, long, type(None)))
211         height = 0
212         work = 0
213         updates = []
214         while share_hash in self.shares:
215             updates.append((share_hash, height, work))
216             height_inc, share_hash, work_inc = self._get_height_jump(share_hash)
217             height += height_inc
218             work += work_inc
219         for update_hash, height_then, work_then in updates:
220             self._set_height_jump(update_hash, height - height_then, share_hash, work - work_then)
221         return height, work, share_hash
222     
223     def get_chain(self, start_hash, length):
224         assert length <= self.get_height(start_hash)
225         for i in xrange(length):
226             yield self.shares[start_hash]
227             start_hash = self.shares[start_hash].previous_hash
228     
229     def is_child_of(self, share_hash, possible_child_hash):
230         height, last = self.get_height_and_last(share_hash)
231         child_height, child_last = self.get_height_and_last(possible_child_hash)
232         if child_last != last:
233             return None # not connected, so can't be determined
234         height_up = child_height - height
235         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == share_hash