made Tracker use head/tail from deltas instead of share.hash/previous_hash
[p2pool.git] / p2pool / util / forest.py
1 '''
2 forest data structure
3 '''
4
5 import itertools
6 import weakref
7
8 from p2pool.util import skiplist, variable
9 from p2pool.bitcoin import data as bitcoin_data
10
11
12 class TrackerSkipList(skiplist.SkipList):
13     def __init__(self, tracker):
14         skiplist.SkipList.__init__(self)
15         self.tracker = tracker
16         
17         self_ref = weakref.ref(self, lambda _: tracker.removed.unwatch(watch_id))
18         watch_id = self.tracker.removed.watch(lambda share: self_ref().forget_item(share.hash))
19     
20     def previous(self, element):
21         return self.tracker.delta_type.from_element(self.tracker.shares[element]).tail
22
23
24 class DistanceSkipList(TrackerSkipList):
25     def get_delta(self, element):
26         return element, 1, self.previous(element)
27     
28     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
29         if to_hash1 != from_hash2:
30             raise AssertionError()
31         return from_hash1, dist1 + dist2, to_hash2
32     
33     def initial_solution(self, start, (n,)):
34         return 0, start
35     
36     def apply_delta(self, (dist1, to_hash1), (from_hash2, dist2, to_hash2), (n,)):
37         if to_hash1 != from_hash2:
38             raise AssertionError()
39         return dist1 + dist2, to_hash2
40     
41     def judge(self, (dist, hash), (n,)):
42         if dist > n:
43             return 1
44         elif dist == n:
45             return 0
46         else:
47             return -1
48     
49     def finalize(self, (dist, hash), (n,)):
50         assert dist == n
51         return hash
52
53
54 class AttributeDelta(object):
55     __slots__ = 'head height work tail'.split(' ')
56     
57     @classmethod
58     def get_none(cls, element_id):
59         return cls(element_id, 0, 0, element_id)
60     
61     @classmethod
62     def from_element(cls, share):
63         return cls(share.hash, 1, bitcoin_data.target_to_average_attempts(share.target), share.previous_hash)
64     
65     def __init__(self, head, height, work, tail):
66         self.head, self.height, self.work, self.tail = head, height, work, tail
67     
68     def __add__(self, other):
69         assert self.tail == other.head
70         return AttributeDelta(self.head, self.height + other.height, self.work + other.work, other.tail)
71     
72     def __sub__(self, other):
73         if self.head == other.head:
74             return AttributeDelta(other.tail, self.height - other.height, self.work - other.work, self.tail)
75         elif self.tail == other.tail:
76             return AttributeDelta(self.head, self.height - other.height, self.work - other.work, other.head)
77         else:
78             raise AssertionError()
79     
80     def __repr__(self):
81         return str(self.__class__) + str((self.head, self.height, self.work, self.tail))
82
83 class Tracker(object):
84     def __init__(self, shares=[], delta_type=AttributeDelta):
85         self.shares = {} # hash -> share
86         self.reverse_shares = {} # delta.tail -> set of share_hashes
87         
88         self.heads = {} # head hash -> tail_hash
89         self.tails = {} # tail hash -> set of head hashes
90         
91         self.deltas = {} # share_hash -> delta, ref
92         self.reverse_deltas = {} # ref -> set of share_hashes
93         
94         self.ref_generator = itertools.count()
95         self.delta_refs = {} # ref -> delta
96         self.reverse_delta_refs = {} # delta.tail -> ref
97         
98         self.added = variable.Event()
99         self.removed = variable.Event()
100         
101         self.get_nth_parent_hash = DistanceSkipList(self)
102         
103         self.delta_type = delta_type
104         
105         for share in shares:
106             self.add(share)
107     
108     def add(self, share):
109         assert not isinstance(share, (int, long, type(None)))
110         delta = self.delta_type.from_element(share)
111         
112         if delta.head in self.shares:
113             raise ValueError('share already present')
114         
115         if delta.head in self.tails:
116             heads = self.tails.pop(delta.head)
117         else:
118             heads = set([delta.head])
119         
120         if delta.tail in self.heads:
121             tail = self.heads.pop(delta.tail)
122         else:
123             tail = self.get_last(delta.tail)
124         
125         self.shares[delta.head] = share
126         self.reverse_shares.setdefault(delta.tail, set()).add(delta.head)
127         
128         self.tails.setdefault(tail, set()).update(heads)
129         if delta.tail in self.tails[tail]:
130             self.tails[tail].remove(delta.tail)
131         
132         for head in heads:
133             self.heads[head] = tail
134         
135         self.added.happened(share)
136     
137     def remove(self, share_hash):
138         assert isinstance(share_hash, (int, long, type(None)))
139         if share_hash not in self.shares:
140             raise KeyError()
141         
142         share = self.shares[share_hash]
143         del share_hash
144         
145         delta = self.delta_type.from_element(share)
146         
147         children = self.reverse_shares.get(delta.head, set())
148         
149         if delta.head in self.heads and delta.tail in self.tails:
150             tail = self.heads.pop(delta.head)
151             self.tails[tail].remove(delta.head)
152             if not self.tails[delta.tail]:
153                 self.tails.pop(delta.tail)
154         elif delta.head in self.heads:
155             tail = self.heads.pop(delta.head)
156             self.tails[tail].remove(delta.head)
157             if self.reverse_shares[delta.tail] != set([delta.head]):
158                 pass # has sibling
159             else:
160                 self.tails[tail].add(delta.tail)
161                 self.heads[delta.tail] = tail
162         elif delta.tail in self.tails and len(self.reverse_shares[delta.tail]) <= 1:
163             # move delta refs referencing children down to this, so they can be moved up in one step
164             if delta.tail in self.reverse_delta_refs:
165                 for x in list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, object()), set())):
166                     self.get_last(x)
167                 assert delta.head not in self.reverse_delta_refs, list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, None), set()))
168             
169             heads = self.tails.pop(delta.tail)
170             for head in heads:
171                 self.heads[head] = delta.head
172             self.tails[delta.head] = set(heads)
173             
174             # move ref pointing to this up
175             if delta.tail in self.reverse_delta_refs:
176                 assert delta.head not in self.reverse_delta_refs, list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, object()), set()))
177                 
178                 ref = self.reverse_delta_refs[delta.tail]
179                 cur_delta = self.delta_refs[ref]
180                 assert cur_delta.tail == delta.tail
181                 self.delta_refs[ref] = cur_delta - self.delta_type.from_element(share)
182                 assert self.delta_refs[ref].tail == delta.head
183                 del self.reverse_delta_refs[delta.tail]
184                 self.reverse_delta_refs[delta.head] = ref
185         else:
186             raise NotImplementedError()
187         
188         # delete delta entry and ref if it is empty
189         if delta.head in self.deltas:
190             delta1, ref = self.deltas.pop(delta.head)
191             self.reverse_deltas[ref].remove(delta.head)
192             if not self.reverse_deltas[ref]:
193                 del self.reverse_deltas[ref]
194                 delta2 = self.delta_refs.pop(ref)
195                 del self.reverse_delta_refs[delta2.tail]
196         
197         self.shares.pop(delta.head)
198         self.reverse_shares[delta.tail].remove(delta.head)
199         if not self.reverse_shares[delta.tail]:
200             self.reverse_shares.pop(delta.tail)
201         
202         self.removed.happened(share)
203     
204     def get_height(self, share_hash):
205         return self.get_delta(share_hash).height
206     
207     def get_work(self, share_hash):
208         return self.get_delta(share_hash).work
209     
210     def get_last(self, share_hash):
211         return self.get_delta(share_hash).tail
212     
213     def get_height_and_last(self, share_hash):
214         delta = self.get_delta(share_hash)
215         return delta.height, delta.tail
216     
217     def get_height_work_and_last(self, share_hash):
218         delta = self.get_delta(share_hash)
219         return delta.height, delta.work, delta.tail
220     
221     def _get_delta(self, share_hash):
222         if share_hash in self.deltas:
223             delta1, ref = self.deltas[share_hash]
224             delta2 = self.delta_refs[ref]
225             res = delta1 + delta2
226         else:
227             res = self.delta_type.from_element(self.shares[share_hash])
228         assert res.head == share_hash
229         return res
230     
231     def _set_delta(self, share_hash, delta):
232         other_share_hash = delta.tail
233         if other_share_hash not in self.reverse_delta_refs:
234             ref = self.ref_generator.next()
235             assert ref not in self.delta_refs
236             self.delta_refs[ref] = self.delta_type.get_none(other_share_hash)
237             self.reverse_delta_refs[other_share_hash] = ref
238             del ref
239         
240         ref = self.reverse_delta_refs[other_share_hash]
241         ref_delta = self.delta_refs[ref]
242         assert ref_delta.tail == other_share_hash
243         
244         if share_hash in self.deltas:
245             prev_ref = self.deltas[share_hash][1]
246             self.reverse_deltas[prev_ref].remove(share_hash)
247             if not self.reverse_deltas[prev_ref] and prev_ref != ref:
248                 self.reverse_deltas.pop(prev_ref)
249                 x = self.delta_refs.pop(prev_ref)
250                 self.reverse_delta_refs.pop(x.tail)
251         self.deltas[share_hash] = delta - ref_delta, ref
252         self.reverse_deltas.setdefault(ref, set()).add(share_hash)
253     
254     def get_delta(self, share_hash):
255         assert isinstance(share_hash, (int, long, type(None)))
256         delta = self.delta_type.get_none(share_hash)
257         updates = []
258         while delta.tail in self.shares:
259             updates.append((delta.tail, delta))
260             this_delta = self._get_delta(delta.tail)
261             delta += this_delta
262         for update_hash, delta_then in updates:
263             self._set_delta(update_hash, delta - delta_then)
264         return delta
265     
266     def get_chain(self, start_hash, length):
267         assert length <= self.get_height(start_hash)
268         for i in xrange(length):
269             yield self.shares[start_hash]
270             start_hash = self.delta_type.from_element(self.shares[start_hash]).tail
271     
272     def is_child_of(self, share_hash, possible_child_hash):
273         height, last = self.get_height_and_last(share_hash)
274         child_height, child_last = self.get_height_and_last(possible_child_hash)
275         if child_last != last:
276             return None # not connected, so can't be determined
277         height_up = child_height - height
278         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == share_hash