improved forest tests
authorForrest Voight <forrest@forre.st>
Sun, 11 Dec 2011 02:12:27 +0000 (21:12 -0500)
committerForrest Voight <forrest@forre.st>
Sun, 11 Dec 2011 02:12:27 +0000 (21:12 -0500)
p2pool/test/util/test_forest.py

index 8b6453c..681e0e0 100644 (file)
@@ -29,11 +29,11 @@ class DumbTracker(object):
     
     @property
     def heads(self):
-        return dict((x.hash, self.get_last(x.hash)) for x in self.shares.itervalues() if x.hash not in self.reverse_shares)
+        return dict((x, self.get_last(x)) for x in self.shares if x not in self.reverse_shares)
     
     @property
     def tails(self):
-        return dict((x, set(y.hash for y in self.shares.itervalues() if self.get_last(y.hash) == x and y.hash not in self.reverse_shares)) for x in self.reverse_shares.iterkeys() if x not in self.shares)
+        return dict((x, set(y for y in self.shares if self.get_last(y) == x and y not in self.reverse_shares)) for x in self.reverse_shares if x not in self.shares)
     
     def get_height(self, share_hash):
         height, work, last = self.get_height_work_and_last(share_hash)
@@ -64,6 +64,16 @@ class DumbTracker(object):
             height += 1
             work += work_inc
         return height, work, share_hash
+    
+    def is_child_of(self, share_hash, possible_child_hash):
+        if self.get_last(share_hash) != self.get_last(possible_child_hash):
+            return None
+        while True:
+            if possible_child_hash == share_hash:
+                return True
+            if possible_child_hash not in self.shares:
+                return False
+            possible_child_hash = self.shares[possible_child_hash].previous_hash
 
 class FakeShare(object):
     target = 2**256 - 1
@@ -86,6 +96,10 @@ def test_tracker(self):
     for start in self.shares:
         a, b = self.get_height_work_and_last(start), t.get_height_work_and_last(start)
         assert a == b, (a, b)
+        
+        other = random.choice(self.shares.keys())
+        assert self.is_child_of(start, other) == t.is_child_of(start, other)
+        assert self.is_child_of(other, start) == t.is_child_of(other, start)
 
 def generate_tracker_simple(n):
     t = forest.Tracker(math.shuffled(FakeShare(hash=i, previous_hash=i - 1 if i > 0 else None) for i in xrange(n)))
@@ -95,7 +109,7 @@ def generate_tracker_simple(n):
 def generate_tracker_random(n):
     shares = []
     for i in xrange(n):
-        x = random.choice(shares + [FakeShare(hash=None)]).hash
+        x = random.choice(shares + [FakeShare(hash=None), FakeShare(hash=random.randrange(1000000, 2000000))]).hash
         shares.append(FakeShare(hash=i, previous_hash=x))
     t = forest.Tracker(math.shuffled(shares))
     test_tracker(t)
@@ -112,10 +126,10 @@ class Test(unittest.TestCase):
         assert t.get_nth_parent_hash(901, 42) == 901 - 42
     
     def test_get_nth_parent_hash(self):
-        t = generate_tracker_simple(2000)
+        t = generate_tracker_simple(200)
         
         for i in xrange(1000):
-            a = random.randrange(2000)
+            a = random.randrange(200)
             b = random.randrange(a + 1)
             res = t.get_nth_parent_hash(a, b)
             assert res == a - b, (a, b, res)