split Tracker's subset_of handling into separate SubsetTracker class
[p2pool.git] / p2pool / util / forest.py
index 666f289..245e909 100644 (file)
@@ -17,7 +17,7 @@ class TrackerSkipList(skiplist.SkipList):
         watch_id = self.tracker.removed.watch(lambda item: self_ref().forget_item(item.hash))
     
     def previous(self, element):
-        return self.tracker.delta_type.from_element(self.tracker.items[element]).tail
+        return self.tracker._delta_type.from_element(self.tracker.items[element]).tail
 
 
 class DistanceSkipList(TrackerSkipList):
@@ -88,39 +88,33 @@ AttributeDelta = get_attributedelta_type(dict(
 ))
 
 class Tracker(object):
-    def __init__(self, items=[], delta_type=AttributeDelta, subset_of=None):
+    def __init__(self, items=[], delta_type=AttributeDelta):
         self.items = {} # hash -> item
         self.reverse = {} # delta.tail -> set of item_hashes
         
         self.heads = {} # head hash -> tail_hash
         self.tails = {} # tail hash -> set of head hashes
         
-        self.deltas = {} # item_hash -> delta, ref
-        self.reverse_deltas = {} # ref -> set of item_hashes
+        self._deltas = {} # item_hash -> delta, ref
+        self._reverse_deltas = {} # ref -> set of item_hashes
         
-        self.ref_generator = itertools.count()
-        self.delta_refs = {} # ref -> delta
-        self.reverse_delta_refs = {} # delta.tail -> ref
+        self._ref_generator = itertools.count()
+        self._delta_refs = {} # ref -> delta
+        self._reverse_delta_refs = {} # delta.tail -> ref
         
         self.added = variable.Event()
         self.removed = variable.Event()
         
-        if subset_of is None:
-            self.get_nth_parent_hash = DistanceSkipList(self)
-        else:
-            self.get_nth_parent_hash = subset_of.get_nth_parent_hash
+        self.get_nth_parent_hash = DistanceSkipList(self)
         
-        self.delta_type = delta_type
-        self.subset_of = subset_of
+        self._delta_type = delta_type
         
         for item in items:
             self.add(item)
     
     def add(self, item):
         assert not isinstance(item, (int, long, type(None)))
-        delta = self.delta_type.from_element(item)
-        if self.subset_of is not None:
-            assert delta.head in self.subset_of.items
+        delta = self._delta_type.from_element(item)
         
         if delta.head in self.items:
             raise ValueError('item already present')
@@ -151,13 +145,11 @@ class Tracker(object):
         assert isinstance(item_hash, (int, long, type(None)))
         if item_hash not in self.items:
             raise KeyError()
-        if self.subset_of is not None:
-            assert item_hash in self.subset_of.items
         
         item = self.items[item_hash]
         del item_hash
         
-        delta = self.delta_type.from_element(item)
+        delta = self._delta_type.from_element(item)
         
         children = self.reverse.get(delta.head, set())
         
@@ -176,10 +168,10 @@ class Tracker(object):
                 self.heads[delta.tail] = tail
         elif delta.tail in self.tails and len(self.reverse[delta.tail]) <= 1:
             # move delta refs referencing children down to this, so they can be moved up in one step
-            if delta.tail in self.reverse_delta_refs:
-                for x in list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, object()), set())):
+            if delta.tail in self._reverse_delta_refs:
+                for x in list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set())):
                     self.get_last(x)
-                assert delta.head not in self.reverse_delta_refs, list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, None), set()))
+                assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, None), set()))
             
             heads = self.tails.pop(delta.tail)
             for head in heads:
@@ -187,27 +179,27 @@ class Tracker(object):
             self.tails[delta.head] = set(heads)
             
             # move ref pointing to this up
-            if delta.tail in self.reverse_delta_refs:
-                assert delta.head not in self.reverse_delta_refs, list(self.reverse_deltas.get(self.reverse_delta_refs.get(delta.head, object()), set()))
+            if delta.tail in self._reverse_delta_refs:
+                assert delta.head not in self._reverse_delta_refs, list(self._reverse_deltas.get(self._reverse_delta_refs.get(delta.head, object()), set()))
                 
-                ref = self.reverse_delta_refs[delta.tail]
-                cur_delta = self.delta_refs[ref]
+                ref = self._reverse_delta_refs[delta.tail]
+                cur_delta = self._delta_refs[ref]
                 assert cur_delta.tail == delta.tail
-                self.delta_refs[ref] = cur_delta - self.delta_type.from_element(item)
-                assert self.delta_refs[ref].tail == delta.head
-                del self.reverse_delta_refs[delta.tail]
-                self.reverse_delta_refs[delta.head] = ref
+                self._delta_refs[ref] = cur_delta - self._delta_type.from_element(item)
+                assert self._delta_refs[ref].tail == delta.head
+                del self._reverse_delta_refs[delta.tail]
+                self._reverse_delta_refs[delta.head] = ref
         else:
             raise NotImplementedError()
         
         # delete delta entry and ref if it is empty
