merged preremove_special and postremove_special since they didn't need to be separate
[p2pool.git] / p2pool / util / forest.py
index 67e4e2f..96d4b78 100644 (file)
@@ -1,19 +1,26 @@
+'''
+forest data structure
+'''
+
 import itertools
 
 from p2pool.util import skiplist, variable
 
-from p2pool.bitcoin import data as bitcoin_data
 
-class DistanceSkipList(skiplist.SkipList):
+class TrackerSkipList(skiplist.SkipList):
     def __init__(self, tracker):
         skiplist.SkipList.__init__(self)
         self.tracker = tracker
+        
+        self.tracker.removed.watch_weakref(self, lambda self, item: self.forget_item(item.hash))
     
     def previous(self, element):
-        return self.tracker.shares[element].previous_hash
-    
+        return self.tracker._delta_type.from_element(self.tracker.items[element]).tail
+
+
+class DistanceSkipList(TrackerSkipList):
     def get_delta(self, element):
-        return element, 1, self.tracker.shares[element].previous_hash
+        return element, 1, self.previous(element)
     
     def combine_deltas(self, (from_hash1, dist1, to_hash1), (from_hash2, dist2, to_hash2)):
         if to_hash1 != from_hash2:
@@ -36,238 +43,286 @@ class DistanceSkipList(skiplist.SkipList):
         else:
             return -1
     
-    def finalize(self, (dist, hash)):
+    def finalize(self, (dist, hash), (n,)):
+        assert dist == n
         return hash
 
