optimized a few routines' time efficiency
[p2pool.git] / p2pool / util / forest.py
index 5114938..057c6b8 100644 (file)
@@ -59,6 +59,14 @@ def get_attributedelta_type(attrs): # attrs: {name: func}
         def from_element(cls, item):
             return cls(item.hash, item.previous_hash, **dict((k, v(item)) for k, v in attrs.iteritems()))
         
+        @staticmethod
+        def get_head(item):
+            return item.hash
+        
+        @staticmethod
+        def get_tail(item):
+            return item.previous_hash
+        
         def __init__(self, head, tail, **kwargs):
             self.head, self.tail = head, tail
             for k, v in kwargs.iteritems():
@@ -85,13 +93,10 @@ AttributeDelta = get_attributedelta_type(dict(
     height=lambda item: 1,
 ))
 
-class Tracker(object):
-    def __init__(self, items=[], delta_type=AttributeDelta):
-        self.items = {} # hash -> item
-        self.reverse = {} # delta.tail -> set of item_hashes
-        
-        self.heads = {} # head hash -> tail_hash
-        self.tails = {} # tail hash -> set of head hashes
+class TrackerView(object):
+    def __init__(self, tracker, delta_type):
+        self._tracker = tracker
+        self._delta_type = delta_type
         
         self._deltas = {} # item_hash -> delta, ref
         self._reverse_deltas = {} # ref -> set of item_hashes
@@ -100,95 +105,49 @@ class Tracker(object):
         self._delta_refs = {} # ref -> delta
         self._reverse_delta_refs = {} # delta.tail -> ref
         
-        self.added = variable.Event()
-        self.removed = variable.Event()
-        
-        self.get_nth_parent_hash = DistanceSkipList(self)
-        
-        self._delta_type = delta_type
-        
-        for item in items:
-            self.add(item)
+        self._tracker.remove_special.watch_weakref(self, lambda self, item: self._handle_remove_special(item))
+        self._tracker.remove_special2.watch_weakref(self, lambda self, item: self._handle_remove_special2(item))
+        self._tracker.removed.watch_weakref(self, lambda self, item: self._handle_removed(item))
     
-    def add(self, item):
-        assert not isinstance(item, (int, long, type(None)))
+    def _handle_remove_special(self, item):
         delta = self._delta_type.from_element(item)
         
-        if delta.head in self.items:
-            raise ValueError('item already present')
+        if delta.tail not in self._reverse_delta_refs:
+            return
         
-        if delta.head in self.tails:
-            heads = self.tails.pop(delta.head)
-        else:
-            heads = set([delta.head])
+        # move delta refs referencing children down to this, so they can be moved up in one step
+        for x in list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set())):
+            self.get_last(x)
         
-        if delta.tail in self.heads:
-            tail = self.heads.pop(delta.tail)
-        else:
-            tail = self.get_last(delta.tail)
-        
-        self.items[delta.head] = item
-        self.reverse.setdefault(delta.tail, set()).add(delta.head)
+        assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set()))
         
-        self.tails.setdefault(tail, set()).update(heads)
-        if delta.tail in self.tails[tail]:
-            self.tails[tail].remove(delta.tail)
+        if delta.tail not in self._reverse_delta_refs:
+            return
         
-        for head in heads:
-            self.heads[head] = tail
+        # move ref pointing to this up
         
-        self.added.happened(item)
+        ref = self._reverse_delta_refs[delta.tail]
+        cur_delta = self._delta_refs[ref]
+        assert cur_delta.tail == delta.tail
+        self._delta_refs[ref] = cur_delta - delta
+        assert self._delta_refs[ref].tail == delta.head
+        del self._reverse_delta_refs[delta.tail]
+        self._reverse_delta_refs[delta.head] = ref
     
-    def remove(self, item_hash):
-        assert isinstance(item_hash, (int, long, type(None)))
-        if item_hash not in self.items:
-            raise KeyError()
-        
-        item = self.items[item_hash]
-        del item_hash
-        
+    def _handle_remove_special2(self, item):
         delta = self._delta_type.from_element(item)
         
