merged preremove_special and postremove_special since they didn't need to be separate
[p2pool.git] / p2pool / util / forest.py
index c9fe9e7..96d4b78 100644 (file)
@@ -3,7 +3,6 @@ forest data structure
 '''
 
 import itertools
-import weakref
 
 from p2pool.util import skiplist, variable
 
@@ -13,8 +12,7 @@ class TrackerSkipList(skiplist.SkipList):
         skiplist.SkipList.__init__(self)
         self.tracker = tracker
         
-        self_ref = weakref.ref(self, lambda _: tracker.removed.unwatch(watch_id))
-        watch_id = self.tracker.removed.watch(lambda item: self_ref().forget_item(item.hash))
+        self.tracker.removed.watch_weakref(self, lambda self, item: self.forget_item(item.hash))
     
     def previous(self, element):
         return self.tracker._delta_type.from_element(self.tracker.items[element]).tail
@@ -87,13 +85,10 @@ AttributeDelta = get_attributedelta_type(dict(
     height=lambda item: 1,
 ))
 
-class Tracker(object):
-    def __init__(self, items=[], delta_type=AttributeDelta, subset_of=None):
-        self.items = {} # hash -> item
-        self.reverse = {} # delta.tail -> set of item_hashes
-        
-        self.heads = {} # head hash -> tail_hash
-        self.tails = {} # tail hash -> set of head hashes
+class TrackerView(object):
+    def __init__(self, tracker, delta_type):
+        self._tracker = tracker
+        self._delta_type = delta_type
         
         self._deltas = {} # item_hash -> delta, ref
         self._reverse_deltas = {} # ref -> set of item_hashes
@@ -102,104 +97,37 @@ class Tracker(object):
         self._delta_refs = {} # ref -> delta
         self._reverse_delta_refs = {} # delta.tail -> ref
         
-        self.added = variable.Event()
-        self.removed = variable.Event()
-        
-        if subset_of is None:
-            self.get_nth_parent_hash = DistanceSkipList(self)
-        else:
-            self.get_nth_parent_hash = subset_of.get_nth_parent_hash
-        
-        self._delta_type = delta_type
-        self._subset_of = subset_of
-        
-        for item in items:
-            self.add(item)
+        self._tracker.remove_special.watch_weakref(self, lambda self, item: self._handle_remove_special(item))
+        self._tracker.removed.watch_weakref(self, lambda self, item: self._handle_removed(item))
     
-    def add(self, item):
-        assert not isinstance(item, (int, long, type(None)))
+    def _handle_remove_special(self, item):
         delta = self._delta_type.from_element(item)
-        if self._subset_of is not None:
-            assert delta.head in self._subset_of.items
-        
-        if delta.head in self.items:
-            raise ValueError('item already present')
         
-        if delta.head in self.tails:
-            heads = self.tails.pop(delta.head)
-        else:
-            heads = set([delta.head])
+        if delta.tail not in self._reverse_delta_refs:
+            return
         
-        if delta.tail in self.heads:
-            tail = self.heads.pop(delta.tail)
-        else:
-            tail = self.get_last(delta.tail)
+        # move delta refs referencing children down to this, so they can be moved up in one step
+        for x in list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set())):
+            self.get_last(x)
         
-        self.items[delta.head] = item
-        self.reverse.setdefault(delta.tail, set()).add(delta.head)
+        assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set()))
         
-        self.tails.setdefault(tail, set()).update(heads)
-        if delta.tail in self.tails[tail]:
-            self.tails[tail].remove(delta.tail)
+        if delta.tail not in self._reverse_delta_refs:
+            return
         
-        for head in heads:
-            self.heads[head] = tail
+        # move ref pointing to this up
         
-        self.added.happened(item)
+        ref = self._reverse_delta_refs[delta.tail]
+        cur_delta = self._delta_refs[ref]
+        assert cur_delta.tail == delta.tail
+        self._delta_refs[ref] = cur_delta - self._delta_type.from_element(item)
+        assert self._delta_refs[ref].tail == delta.head
+        del self._reverse_delta_refs[delta.tail]
+        self._reverse_delta_refs[delta.head] = ref
     
-    def remove(self, item_hash):
-        assert isinstance(item_hash, (int, long, type(None)))
-        if item_hash not in self.items:
-            raise KeyError()
-        if self._subset_of is not None:
-            assert item_hash in self._subset_of.items
-        
-        item = self.items[item_hash]
-        del item_hash
-        
+    def _handle_removed(self, item):
         delta = self._delta_type.from_element(item)
         
-        children = self.reverse.get(delta.head, set())
-        
-        if delta.head in self.heads and delta.tail in self.tails:
-            tail = self.heads.pop(delta.head)
-            self.tails[tail].remove(delta.head)
-            if not self.tails[delta.tail]:
-                self.tails.pop(delta.tail)
-        elif delta.head in self.heads:
-            tail = self.heads.pop(delta.head)
-            self.tails[tail].remove(delta.head)
-            if self.reverse[delta.tail] != set([delta.head]):
-                pass # has sibling
-            else:
-                self.tails[tail].add(delta.tail)
-                self.heads[delta.tail] = tail
-        elif delta.tail in self.tails and len(self.reverse[delta.tail]) <= 1:
-            # move delta refs referencing children down to this, so they can be moved up in one step
-            if delta.tail in self._reverse_delta_refs:
-                for x in list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set())):
-                    self.get_last(x)
-                assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, None), set()))
-            
-            heads = self.tails.pop(delta.tail)
-            for head in heads:
-                self.heads[head] = delta.head
-            self.tails[delta.head] = set(heads)
-            
-            # move ref pointing to this up
-            if delta.tail in self._reverse_delta_refs:
-                assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set()))
-                
-                ref = self._reverse_delta_refs[delta.tail]
-                cur_delta = self._delta_refs[ref]
-                assert cur_delta.tail == delta.tail
-                self._delta_refs[ref] = cur_delta - self._delta_type.from_element(item)
-                assert self._delta_refs[ref].tail == delta.head
-                del self._reverse_delta_refs[delta.tail]
-                self._reverse_delta_refs[delta.head] = ref
-        else:
-            raise NotImplementedError()
-        
         # delete delta entry and ref if it is empty
         if delta.head in self._deltas:
             delta1, ref = self._deltas.pop(delta.head)
@@ -208,13 +136,7 @@ class Tracker(object):
                 del self._reverse_deltas[ref]
                 delta2 = self._delta_refs.pop(ref)
                 del self._reverse_delta_refs[delta2.tail]
-        
-        self.items.pop(delta.head)
-        self.reverse[delta.tail].remove(delta.head)
-        if not self.reverse[delta.tail]:
-            self.reverse.pop(delta.tail)
-        
-        self.removed.happened(item)
+    
     
     def get_height(self, item_hash):
         return self.get_delta_to_last(item_hash).height
@@ -235,7 +157,7 @@ class Tracker(object):
             delta2 = self._delta_refs[ref]
             res = delta1 + delta2
         else:
-            res = self._delta_type.from_element(self.items[item_hash])
+            res = self._delta_type.from_element(self._tracker.items[item_hash])
         assert res.head == item_hash
         return res
     
@@ -266,7 +188,7 @@ class Tracker(object):
         assert isinstance(item_hash, (int, long, type(None)))
         delta = self._delta_type.get_none(item_hash)
         updates = []
-        while delta.tail in self.items:
+        while delta.tail in self._tracker.items:
             updates.append((delta.tail, delta))
             this_delta = self._get_delta(delta.tail)
             delta += this_delta
@@ -275,8 +197,104 @@ class Tracker(object):
         return delta
     
     def get_delta(self, item, ancestor):
-        assert self.is_child_of(ancestor, item)
+        assert self._tracker.is_child_of(ancestor, item)
         return self.get_delta_to_last(item) - self.get_delta_to_last(ancestor)
+
+class Tracker(object):
+    def __init__(self, items=[], delta_type=AttributeDelta):
+        self.items = {} # hash -> item
+        self.reverse = {} # delta.tail -> set of item_hashes
+        
+        self.heads = {} # head hash -> tail_hash
+        self.tails = {} # tail hash -> set of head hashes
+        
+        self.added = variable.Event()
+        self.remove_special = variable.Event()
+        self.removed = variable.Event()
+        
+        self.get_nth_parent_hash = DistanceSkipList(self)
+        
+        self._delta_type = delta_type
+        self._default_view = TrackerView(self, delta_type)
+        
+        for item in items:
+            self.add(item)
+    
+    def __getattr__(self, name):
+        attr = getattr(self._default_view, name)
+        setattr(self, name, attr)
+        return attr
+    
+    def add(self, item):
+        assert not isinstance(item, (int, long, type(None)))
+        delta = self._delta_type.from_element(item)
+        
+        if delta.head in self.items:
+            raise ValueError('item already present')
+        
+        if delta.head in self.tails:
+            heads = self.tails.pop(delta.head)
+        else:
+            heads = set([delta.head])
+        
+        if delta.tail in self.heads:
+            tail = self.heads.pop(delta.tail)
+        else:
+            tail = self.get_last(delta.tail)
+        
+        self.items[delta.head] = item
+        self.reverse.setdefault(delta.tail, set()).add(delta.head)
+        
+        self.tails.setdefault(tail, set()).update(heads)
+        if delta.tail in self.tails[tail]:
+            self.tails[tail].remove(delta.tail)
+        
+        for head in heads:
+            self.heads[head] = tail
+        
+        self.added.happened(item)
+    
+    def remove(self, item_hash):
+        assert isinstance(item_hash, (int, long, type(None)))
+        if item_hash not in self.items:
+            raise KeyError()
+        
+        item = self.items[item_hash]
+        del item_hash
+        
+        delta = self._delta_type.from_element(item)
+        
+        children = self.reverse.get(delta.head, set())
+        
+        if delta.head in self.heads and delta.tail in self.tails:
+            tail = self.heads.pop(delta.head)
+            self.tails[tail].remove(delta.head)
+            if not self.tails[delta.tail]:
+                self.tails.pop(delta.tail)
+        elif delta.head in self.heads:
+            tail = self.heads.pop(delta.head)
+            self.tails[tail].remove(delta.head)
+            if self.reverse[delta.tail] != set([delta.head]):
+                pass # has sibling
+            else:
+                self.tails[tail].add(delta.tail)
+                self.heads[delta.tail] = tail
+        elif delta.tail in self.tails and len(self.reverse[delta.tail]) <= 1:
+            heads = self.tails.pop(delta.tail)
+            for head in heads:
+                self.heads[head] = delta.head
+            self.tails[delta.head] = set(heads)
+            
+            self.remove_special.happened(item)
+        else:
+            raise NotImplementedError()
+        
+        self.items.pop(delta.head)
+        self.reverse[delta.tail].remove(delta.head)
+        if not self.reverse[delta.tail]:
+            self.reverse.pop(delta.tail)
+        
+        self.removed.happened(item)
     
     def get_chain(self, start_hash, length):
         assert length <= self.get_height(start_hash)
@@ -291,3 +309,20 @@ class Tracker(object):
             return None # not connected, so can't be determined
         height_up = child_height - height
         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == item_hash
+
+class SubsetTracker(Tracker):
+    def __init__(self, subset_of, **kwargs):
+        Tracker.__init__(self, **kwargs)
+        self.get_nth_parent_hash = subset_of.get_nth_parent_hash # overwrites Tracker.__init__'s
+        self._subset_of = subset_of
+    
+    def add(self, item):
+        delta = self._delta_type.from_element(item)
+        if self._subset_of is not None:
+            assert delta.head in self._subset_of.items
+        Tracker.add(self, item)
+    
+    def remove(self, item_hash):
+        if self._subset_of is not None:
+            assert item_hash in self._subset_of.items
+        Tracker.remove(self, item_hash)