-        if delta.head in self.deltas:
-            delta1, ref = self.deltas.pop(delta.head)
-            self.reverse_deltas[ref].remove(delta.head)
-            if not self.reverse_deltas[ref]:
-                del self.reverse_deltas[ref]
-                delta2 = self.delta_refs.pop(ref)
-                del self.reverse_delta_refs[delta2.tail]
+        if delta.head in self._deltas:
+            delta1, ref = self._deltas.pop(delta.head)
+            self._reverse_deltas[ref].remove(delta.head)
+            if not self._reverse_deltas[ref]:
+                del self._reverse_deltas[ref]
+                delta2 = self._delta_refs.pop(ref)
+                del self._reverse_delta_refs[delta2.tail]
         
         self.items.pop(delta.head)
         self.reverse[delta.tail].remove(delta.head)
@@ -230,41 +222,41 @@ class Tracker(object):
         return delta.height, delta.tail
     
     def _get_delta(self, item_hash):
-        if item_hash in self.deltas:
-            delta1, ref = self.deltas[item_hash]
-            delta2 = self.delta_refs[ref]
+        if item_hash in self._deltas:
+            delta1, ref = self._deltas[item_hash]
+            delta2 = self._delta_refs[ref]
             res = delta1 + delta2
         else:
-            res = self.delta_type.from_element(self.items[item_hash])
+            res = self._delta_type.from_element(self.items[item_hash])
         assert res.head == item_hash
         return res
     
     def _set_delta(self, item_hash, delta):
         other_item_hash = delta.tail
-        if other_item_hash not in self.reverse_delta_refs:
-            ref = self.ref_generator.next()
-            assert ref not in self.delta_refs
-            self.delta_refs[ref] = self.delta_type.get_none(other_item_hash)
-            self.reverse_delta_refs[other_item_hash] = ref
+        if other_item_hash not in self._reverse_delta_refs:
+            ref = self._ref_generator.next()
+            assert ref not in self._delta_refs
+            self._delta_refs[ref] = self._delta_type.get_none(other_item_hash)
+            self._reverse_delta_refs[other_item_hash] = ref
             del ref
         
-        ref = self.reverse_delta_refs[other_item_hash]
-        ref_delta = self.delta_refs[ref]
+        ref = self._reverse_delta_refs[other_item_hash]
+        ref_delta = self._delta_refs[ref]
         assert ref_delta.tail == other_item_hash
         
-        if item_hash in self.deltas:
-            prev_ref = self.deltas[item_hash][1]
-            self.reverse_deltas[prev_ref].remove(item_hash)
-            if not self.reverse_deltas[prev_ref] and prev_ref != ref:
-                self.reverse_deltas.pop(prev_ref)
-                x = self.delta_refs.pop(prev_ref)
-                self.reverse_delta_refs.pop(x.tail)
-        self.deltas[item_hash] = delta - ref_delta, ref
-        self.reverse_deltas.setdefault(ref, set()).add(item_hash)
+        if item_hash in self._deltas:
+            prev_ref = self._deltas[item_hash][1]
+            self._reverse_deltas[prev_ref].remove(item_hash)
+            if not self._reverse_deltas[prev_ref] and prev_ref != ref:
+                self._reverse_deltas.pop(prev_ref)
+                x = self._delta_refs.pop(prev_ref)
+                self._reverse_delta_refs.pop(x.tail)
+        self._deltas[item_hash] = delta - ref_delta, ref
+        self._reverse_deltas.setdefault(ref, set()).add(item_hash)
     
     def get_delta_to_last(self, item_hash):
         assert isinstance(item_hash, (int, long, type(None)))
-        delta = self.delta_type.get_none(item_hash)
+        delta = self._delta_type.get_none(item_hash)
         updates = []
         while delta.tail in self.items:
             updates.append((delta.tail, delta))
@@ -282,7 +274,7 @@ class Tracker(object):
         assert length <= self.get_height(start_hash)
         for i in xrange(length):
             yield self.items[start_hash]
-            start_hash = self.delta_type.from_element(self.items[start_hash]).tail
+            start_hash = self._delta_type.from_element(self.items[start_hash]).tail
     
     def is_child_of(self, item_hash, possible_child_hash):
         height, last = self.get_height_and_last(item_hash)
@@ -291,3 +283,20 @@ class Tracker(object):
             return None # not connected, so can't be determined
         height_up = child_height - height
         return height_up >= 0 and self.get_nth_parent_hash(possible_child_hash, height_up) == item_hash
+
+class SubsetTracker(Tracker):
+    def __init__(self, subset_of, **kwargs):
+        Tracker.__init__(self, **kwargs)
+        self.get_nth_parent_hash = subset_of.get_nth_parent_hash # overwrites Tracker.__init__'s
+        self._subset_of = subset_of
+    
+    def add(self, item):
+        delta = self._delta_type.from_element(item)
+        if self._subset_of is not None:
+            assert delta.head in self._subset_of.items
+        Tracker.add(self, item)
+    
+    def remove(self, item_hash):
+        if self._subset_of is not None:
+            assert item_hash in self._subset_of.items
+        Tracker.remove(self, item_hash)