-        children = self.reverse.get(delta.head, set())
+        if delta.tail not in self._reverse_delta_refs:
+            return
         
-        if delta.head in self.heads and delta.tail in self.tails:
-            tail = self.heads.pop(delta.head)
-            self.tails[tail].remove(delta.head)
-            if not self.tails[delta.tail]:
-                self.tails.pop(delta.tail)
-        elif delta.head in self.heads:
-            tail = self.heads.pop(delta.head)
-            self.tails[tail].remove(delta.head)
-            if self.reverse[delta.tail] != set([delta.head]):
-                pass # has sibling
-            else:
-                self.tails[tail].add(delta.tail)
-                self.heads[delta.tail] = tail
-        elif delta.tail in self.tails and len(self.reverse[delta.tail]) <= 1:
-            # move delta refs referencing children down to this, so they can be moved up in one step
-            if delta.tail in self._reverse_delta_refs:
-                for x in list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set())):
-                    self.get_last(x)
-                assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, None), set()))
-            
-            heads = self.tails.pop(delta.tail)
-            for head in heads:
-                self.heads[head] = delta.head
-            self.tails[delta.head] = set(heads)
-            
-            # move ref pointing to this up
-            if delta.tail in self._reverse_delta_refs:
-                assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set()))
-                
-                ref = self._reverse_delta_refs[delta.tail]
-                cur_delta = self._delta_refs[ref]
-                assert cur_delta.tail == delta.tail
-                self._delta_refs[ref] = cur_delta - self._delta_type.from_element(item)
-                assert self._delta_refs[ref].tail == delta.head
-                del self._reverse_delta_refs[delta.tail]
-                self._reverse_delta_refs[delta.head] = ref
-        else:
-            raise NotImplementedError()
+        ref = self._reverse_delta_refs.pop(delta.tail)
+        del self._delta_refs[ref]
+        
+        for x in self._reverse_deltas.pop(ref):
+            del self._deltas[x]
+    
+    def _handle_removed(self, item):
+        delta = self._delta_type.from_element(item)
         
         # delete delta entry and ref if it is empty
         if delta.head in self._deltas:
@@ -198,13 +157,7 @@ class Tracker(object):
                 del self._reverse_deltas[ref]
                 delta2 = self._delta_refs.pop(ref)
                 del self._reverse_delta_refs[delta2.tail]
-        
-        self.items.pop(delta.head)
-        self.reverse[delta.tail].remove(delta.head)
-        if not self.reverse[delta.tail]:
-            self.reverse.pop(delta.tail)
-        
-        self.removed.happened(item)
+    
     
     def get_height(self, item_hash):
         return self.get_delta_to_last(item_hash).height
@@ -225,7 +178,7 @@ class Tracker(object):
             delta2 = self._delta_refs[ref]
             res = delta1 + delta2
         else:
-            res = self._delta_type.from_element(self.items[item_hash])
+            res = self._delta_type.from_element(self._tracker.items[item_hash])
         assert res.head == item_hash
         return res
     
@@ -256,7 +209,7 @@ class Tracker(object):
         assert isinstance(item_hash, (int, long, type(None)))
         delta = self._delta_type.get_none(item_hash)
         updates = []
-        while delta.tail in self.items:
+        while delta.tail in self._tracker.items:
             updates.append((delta.tail, delta))
             this_delta = self._get_delta(delta.tail)
             delta += this_delta
@@ -265,14 +218,123 @@ class Tracker(object):
         return delta
     
     def get_delta(self, item, ancestor):
-        assert self.is_child_of(ancestor, item)
+        assert self._tracker.is_child_of(ancestor, item)
         return self.get_delta_to_last(item) - self.get_delta_to_last(ancestor)