+def get_attributedelta_type(attrs): # attrs: {name: func}
+    class ProtoAttributeDelta(object):
+        __slots__ = ['head', 'tail'] + attrs.keys()
+        
+        @classmethod
+        def get_none(cls, element_id):
+            return cls(element_id, element_id, **dict((k, 0) for k in attrs))
+        
+        @classmethod
+        def from_element(cls, item):
+            return cls(item.hash, item.previous_hash, **dict((k, v(item)) for k, v in attrs.iteritems()))
+        
+        def __init__(self, head, tail, **kwargs):
+            self.head, self.tail = head, tail
+            for k, v in kwargs.iteritems():
+                setattr(self, k, v)
+        
+        def __add__(self, other):
+            assert self.tail == other.head
+            return self.__class__(self.head, other.tail, **dict((k, getattr(self, k) + getattr(other, k)) for k in attrs))
+        
+        def __sub__(self, other):
+            if self.head == other.head:
+                return self.__class__(other.tail, self.tail, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
+            elif self.tail == other.tail:
+                return self.__class__(self.head, other.head, **dict((k, getattr(self, k) - getattr(other, k)) for k in attrs))
+            else:
+                raise AssertionError()
+        
+        def __repr__(self):
+            return '%s(%r, %r%s)' % (self.__class__, self.head, self.tail, ''.join(', %s=%r' % (k, getattr(self, k)) for k in attrs))
+    ProtoAttributeDelta.attrs = attrs
+    return ProtoAttributeDelta
 
-# linked list tracker
+AttributeDelta = get_attributedelta_type(dict(
+    height=lambda item: 1,
+))
+
+class TrackerView(object):
+    def __init__(self, tracker, delta_type):
+        self._tracker = tracker
+        self._delta_type = delta_type
+        
+        self._deltas = {} # item_hash -> delta, ref
+        self._reverse_deltas = {} # ref -> set of item_hashes
+        
+        self._ref_generator = itertools.count()
+        self._delta_refs = {} # ref -> delta
+        self._reverse_delta_refs = {} # delta.tail -> ref
+        
+        self._tracker.remove_special.watch_weakref(self, lambda self, item: self._handle_remove_special(item))
+        self._tracker.removed.watch_weakref(self, lambda self, item: self._handle_removed(item))
+    
+    def _handle_remove_special(self, item):
+        delta = self._delta_type.from_element(item)
+        
+        if delta.tail not in self._reverse_delta_refs:
+            return
+        
+        # move delta refs referencing children down to this, so they can be moved up in one step
+        for x in list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set())):
+            self.get_last(x)
+        
+        assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set()))
+        
+        if delta.tail not in self._reverse_delta_refs:
+            return
+        
+        # move ref pointing to this up
+        
+        ref = self._reverse_delta_refs[delta.tail]
+        cur_delta = self._delta_refs[ref]
+        assert cur_delta.tail == delta.tail
+        self._delta_refs[ref] = cur_delta - self._delta_type.from_element(item)
+        assert self._delta_refs[ref].tail == delta.head
+        del self._reverse_delta_refs[delta.tail]
+        self._reverse_delta_refs[delta.head] = ref
+    
+    def _handle_removed(self, item):
+        delta = self._delta_type.from_element(item)
+        
+        # delete delta entry and ref if it is empty
+        if delta.head in self._deltas:
+            delta1, ref = self._deltas.pop(delta.head)
+            self._reverse_deltas[ref].remove(delta.head)
+            if not self._reverse_deltas[ref]:
+                del self._reverse_deltas[ref]
+                delta2 = self._delta_refs.pop(ref)
+                del self._reverse_delta_refs[delta2.tail]
+    
+    
+    def get_height(self, item_hash):
+        return self.get_delta_to_last(item_hash).height
+    
+    def get_work(self, item_hash):
+        return self.get_delta_to_last(item_hash).work
+    
+    def get_last(self, item_hash):
+        return self.get_delta_to_last(item_hash).tail
+    
+    def get_height_and_last(self, item_hash):
+        delta = self.get_delta_to_last(item_hash)
+        return delta.height, delta.tail
+    
+    def _get_delta(self, item_hash):
+        if item_hash in self._deltas:
+            delta1, ref = self._deltas[item_hash]
+            delta2 = self._delta_refs[ref]
+            res = delta1 + delta2
+        else:
+            res = self._delta_type.from_element(self._tracker.items[item_hash])
+        assert res.head == item_hash
+        return res
+    
+    def _set_delta(self, item_hash, delta):
+        other_item_hash = delta.tail
+        if other_item_hash not in self._reverse_delta_refs:
+            ref = self._ref_generator.next()
+            assert ref not in self._delta_refs
+            self._delta_refs[ref] = self._delta_type.get_none(other_item_hash)
+            self._reverse_delta_refs[other_item_hash] = ref
+            del ref
+        
+        ref = self._reverse_delta_refs[other_item_hash]
+        ref_delta = self._delta_refs[ref]
+        assert ref_delta.tail == other_item_hash
+        
+        if item_hash in self._deltas:
+            prev_ref = self._deltas[item_hash][1]
+            self._reverse_deltas[prev_ref].remove(item_hash)
+            if not self._reverse_deltas[prev_ref] and prev_ref != ref:
+                self._reverse_deltas.pop(prev_ref)
+                x = self._delta_refs.pop(prev_ref)
+                self._reverse_delta_refs.pop(x.tail)
+        self._deltas[item_hash] = delta - ref_delta, ref
+        self._reverse_deltas.setdefault(ref, set()).add(item_hash)
+    
+    def get_delta_to_last(self, item_hash):
+        assert isinstance(item_hash, (int, long, type(None)))
+        delta = self._delta_type.get_none(item_hash)
+        updates = []
+        while delta.tail in self._tracker.items:
+            updates.append((delta.tail, delta))
+            this_delta = self._get_delta(delta.tail)
+            delta += this_delta
+        for update_hash, delta_then in updates:
+            self._set_delta(update_hash, delta - delta_then)
+        return delta
+    
+    def get_delta(self, item, ancestor):
+        assert self._tracker.is_child_of(ancestor, item)
+        return self.get_delta_to_last(item) - self.get_delta_to_last(ancestor)
 
 class Tracker(object):
-    def __init__(self, shares=[]):
-        self.shares = {} # hash -> share
-        #self.ids = {} # hash -> (id, height)
-        self.reverse_shares = {} # previous_hash -> set of share_hashes
+    def __init__(self, items=[], delta_type=AttributeDelta):
+        self.items = {} # hash -> item
+        self.reverse = {} # delta.tail -> set of item_hashes
         
         self.heads = {} # head hash -> tail_hash
         self.tails = {} # tail hash -> set of head hashes
         
