moved work attribute out of forrest.Tracker
[p2pool.git] / p2pool / test / util / test_forest.py
1 import random
2 import unittest
3
4 from p2pool.util import forest, math
5
6 class DumbTracker(object):
7     def __init__(self, shares=[]):
8         self.shares = {} # hash -> share
9         self.reverse_shares = {} # previous_hash -> set of share_hashes
10         
11         for share in shares:
12             self.add(share)
13     
14     def add(self, share):
15         if share.hash in self.shares:
16             raise ValueError('share already present')
17         self.shares[share.hash] = share
18         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
19     
20     def remove(self, share_hash):
21         share = self.shares[share_hash]
22         del share_hash
23         
24         self.shares.pop(share.hash)
25         self.reverse_shares[share.previous_hash].remove(share.hash)
26         if not self.reverse_shares[share.previous_hash]:
27             self.reverse_shares.pop(share.previous_hash)
28     
29     @property
30     def heads(self):
31         return dict((x, self.get_last(x)) for x in self.shares if x not in self.reverse_shares)
32     
33     @property
34     def tails(self):
35         return dict((x, set(y for y in self.shares if self.get_last(y) == x and y not in self.reverse_shares)) for x in self.reverse_shares if x not in self.shares)
36     
37     def get_nth_parent_hash(self, share_hash, n):
38         for i in xrange(n):
39             share_hash = self.shares[share_hash].previous_hash
40         return share_hash
41     
42     def get_height(self, share_hash):
43         height, last = self.get_height_and_last(share_hash)
44         return height
45     
46     def get_last(self, share_hash):
47         height, last = self.get_height_and_last(share_hash)
48         return last
49     
50     def get_height_and_last(self, share_hash):
51         height = 0
52         while share_hash in self.shares:
53             share_hash = self.shares[share_hash].previous_hash
54             height += 1
55         return height, share_hash
56     
57     def get_chain(self, start_hash, length):
58         # same implementation :/
59         assert length <= self.get_height(start_hash)
60         for i in xrange(length):
61             yield self.shares[start_hash]
62             start_hash = self.shares[start_hash].previous_hash
63     
64     def is_child_of(self, share_hash, possible_child_hash):
65         if self.get_last(share_hash) != self.get_last(possible_child_hash):
66             return None
67         while True:
68             if possible_child_hash == share_hash:
69                 return True
70             if possible_child_hash not in self.shares:
71                 return False
72             possible_child_hash = self.shares[possible_child_hash].previous_hash
73
74 class FakeShare(object):
75     def __init__(self, **kwargs):
76         for k, v in kwargs.iteritems():
77             setattr(self, k, v)
78         self._attrs = kwargs
79     
80     def __repr__(self):
81         return 'FakeShare(' + ', '.join('%s=%r' % (k, v) for k, v in self._attrs.iteritems()) + ')'
82
83 def test_tracker(self):
84     t = DumbTracker(self.shares.itervalues())
85     
86     assert self.shares == t.shares, (self.shares, t.shares)
87     assert self.reverse_shares == t.reverse_shares, (self.reverse_shares, t.reverse_shares)
88     assert self.heads == t.heads, (self.heads, t.heads)
89     assert self.tails == t.tails, (self.tails, t.tails)
90     
91     if random.random() < 0.9:
92         return
93     
94     for start in self.shares:
95         a, b = self.get_height_and_last(start), t.get_height_and_last(start)
96         assert a == b, (a, b)
97         
98         other = random.choice(self.shares.keys())
99         assert self.is_child_of(start, other) == t.is_child_of(start, other)
100         assert self.is_child_of(other, start) == t.is_child_of(other, start)
101         
102         length = random.randrange(a[0])
103         assert list(self.get_chain(start, length)) == list(t.get_chain(start, length))
104
105 def generate_tracker_simple(n):
106     t = forest.Tracker(math.shuffled(FakeShare(hash=i, previous_hash=i - 1 if i > 0 else None) for i in xrange(n)))
107     test_tracker(t)
108     return t
109
110 def generate_tracker_random(n):
111     shares = []
112     for i in xrange(n):
113         x = random.choice(shares + [FakeShare(hash=None), FakeShare(hash=random.randrange(1000000, 2000000))]).hash
114         shares.append(FakeShare(hash=i, previous_hash=x))
115     t = forest.Tracker(math.shuffled(shares))
116     test_tracker(t)
117     return t
118
119 class Test(unittest.TestCase):
120     def test_tracker(self):
121         t = generate_tracker_simple(100)
122         
123         assert t.heads == {99: None}
124         assert t.tails == {None: set([99])}
125         
126         assert t.get_nth_parent_hash(90, 50) == 90 - 50
127         assert t.get_nth_parent_hash(91, 42) == 91 - 42
128     
129     def test_get_nth_parent_hash(self):
130         t = generate_tracker_simple(200)
131         
132         for i in xrange(1000):
133             a = random.randrange(200)
134             b = random.randrange(a + 1)
135             res = t.get_nth_parent_hash(a, b)
136             assert res == a - b, (a, b, res)
137     
138     def test_tracker2(self):
139         for ii in xrange(20):
140             t = generate_tracker_random(random.randrange(100))
141             #print "--start--"
142             while t.shares:
143                 while True:
144                     try:
145                         t.remove(random.choice(list(t.shares)))
146                     except NotImplementedError:
147                         pass # print "aborted", x
148                     else:
149                         break
150                 test_tracker(t)
151     
152     def test_tracker3(self):
153         for ii in xrange(10):
154             shares = []
155             for i in xrange(random.randrange(100)):
156                 x = random.choice(shares + [FakeShare(hash=None), FakeShare(hash=random.randrange(1000000, 2000000))]).hash
157                 shares.append(FakeShare(hash=i, previous_hash=x))
158             
159             t = forest.Tracker()
160             test_tracker(t)
161             
162             for share in math.shuffled(shares):
163                 t.add(share)
164                 test_tracker(t)
165                 if random.randrange(3) == 0:
166                     while True:
167                         try:
168                             t.remove(random.choice(list(t.shares)))
169                         except NotImplementedError:
170                             pass
171                         else:
172                             break
173                     test_tracker(t)
174             
175             for share in math.shuffled(shares):
176                 if share.hash not in t.shares:
177                     t.add(share)
178                     test_tracker(t)
179                 if random.randrange(3) == 0:
180                     while True:
181                         try:
182                             t.remove(random.choice(list(t.shares)))
183                         except NotImplementedError:
184                             pass
185                         else:
186                             break
187                     test_tracker(t)
188             
189             while t.shares:
190                 while True:
191                     try:
192                         t.remove(random.choice(list(t.shares)))
193                     except NotImplementedError:
194                         pass
195                     else:
196                         break
197                 test_tracker(t)