speed up forest
[p2pool.git] / p2pool / test / util / test_forest.py
1 import random
2 import unittest
3
4 from p2pool.util import forest, math
5 from p2pool.bitcoin import data as bitcoin_data
6
7 class DumbTracker(object):
8     def __init__(self, shares=[]):
9         self.shares = {} # hash -> share
10         self.reverse_shares = {} # previous_hash -> set of share_hashes
11         
12         for share in shares:
13             self.add(share)
14     
15     def add(self, share):
16         if share.hash in self.shares:
17             raise ValueError('share already present')
18         self.shares[share.hash] = share
19         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
20     
21     def remove(self, share_hash):
22         share = self.shares[share_hash]
23         del share_hash
24         
25         self.shares.pop(share.hash)
26         self.reverse_shares[share.previous_hash].remove(share.hash)
27         if not self.reverse_shares[share.previous_hash]:
28             self.reverse_shares.pop(share.previous_hash)
29     
30     @property
31     def heads(self):
32         return dict((x, self.get_last(x)) for x in self.shares if x not in self.reverse_shares)
33     
34     @property
35     def tails(self):
36         return dict((x, set(y for y in self.shares if self.get_last(y) == x and y not in self.reverse_shares)) for x in self.reverse_shares if x not in self.shares)
37     
38     def get_height(self, share_hash):
39         height, work, last = self.get_height_work_and_last(share_hash)
40         return height
41     
42     def get_work(self, share_hash):
43         height, work, last = self.get_height_work_and_last(share_hash)
44         return work
45     
46     def get_last(self, share_hash):
47         height, work, last = self.get_height_work_and_last(share_hash)
48         return last
49     
50     def get_height_and_last(self, share_hash):
51         height, work, last = self.get_height_work_and_last(share_hash)
52         return height, last
53     
54     def get_nth_parent_hash(self, share_hash, n):
55         for i in xrange(n):
56             share_hash = self.shares[share_hash].previous_hash
57         return share_hash
58     
59     def get_height_work_and_last(self, share_hash):
60         height = 0
61         work = 0
62         while share_hash in self.shares:
63             share_hash, work_inc = self.shares[share_hash].previous_hash, bitcoin_data.target_to_average_attempts(self.shares[share_hash].target)
64             height += 1
65             work += work_inc
66         return height, work, share_hash
67     
68     def get_chain(self, start_hash, length):
69         # same implementation :/
70         assert length <= self.get_height(start_hash)
71         for i in xrange(length):
72             yield self.shares[start_hash]
73             start_hash = self.shares[start_hash].previous_hash
74     
75     def is_child_of(self, share_hash, possible_child_hash):
76         if self.get_last(share_hash) != self.get_last(possible_child_hash):
77             return None
78         while True:
79             if possible_child_hash == share_hash:
80                 return True
81             if possible_child_hash not in self.shares:
82                 return False
83             possible_child_hash = self.shares[possible_child_hash].previous_hash
84
85 class FakeShare(object):
86     target = 2**256 - 1
87     def __init__(self, **kwargs):
88         for k, v in kwargs.iteritems():
89             setattr(self, k, v)
90         self._attrs = kwargs
91     
92     def __repr__(self):
93         return 'FakeShare(' + ', '.join('%s=%r' % (k, v) for k, v in self._attrs.iteritems()) + ')'
94
95 def test_tracker(self):
96     t = DumbTracker(self.shares.itervalues())
97     
98     assert self.shares == t.shares, (self.shares, t.shares)
99     assert self.reverse_shares == t.reverse_shares, (self.reverse_shares, t.reverse_shares)
100     assert self.heads == t.heads, (self.heads, t.heads)
101     assert self.tails == t.tails, (self.tails, t.tails)
102     
103     if random.random() < 0.9:
104         return
105     
106     for start in self.shares:
107         a, b = self.get_height_work_and_last(start), t.get_height_work_and_last(start)
108         assert a == b, (a, b)
109         
110         other = random.choice(self.shares.keys())
111         assert self.is_child_of(start, other) == t.is_child_of(start, other)
112         assert self.is_child_of(other, start) == t.is_child_of(other, start)
113         
114         length = random.randrange(a[0])
115         assert list(self.get_chain(start, length)) == list(t.get_chain(start, length))
116
117 def generate_tracker_simple(n):
118     t = forest.Tracker(math.shuffled(FakeShare(hash=i, previous_hash=i - 1 if i > 0 else None) for i in xrange(n)))
119     test_tracker(t)
120     return t
121
122 def generate_tracker_random(n):
123     shares = []
124     for i in xrange(n):
125         x = random.choice(shares + [FakeShare(hash=None), FakeShare(hash=random.randrange(1000000, 2000000))]).hash
126         shares.append(FakeShare(hash=i, previous_hash=x))
127     t = forest.Tracker(math.shuffled(shares))
128     test_tracker(t)
129     return t
130
131 class Test(unittest.TestCase):
132     def test_tracker(self):
133         t = generate_tracker_simple(100)
134         
135         assert t.heads == {99: None}
136         assert t.tails == {None: set([99])}
137         
138         assert t.get_nth_parent_hash(90, 50) == 90 - 50
139         assert t.get_nth_parent_hash(91, 42) == 91 - 42
140     
141     def test_get_nth_parent_hash(self):
142         t = generate_tracker_simple(200)
143         
144         for i in xrange(1000):
145             a = random.randrange(200)
146             b = random.randrange(a + 1)
147             res = t.get_nth_parent_hash(a, b)
148             assert res == a - b, (a, b, res)
149     
150     def test_tracker2(self):
151         for ii in xrange(20):
152             t = generate_tracker_random(random.randrange(100))
153             #print "--start--"
154             while t.shares:
155                 while True:
156                     try:
157                         t.remove(random.choice(list(t.shares)))
158                     except NotImplementedError:
159                         pass # print "aborted", x
160                     else:
161                         break
162                 test_tracker(t)
163     
164     def test_tracker3(self):
165         for ii in xrange(10):
166             shares = []
167             for i in xrange(random.randrange(100)):
168                 x = random.choice(shares + [FakeShare(hash=None), FakeShare(hash=random.randrange(1000000, 2000000))]).hash
169                 shares.append(FakeShare(hash=i, previous_hash=x))
170             
171             t = forest.Tracker()
172             test_tracker(t)
173             
174             for share in math.shuffled(shares):
175                 t.add(share)
176                 test_tracker(t)
177                 if random.randrange(3) == 0:
178                     while True:
179                         try:
180                             t.remove(random.choice(list(t.shares)))
181                         except NotImplementedError:
182                             pass
183                         else:
184                             break
185                     test_tracker(t)
186             
187             for share in math.shuffled(shares):
188                 if share.hash not in t.shares:
189                     t.add(share)
190                     test_tracker(t)
191                 if random.randrange(3) == 0:
192                     while True:
193                         try:
194                             t.remove(random.choice(list(t.shares)))
195                         except NotImplementedError:
196                             pass
197                         else:
198                             break
199                     test_tracker(t)
200             
201             while t.shares:
202                 while True:
203                     try:
204                         t.remove(random.choice(list(t.shares)))
205                     except NotImplementedError:
206                         pass
207                     else:
208                         break
209                 test_tracker(t)