-        self.heights = {} # share_hash -> height_to, ref, work_inc
-        self.reverse_heights = {} # ref -> set of share_hashes
-        
-        self.ref_generator = itertools.count()
-        self.height_refs = {} # ref -> height, share_hash, work_inc
-        self.reverse_height_refs = {} # share_hash -> ref
+        self.added = variable.Event()
+        self.remove_special = variable.Event()
+        self.removed = variable.Event()
         
         self.get_nth_parent_hash = DistanceSkipList(self)
         
-        self.added = variable.Event()
-        self.removed = variable.Event()
+        self._delta_type = delta_type
+        self._default_view = TrackerView(self, delta_type)
         
-        for share in shares:
-            self.add(share)
+        for item in items:
+            self.add(item)
+    
+    def __getattr__(self, name):
+        attr = getattr(self._default_view, name)
+        setattr(self, name, attr)
+        return attr
     
-    def add(self, share):
-        assert not isinstance(share, (int, long, type(None)))
-        if share.hash in self.shares:
-            raise ValueError('share already present')
+    def add(self, item):
+        assert not isinstance(item, (int, long, type(None)))
+        delta = self._delta_type.from_element(item)
         
-        if share.hash in self.tails:
-            heads = self.tails.pop(share.hash)
+        if delta.head in self.items:
+            raise ValueError('item already present')
+        
+        if delta.head in self.tails:
+            heads = self.tails.pop(delta.head)
         else:
-            heads = set([share.hash])
+            heads = set([delta.head])
         
-        if share.previous_hash in self.heads:
-            tail = self.heads.pop(share.previous_hash)
+        if delta.tail in self.heads:
+            tail = self.heads.pop(delta.tail)
         else:
-            tail = self.get_last(share.previous_hash)
-            #tail2 = share.previous_hash
-            #while tail2 in self.shares:
-            #    tail2 = self.shares[tail2].previous_hash
-            #assert tail == tail2
+            tail = self.get_last(delta.tail)
         
-        self.shares[share.hash] = share
-        self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
+        self.items[delta.head] = item
+        self.reverse.setdefault(delta.tail, set()).add(delta.head)
         
         self.tails.setdefault(tail, set()).update(heads)
-        if share.previous_hash in self.tails[tail]:
-            self.tails[tail].remove(share.previous_hash)
+        if delta.tail in self.tails[tail]:
+            self.tails[tail].remove(delta.tail)
         
         for head in heads:
             self.heads[head] = tail
         
-        self.added.happened(share)
+        self.added.happened(item)
     
-    def remove(self, share_hash):
-        assert isinstance(share_hash, (int, long, type(None)))
-        if share_hash not in self.shares:
+    def remove(self, item_hash):
+        assert isinstance(item_hash, (int, long, type(None)))
+        if item_hash not in self.items:
             raise KeyError()
         
-        share = self.shares[share_hash]
-        del share_hash
-        
-        children = self.reverse_shares.get(share.hash, set())
-        
-        # move height refs referencing children down to this, so they can be moved up in one step
-        if share.previous_hash in self.reverse_height_refs:
-            if share.previous_hash not in self.tails:
-                for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.previous_hash, object()), set())):
-                    self.get_last(x)
-            for x in list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set())):
-                self.get_last(x)
-            assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, None), set()))
-        
-        if share.hash in self.heads and share.previous_hash in self.tails:
-            tail = self.heads.pop(share.hash)
-            self.tails[tail].remove(share.hash)
-            if not self.tails[share.previous_hash]:
-                self.tails.pop(share.previous_hash)
-        elif share.hash in self.heads:
-            tail = self.heads.pop(share.hash)
-            self.tails[tail].remove(share.hash)
-            if self.reverse_shares[share.previous_hash] != set([share.hash]):
+        item = self.items[item_hash]
+        del item_hash
+        
+        delta = self._delta_type.from_element(item)
+        
+        children = self.reverse.get(delta.head, set())
+        
+        if delta.head in self.heads and delta.tail in self.tails:
+            tail = self.heads.pop(delta.head)
+            self.tails[tail].remove(delta.head)
+            if not self.tails[delta.tail]:
+                self.tails.pop(delta.tail)
+        elif delta.head in self.heads:
+            tail = self.heads.pop(delta.head)
+            self.tails[tail].remove(delta.head)
+            if self.reverse[delta.tail] != set([delta.head]):
                 pass # has sibling
             else:
-                self.tails[tail].add(share.previous_hash)
-                self.heads[share.previous_hash] = tail
-        elif share.previous_hash in self.tails and len(self.reverse_shares[share.previous_hash]) <= 1:
-            heads = self.tails.pop(share.previous_hash)
+                self.tails[tail].add(delta.tail)
+                self.heads[delta.tail] = tail
+        elif delta.tail in self.tails and len(self.reverse[delta.tail]) <= 1:
+            heads = self.tails.pop(delta.tail)
             for head in heads:
-                self.heads[head] = share.hash
-            self.tails[share.hash] = set(heads)
+                self.heads[head] = delta.head
+            self.tails[delta.head] = set(heads)
             
-            # move ref pointing to this up
-            if share.previous_hash in self.reverse_height_refs:
-                assert share.hash not in self.reverse_height_refs, list(self.reverse_heights.get(self.reverse_height_refs.get(share.hash, object()), set()))
-                
-                ref = self.reverse_height_refs[share.previous_hash]
-                cur_height, cur_hash, cur_work = self.height_refs[ref]
-                assert cur_hash == share.previous_hash
-                self.height_refs[ref] = cur_height - 1, share.hash, cur_work - bitcoin_data.target_to_average_attempts(share.target)
-                del self.reverse_height_refs[share.previous_hash]
-                self.reverse_height_refs[share.hash] = ref
+            self.remove_special.happened(item)
         else:
             raise NotImplementedError()
         
-        # delete height entry, and ref if it is empty
-        if share.hash in self.heights:
-            _, ref, _ = self.heights.pop(share.hash)
-            self.reverse_heights[ref].remove(share.hash)
-            if not self.reverse_heights[ref]:
-                del self.reverse_heights[ref]
-                _, ref_hash, _ = self.height_refs.pop(ref)
-                del self.reverse_height_refs[ref_hash]
-        
-        self.shares.pop(share.hash)
-        self.reverse_shares[share.previous_hash].remove(share.hash)
-        if not self.reverse_shares[share.previous_hash]:
-            self.reverse_shares.pop(share.previous_hash)
-        
-        self.removed.happened(share)
-    
-    def get_height(self, share_hash):
-        height, work, last = self.get_height_work_and_last(share_hash)
-        return height
-    
-    def get_work(self, share_hash):
-        height, work, last = self.get_height_work_and_last(share_hash)
-        return work
-    
-    def get_last(self, share_hash):
-        height, work, last = self.get_height_work_and_last(share_hash)
-        return last
-    
-    def get_height_and_last(self, share_hash):
-        height, work, last = self.get_height_work_and_last(share_hash)
-        return height, last
-    
-    def _get_height_jump(self, share_hash):
-        if share_hash in self.heights:
-            height_to1, ref, work_inc1 = self.heights[share_hash]
-            height_to2, share_hash, work_inc2 = self.height_refs[ref]
-            height_inc = height_to1 + height_to2
-            work_inc = work_inc1 + work_inc2
-        else:
-            height_inc, share_hash, work_inc = 1, self.shares[share_hash].previous_hash, bitcoin_data.target_to_average_attempts(self.shares[share_hash].target)
-        return height_inc, share_hash, work_inc
-    
-    def _set_height_jump(self, share_hash, height_inc, other_share_hash, work_inc):
-        if other_share_hash not in self.reverse_height_refs:
-            ref = self.ref_generator.next()
-            assert ref not in self.height_refs
-            self.height_refs[ref] = 0, other_share_hash, 0
-            self.reverse_height_refs[other_share_hash] = ref
-            del ref
+        self.items.pop(delta.head)
+        self.reverse[delta.tail].remove(delta.head)
+        if not self.reverse[delta.tail]:
+            self.reverse.pop(delta.tail)
         