+
+class Tracker(object):
+    def __init__(self, items=[], delta_type=AttributeDelta):
+        self.items = {} # hash -> item
+        self.reverse = {} # delta.tail -> set of item_hashes
+        
+        self.heads = {} # head hash -> tail_hash
+        self.tails = {} # tail hash -> set of head hashes
+        
+        self.added = variable.Event()
+        self.remove_special = variable.Event()
+        self.remove_special2 = variable.Event()
+        self.removed = variable.Event()
+        
+        self.get_nth_parent_hash = DistanceSkipList(self)
+        
+        self._delta_type = delta_type
+        self._default_view = TrackerView(self, delta_type)
+        
+        for item in items:
+            self.add(item)
+    
+    def __getattr__(self, name):
+        attr = getattr(self._default_view, name)
+        setattr(self, name, attr)
+        return attr
+    
+    def add(self, item):
+        assert not isinstance(item, (int, long, type(None)))
+        delta = self._delta_type.from_element(item)
+        
+        if delta.head in self.items:
+            raise ValueError('item already present')
+        
+        if delta.head in self.tails:
+            heads = self.tails.pop(delta.head)
+        else:
+            heads = set([delta.head])
+        
+        if delta.tail in self.heads:
+            tail = self.heads.pop(delta.tail)
+        else:
+            tail = self.get_last(delta.tail)
+        
+        self.items[delta.head] = item
+        self.reverse.setdefault(delta.tail, set()).add(delta.head)
+        
+        self.tails.setdefault(tail, set()).update(heads)
+        if delta.tail in self.tails[tail]:
+            self.tails[tail].remove(delta.tail)
+        
+        for head in heads:
+            self.heads[head] = tail
+        
+        self.added.happened(item)
+    
+    def remove(self, item_hash):
+        assert isinstance(item_hash, (int, long, type(None)))
+        if item_hash not in self.items:
+            raise KeyError()
+        
+        item = self.items[item_hash]
+        del item_hash
+        
+        delta = self._delta_type.from_element(item)
+        
+        children = self.reverse.get(delta.head, set())
+        
+        if delta.head in self.heads and delta.tail in self.tails:
+            tail = self.heads.pop(delta.head)
+            self.tails[tail].remove(delta.head)
+            if not self.tails[delta.tail]:
+                self.tails.pop(delta.tail)
+        elif delta.head in self.heads:
+            tail = self.heads.pop(delta.head)
+            self.tails[tail].remove(delta.head)
+            if self.reverse[delta.tail] != set([delta.head]):
+                pass # has sibling
+            else:
+                self.tails[tail].add(delta.tail)
+                self.heads[delta.tail] = tail
+        elif delta.tail in self.tails and len(self.reverse[delta.tail]) <= 1:
+            heads = self.tails.pop(delta.tail)
+            for head in heads:
+                self.heads[head] = delta.head
+            self.tails[delta.head] = set(heads)
+            
+            self.remove_special.happened(item)
+        elif delta.tail in self.tails and len(self.reverse[delta.tail]) > 1:
+            heads = [x for x in self.tails[delta.tail] if self.is_child_of(delta.head, x)]
+            self.tails[delta.tail] -= set(heads)
+            if not self.tails[delta.tail]:
+                self.tails.pop(delta.tail)
+            for head in heads:
+                self.heads[head] = delta.head
+            assert delta.head not in self.tails
+            self.tails[delta.head] = set(heads)
+            
+            self.remove_special2.happened(item)
+        else:
+            raise NotImplementedError()
+        
+        self.items.pop(delta.head)
+        self.reverse[delta.tail].remove(delta.head)
+        if not self.reverse[delta.tail]:
+            self.reverse.pop(delta.tail)
+        
+        self.removed.happened(item)
     
     def get_chain(self, start_hash, length):
         assert length <= self.get_height(start_hash)
         for i in xrange(length):
-            yield self.items[start_hash]
-            start_hash = self._delta_type.from_element(self.items[start_hash]).tail
+            item = self.items[start_hash]
+            yield item
+            start_hash = self._delta_type.get_tail(item)
     
     def is_child_of(self, item_hash, possible_child_hash):
         height, last = self.get_height_and_last(item_hash)
@@ -289,9 +351,8 @@ class SubsetTracker(Tracker):
         self._subset_of = subset_of
     
     def add(self, item):
-        delta = self._delta_type.from_element(item)
         if self._subset_of is not None:
-            assert delta.head in self._subset_of.items
+            assert self._delta_type.get_head(item) in self._subset_of.items
         Tracker.add(self, item)
     
     def remove(self, item_hash):