removed some dead code in tests
[p2pool.git] / p2pool / test / util / test_forest.py
index 95815f5..3b64187 100644 (file)
@@ -2,113 +2,102 @@ import random
 import unittest
 
 from p2pool.util import forest, math
-from p2pool.bitcoin import data as bitcoin_data
 
 class DumbTracker(object):
-    def __init__(self, shares=[]):
-        self.shares = {} # hash -> share
-        self.reverse_shares = {} # previous_hash -> set of share_hashes
+    def __init__(self, items=[]):
+        self.items = {} # hash -> item
+        self.reverse = {} # previous_hash -> set of item_hashes
         
-        for share in shares:
-            self.add(share)
-    
-    def add(self, share):
-        if share.hash in self.shares:
-            raise ValueError('share already present')
-        self.shares[share.hash] = share
-        self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
-    
-    def remove(self, share_hash):
-        share = self.shares[share_hash]
-        del share_hash
+        for item in items:
+            self.add(item)
+    
+    def add(self, item):
+        if item.hash in self.items:
+            raise ValueError('item already present')
+        self.items[item.hash] = item
+        self.reverse.setdefault(item.previous_hash, set()).add(item.hash)
+    
+    def remove(self, item_hash):
+        item = self.items[item_hash]
+        del item_hash
         
-        self.shares.pop(share.hash)
-        self.reverse_shares[share.previous_hash].remove(share.hash)
-        if not self.reverse_shares[share.previous_hash]:
-            self.reverse_shares.pop(share.previous_hash)
+        self.items.pop(item.hash)
+        self.reverse[item.previous_hash].remove(item.hash)
+        if not self.reverse[item.previous_hash]:
+            self.reverse.pop(item.previous_hash)
     
     @property
     def heads(self):
-        return dict((x, self.get_last(x)) for x in self.shares if x not in self.reverse_shares)
+        return dict((x, self.get_last(x)) for x in self.items if x not in self.reverse)
     
     @property
     def tails(self):
-        return dict((x, set(y for y in self.shares if self.get_last(y) == x and y not in self.reverse_shares)) for x in self.reverse_shares if x not in self.shares)
+        return dict((x, set(y for y in self.items if self.get_last(y) == x and y not in self.reverse)) for x in self.reverse if x not in self.items)
     
-    def get_height(self, share_hash):
-        height, work, last = self.get_height_work_and_last(share_hash)
-        return height
+    def get_nth_parent_hash(self, item_hash, n):
+        for i in xrange(n):
+            item_hash = self.items[item_hash].previous_hash
+        return item_hash
     
-    def get_work(self, share_hash):
-        height, work, last = self.get_height_work_and_last(share_hash)
-        return work
+    def get_height(self, item_hash):
+        height, last = self.get_height_and_last(item_hash)
+        return height
     
-    def get_last(self, share_hash):
-        height, work, last = self.get_height_work_and_last(share_hash)
+    def get_last(self, item_hash):
+        height, last = self.get_height_and_last(item_hash)
         return last
     
-    def get_height_and_last(self, share_hash):
-        height, work, last = self.get_height_work_and_last(share_hash)
-        return height, last
-    
-    def get_nth_parent_hash(self, share_hash, n):
-        for i in xrange(n):
-            share_hash = self.shares[share_hash].previous_hash
-        return share_hash
-    
-    def get_height_work_and_last(self, share_hash):
+    def get_height_and_last(self, item_hash):
         height = 0
-        work = 0
-        while share_hash in self.shares:
-            share_hash, work_inc = self.shares[share_hash].previous_hash, bitcoin_data.target_to_average_attempts(self.shares[share_hash].target)
+        while item_hash in self.items:
+            item_hash = self.items[item_hash].previous_hash
             height += 1
-            work += work_inc
-        return height, work, share_hash
+        return height, item_hash
     
     def get_chain(self, start_hash, length):
         # same implementation :/
         assert length <= self.get_height(start_hash)
         for i in xrange(length):
-            yield self.shares[start_hash]
-            start_hash = self.shares[start_hash].previous_hash
+            yield self.items[start_hash]
+            start_hash = self.items[start_hash].previous_hash
     
-    def is_child_of(self, share_hash, possible_child_hash):
-        if self.get_last(share_hash) != self.get_last(possible_child_hash):
+    def is_child_of(self, item_hash, possible_child_hash):
+        if self.get_last(item_hash) != self.get_last(possible_child_hash):
             return None
         while True:
-            if possible_child_hash == share_hash:
+            if possible_child_hash == item_hash:
                 return True
-            if possible_child_hash not in self.shares:
+            if possible_child_hash not in self.items:
                 return False
-            possible_child_hash = self.shares[possible_child_hash].previous_hash
+            possible_child_hash = self.items[possible_child_hash].previous_hash
 
 class FakeShare(object):
-    target = 2**256 - 1
     def __init__(self, **kwargs):
         for k, v in kwargs.iteritems():
             setattr(self, k, v)
         self._attrs = kwargs