-        ref = self.reverse_height_refs[other_share_hash]
-        ref_height_to, ref_share_hash, ref_work_inc = self.height_refs[ref]
-        assert ref_share_hash == other_share_hash
-        
-        if share_hash in self.heights:
-            prev_ref = self.heights[share_hash][1]
-            self.reverse_heights[prev_ref].remove(share_hash)
-            if not self.reverse_heights[prev_ref] and prev_ref != ref:
-                self.reverse_heights.pop(prev_ref)
-                _, x, _ = self.height_refs.pop(prev_ref)
-                self.reverse_height_refs.pop(x)
-        self.heights[share_hash] = height_inc - ref_height_to, ref, work_inc - ref_work_inc
-        self.reverse_heights.setdefault(ref, set()).add(share_hash)
+        self.removed.happened(item)
     
-    def get_height_work_and_last(self, share_hash):
-        assert isinstance(share_hash, (int, long, type(None)))
-        height = 0
-        work = 0
-        updates = []
-        while share_hash in self.shares:
-            updates.append((share_hash, height, work))
-            height_inc, share_hash, work_inc = self._get_height_jump(share_hash)
-            height += height_inc
-            work += work_inc
-        for update_hash, height_then, work_then in updates:
-            self._set_height_jump(update_hash, height - height_then, share_hash, work - work_then)
-        return height, work, share_hash
-    
-    def get_chain_known(self, start_hash):
-        assert isinstance(start_hash, (int, long, type(None)))
-        '''
-        Chain starting with item of hash I{start_hash} of items that this Tracker contains
-        '''
-        item_hash_to_get = start_hash
-        while True:
-            if item_hash_to_get not in self.shares:
-                break
-            share = self.shares[item_hash_to_get]
-            assert not isinstance(share, long)
-            yield share
-            item_hash_to_get = share.previous_hash
-    
-    def get_chain_to_root(self, start_hash, root=None):
-        assert isinstance(start_hash, (int, long, type(None)))
-        assert isinstance(root, (int, long, type(None)))
-        '''
-        Chain of hashes starting with share_hash of shares to the root (doesn't include root)
-        Raises an error if one is missing
-        '''
-        share_hash_to_get = start_hash
-        while share_hash_to_get != root:
-            share = self.shares[share_hash_to_get]
-            yield share
-            share_hash_to_get = share.previous_hash
-    
-    def get_best_hash(self):
-        '''
-        Returns hash of item with the most items in its chain
-        '''
-        if not self.heads:
-            return None
-        return max(self.heads, key=self.get_height_and_last)
+    def get_chain(self, start_hash, length):
+        assert length <= self.get_height(start_hash)
+        for i in xrange(length):
+            yield self.items[start_hash]
+            start_hash = self._delta_type.from_element(self.items[start_hash]).tail
     
-    def get_highest_height(self):
-        return max(self.get_height_and_last(head)[0] for head in self.heads) if self.heads else 0
-    
-    def is_child_of(self, share_hash, possible_child_hash):
-        height, last = self.get_height_and_last(share_hash)
+    def is_child_of(self, item_hash, possible_child_hash):
+        height, last = self.get_height_and_last(item_hash)
         child_height, child_last = self.get_height_and_last(possible_child_hash)
         if child_last != last:
             return None # not connected, so can't be determined
         height_up = child_height - height
-        return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == share_hash
+        return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == item_hash
+
+class SubsetTracker(Tracker):
+    def __init__(self, subset_of, **kwargs):
+        Tracker.__init__(self, **kwargs)
+        self.get_nth_parent_hash = subset_of.get_nth_parent_hash # overwrites Tracker.__init__'s
+        self._subset_of = subset_of
+    
+    def add(self, item):
+        delta = self._delta_type.from_element(item)
+        if self._subset_of is not None:
+            assert delta.head in self._subset_of.items
+        Tracker.add(self, item)
+    
+    def remove(self, item_hash):
+        if self._subset_of is not None:
+            assert item_hash in self._subset_of.items
+        Tracker.remove(self, item_hash)