-    
-    def __repr__(self):
-        return 'FakeShare(' + ', '.join('%s=%r' % (k, v) for k, v in self._attrs.iteritems()) + ')'
 
 def test_tracker(self):
-    t = DumbTracker(self.shares.itervalues())
+    t = DumbTracker(self.items.itervalues())
     
-    assert self.shares == t.shares, (self.shares, t.shares)
-    assert self.reverse_shares == t.reverse_shares, (self.reverse_shares, t.reverse_shares)
+    assert self.items == t.items, (self.items, t.items)
+    assert self.reverse == t.reverse, (self.reverse, t.reverse)
     assert self.heads == t.heads, (self.heads, t.heads)
     assert self.tails == t.tails, (self.tails, t.tails)
     
-    for start in self.shares:
-        a, b = self.get_height_work_and_last(start), t.get_height_work_and_last(start)
+    if random.random() < 0.9:
+        return
+    
+    for start in self.items:
+        a, b = self.get_height_and_last(start), t.get_height_and_last(start)
         assert a == b, (a, b)
         
-        other = random.choice(self.shares.keys())
+        other = random.choice(self.items.keys())
         assert self.is_child_of(start, other) == t.is_child_of(start, other)
         assert self.is_child_of(other, start) == t.is_child_of(other, start)
         
-        assert list(self.get_chain(start, min(a[0], 10))) == list(t.get_chain(start, min(a[0], 10)))
+        length = random.randrange(a[0])
+        assert list(self.get_chain(start, length)) == list(t.get_chain(start, length))
 
 def generate_tracker_simple(n):
     t = forest.Tracker(math.shuffled(FakeShare(hash=i, previous_hash=i - 1 if i > 0 else None) for i in xrange(n)))
@@ -116,23 +105,23 @@ def generate_tracker_simple(n):
     return t
 
 def generate_tracker_random(n):
-    shares = []
+    items = []
     for i in xrange(n):
-        x = random.choice(shares + [FakeShare(hash=None), FakeShare(hash=random.randrange(1000000, 2000000))]).hash
-        shares.append(FakeShare(hash=i, previous_hash=x))
-    t = forest.Tracker(math.shuffled(shares))
+        x = random.choice(items + [FakeShare(hash=None), FakeShare(hash=random.randrange(1000000, 2000000))]).hash
+        items.append(FakeShare(hash=i, previous_hash=x))
+    t = forest.Tracker(math.shuffled(items))
     test_tracker(t)
     return t
 
 class Test(unittest.TestCase):
     def test_tracker(self):
-        t = generate_tracker_simple(1000)
+        t = generate_tracker_simple(100)
         
-        assert t.heads == {999: None}
-        assert t.tails == {None: set([999])}
+        assert t.heads == {99: None}
+        assert t.tails == {None: set([99])}
         
-        assert t.get_nth_parent_hash(900, 500) == 900 - 500
-        assert t.get_nth_parent_hash(901, 42) == 901 - 42
+        assert t.get_nth_parent_hash(90, 50) == 90 - 50
+        assert t.get_nth_parent_hash(91, 42) == 91 - 42
     
     def test_get_nth_parent_hash(self):
         t = generate_tracker_simple(200)
@@ -144,13 +133,62 @@ class Test(unittest.TestCase):
             assert res == a - b, (a, b, res)
     
     def test_tracker2(self):
-        for ii in xrange(50):
+        for ii in xrange(20):
             t = generate_tracker_random(random.randrange(100))
             #print "--start--"
-            while t.shares:
-                x = random.choice(list(t.shares))
-                try:
-                    t.remove(x)
-                except NotImplementedError:
-                    pass # print "aborted", x
+            while t.items:
+                while True:
+                    try:
+                        t.remove(random.choice(list(t.items)))
+                    except NotImplementedError:
+                        pass # print "aborted", x
+                    else:
+                        break
+                test_tracker(t)
+    
+    def test_tracker3(self):
+        for ii in xrange(10):
+            items = []
+            for i in xrange(random.randrange(100)):
+                x = random.choice(items + [FakeShare(hash=None), FakeShare(hash=random.randrange(1000000, 2000000))]).hash
+                items.append(FakeShare(hash=i, previous_hash=x))
+            
+            t = forest.Tracker()
+            test_tracker(t)
+            
+            for item in math.shuffled(items):
+                t.add(item)
+                test_tracker(t)
+                if random.randrange(3) == 0:
+                    while True:
+                        try:
+                            t.remove(random.choice(list(t.items)))
+                        except NotImplementedError:
+                            pass
+                        else:
+                            break
+                    test_tracker(t)
+            
+            for item in math.shuffled(items):
+                if item.hash not in t.items:
+                    t.add(item)
+                    test_tracker(t)
+                if random.randrange(3) == 0:
+                    while True:
+                        try:
+                            t.remove(random.choice(list(t.items)))
+                        except NotImplementedError:
+                            pass
+                        else:
+                            break
+                    test_tracker(t)
+            
+            while t.items:
+                while True:
+                    try:
+                        t.remove(random.choice(list(t.items)))
+                    except NotImplementedError:
+                        pass
+                    else:
+                        break
                 test_tracker(t)