optimized share weight adding, fixed runaway bitcoind request bug, made peer requests...
[p2pool.git] / p2pool / bitcoin / data.py
1 from __future__ import division
2
3 import struct
4 import hashlib
5 import itertools
6 import warnings
7
8 from . import base58
9 from p2pool.util import bases, math, skiplist
10
11 class EarlyEnd(Exception):
12     pass
13
14 class LateEnd(Exception):
15     pass
16
17 def read((data, pos), length):
18     data2 = data[pos:pos + length]
19     if len(data2) != length:
20         raise EarlyEnd()
21     return data2, (data, pos + length)
22
23 class Type(object):
24     # the same data can have only one unpacked representation, but multiple packed binary representations
25     
26     #def __hash__(self):
27     #    return hash(tuple(self.__dict__.items()))
28     
29     #def __eq__(self, other):
30     #    if not isinstance(other, Type):
31     #        raise NotImplementedError()
32     #    return self.__dict__ == other.__dict__
33     
34     def _unpack(self, data):
35         obj, (data2, pos) = self.read((data, 0))
36         
37         assert data2 is data
38         
39         if pos != len(data):
40             raise LateEnd()
41         
42         return obj
43     
44     def _pack(self, obj):
45         f = self.write(None, obj)
46         
47         res = []
48         while f is not None:
49             res.append(f[1])
50             f = f[0]
51         res.reverse()
52         return ''.join(res)
53     
54     
55     def unpack(self, data):
56         obj = self._unpack(data)
57         
58         if __debug__:
59             data2 = self._pack(obj)
60             if data2 != data:
61                 assert self._unpack(data2) == obj
62         
63         return obj
64     
65     def pack(self, obj):
66         data = self._pack(obj)
67         
68         assert self._unpack(data) == obj
69         
70         return data
71     
72     
73     def pack_base58(self, obj):
74         return base58.base58_encode(self.pack(obj))
75     
76     def unpack_base58(self, base58_data):
77         return self.unpack(base58.base58_decode(base58_data))
78     
79     
80     def hash160(self, obj):
81         return ShortHashType().unpack(hashlib.new('ripemd160', hashlib.sha256(self.pack(obj)).digest()).digest())
82     
83     def hash256(self, obj):
84         return HashType().unpack(hashlib.sha256(hashlib.sha256(self.pack(obj)).digest()).digest())
85
86 class VarIntType(Type):
87     # redundancy doesn't matter here because bitcoin and p2pool both reencode before hashing
88     def read(self, file):
89         data, file = read(file, 1)
90         first = ord(data)
91         if first < 0xfd:
92             return first, file
93         elif first == 0xfd:
94             desc, length = '<H', 2
95         elif first == 0xfe:
96             desc, length = '<I', 4
97         elif first == 0xff:
98             desc, length = '<Q', 8
99         else:
100             raise AssertionError()
101         data, file = read(file, length)
102         return struct.unpack(desc, data)[0], file
103     
104     def write(self, file, item):
105         if item < 0xfd:
106             file = file, struct.pack('<B', item)
107         elif item <= 0xffff:
108             file = file, struct.pack('<BH', 0xfd, item)
109         elif item <= 0xffffffff:
110             file = file, struct.pack('<BI', 0xfe, item)
111         elif item <= 0xffffffffffffffff:
112             file = file, struct.pack('<BQ', 0xff, item)
113         else:
114             raise ValueError('int too large for varint')
115         return file
116
117 class VarStrType(Type):
118     _inner_size = VarIntType()
119     
120     def read(self, file):
121         length, file = self._inner_size.read(file)
122         return read(file, length)
123     
124     def write(self, file, item):
125         return self._inner_size.write(file, len(item)), item
126
127 class FixedStrType(Type):
128     def __init__(self, length):
129         self.length = length
130     
131     def read(self, file):
132         return read(file, self.length)
133     
134     def write(self, file, item):
135         if len(item) != self.length:
136             raise ValueError('incorrect length item!')
137         return file, item
138
139 class EnumType(Type):
140     def __init__(self, inner, values):
141         self.inner = inner
142         self.values = values
143         
144         self.keys = {}
145         for k, v in values.iteritems():
146             if v in self.keys:
147                 raise ValueError('duplicate value in values')
148             self.keys[v] = k
149     
150     def read(self, file):
151         data, file = self.inner.read(file)
152         return self.keys[data], file
153     
154     def write(self, file, item):
155         return self.inner.write(file, self.values[item])
156
157 class HashType(Type):
158     def read(self, file):
159         data, file = read(file, 256//8)
160         return int(data[::-1].encode('hex'), 16), file
161     
162     def write(self, file, item):
163         if not 0 <= item < 2**256:
164             raise ValueError('invalid hash value - %r' % (item,))
165         if item != 0 and item < 2**160:
166             warnings.warn('very low hash value - maybe you meant to use ShortHashType? %x' % (item,))
167         return file, ('%064x' % (item,)).decode('hex')[::-1]
168
169 class ShortHashType(Type):
170     def read(self, file):
171         data, file = read(file, 160//8)
172         return int(data[::-1].encode('hex'), 16), file
173     
174     def write(self, file, item):
175         if not 0 <= item < 2**160:
176             raise ValueError('invalid hash value - %r' % (item,))
177         return file, ('%040x' % (item,)).decode('hex')[::-1]
178
179 class ListType(Type):
180     _inner_size = VarIntType()
181     
182     def __init__(self, type):
183         self.type = type
184     
185     def read(self, file):
186         length, file = self._inner_size.read(file)
187         res = []
188         for i in xrange(length):
189             item, file = self.type.read(file)
190             res.append(item)
191         return res, file
192     
193     def write(self, file, item):
194         file = self._inner_size.write(file, len(item))
195         for subitem in item:
196             file = self.type.write(file, subitem)
197         return file
198
199 class StructType(Type):
200     def __init__(self, desc):
201         self.desc = desc
202         self.length = struct.calcsize(self.desc)
203     
204     def read(self, file):
205         data, file = read(file, self.length)
206         res, = struct.unpack(self.desc, data)
207         return res, file
208     
209     def write(self, file, item):
210         data = struct.pack(self.desc, item)
211         if struct.unpack(self.desc, data)[0] != item:
212             # special test because struct doesn't error on some overflows
213             raise ValueError('''item didn't survive pack cycle (%r)''' % (item,))
214         return file, data
215
216 class IPV6AddressType(Type):
217     def read(self, file):
218         data, file = read(file, 16)
219         if data[:12] != '00000000000000000000ffff'.decode('hex'):
220             raise ValueError('ipv6 addresses not supported yet')
221         return '.'.join(str(ord(x)) for x in data[12:]), file
222     
223     def write(self, file, item):
224         bits = map(int, item.split('.'))
225         if len(bits) != 4:
226             raise ValueError('invalid address: %r' % (bits,))
227         data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
228         assert len(data) == 16, len(data)
229         return file, data
230
231 class ComposedType(Type):
232     def __init__(self, fields):
233         self.fields = fields
234     
235     def read(self, file):
236         item = {}
237         for key, type_ in self.fields:
238             item[key], file = type_.read(file)
239         return item, file
240     
241     def write(self, file, item):
242         for key, type_ in self.fields:
243             file = type_.write(file, item[key])
244         return file
245
246 class ChecksummedType(Type):
247     def __init__(self, inner):
248         self.inner = inner
249     
250     def read(self, file):
251         obj, file = self.inner.read(file)
252         data = self.inner.pack(obj)
253         
254         checksum, file = read(file, 4)
255         if checksum != hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]:
256             raise ValueError('invalid checksum')
257         
258         return obj, file
259     
260     def write(self, file, item):
261         data = self.inner.pack(item)
262         return (file, data), hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]
263
264 class FloatingIntegerType(Type):
265     # redundancy doesn't matter here because bitcoin checks binary bits against its own computed bits
266     # so it will always be encoded 'normally' in blocks (they way bitcoin does it)
267     _inner = StructType('<I')
268     
269     def read(self, file):
270         bits, file = self._inner.read(file)
271         target = self._bits_to_target(bits)
272         if __debug__:
273             if self._target_to_bits(target) != bits:
274                 raise ValueError('bits in non-canonical form')
275         return target, file
276     
277     def write(self, file, item):
278         return self._inner.write(file, self._target_to_bits(item))
279     
280     def truncate_to(self, x):
281         return self._bits_to_target(self._target_to_bits(x, _check=False))
282     
283     def _bits_to_target(self, bits2):
284         target = math.shift_left(bits2 & 0x00ffffff, 8 * ((bits2 >> 24) - 3))
285         assert target == self._bits_to_target1(struct.pack('<I', bits2))
286         assert self._target_to_bits(target, _check=False) == bits2
287         return target
288     
289     def _bits_to_target1(self, bits):
290         bits = bits[::-1]
291         length = ord(bits[0])
292         return bases.string_to_natural((bits[1:] + '\0'*length)[:length])
293     
294     def _target_to_bits(self, target, _check=True):
295         n = bases.natural_to_string(target)
296         if n and ord(n[0]) >= 128:
297             n = '\x00' + n
298         bits2 = (chr(len(n)) + (n + 3*chr(0))[:3])[::-1]
299         bits = struct.unpack('<I', bits2)[0]
300         if _check:
301             if self._bits_to_target(bits) != target:
302                 raise ValueError(repr((target, self._bits_to_target(bits, _check=False))))
303         return bits
304
305 class PossiblyNone(Type):
306     def __init__(self, none_value, inner):
307         self.none_value = none_value
308         self.inner = inner
309     
310     def read(self, file):
311         value, file = self.inner.read(file)
312         return None if value == self.none_value else value, file
313     
314     def write(self, file, item):
315         if item == self.none_value:
316             raise ValueError('none_value used')
317         return self.inner.write(file, self.none_value if item is None else item)
318
319 address_type = ComposedType([
320     ('services', StructType('<Q')),
321     ('address', IPV6AddressType()),
322     ('port', StructType('>H')),
323 ])
324
325 tx_type = ComposedType([
326     ('version', StructType('<I')),
327     ('tx_ins', ListType(ComposedType([
328         ('previous_output', PossiblyNone(dict(hash=0, index=2**32 - 1), ComposedType([
329             ('hash', HashType()),
330             ('index', StructType('<I')),
331         ]))),
332         ('script', VarStrType()),
333         ('sequence', PossiblyNone(2**32 - 1, StructType('<I'))),
334     ]))),
335     ('tx_outs', ListType(ComposedType([
336         ('value', StructType('<Q')),
337         ('script', VarStrType()),
338     ]))),
339     ('lock_time', StructType('<I')),
340 ])
341
342 block_header_type = ComposedType([
343     ('version', StructType('<I')),
344     ('previous_block', PossiblyNone(0, HashType())),
345     ('merkle_root', HashType()),
346     ('timestamp', StructType('<I')),
347     ('target', FloatingIntegerType()),
348     ('nonce', StructType('<I')),
349 ])
350
351 block_type = ComposedType([
352     ('header', block_header_type),
353     ('txs', ListType(tx_type)),
354 ])
355
356
357 merkle_record_type = ComposedType([
358     ('left', HashType()),
359     ('right', HashType()),
360 ])
361
362 def merkle_hash(tx_list):
363     if not tx_list:
364         return 0
365     hash_list = map(tx_type.hash256, tx_list)
366     while len(hash_list) > 1:
367         hash_list = [merkle_record_type.hash256(dict(left=left, right=left if right is None else right))
368             for left, right in zip(hash_list[::2], hash_list[1::2] + [None])]
369     return hash_list[0]
370
371 def target_to_average_attempts(target):
372     return 2**256//(target + 1)
373
374 # human addresses
375
376 human_address_type = ChecksummedType(ComposedType([
377     ('version', StructType('<B')),
378     ('pubkey_hash', ShortHashType()),
379 ]))
380
381 pubkey_type = FixedStrType(65)
382
383 def pubkey_hash_to_address(pubkey_hash, net):
384     return human_address_type.pack_base58(dict(version=net.BITCOIN_ADDRESS_VERSION, pubkey_hash=pubkey_hash))
385
386 def pubkey_to_address(pubkey, net):
387     return pubkey_hash_to_address(pubkey_type.hash160(pubkey), net)
388
389 def address_to_pubkey_hash(address, net):
390     x = human_address_type.unpack_base58(address)
391     if x['version'] != net.BITCOIN_ADDRESS_VERSION:
392         raise ValueError('address not for this net!')
393     return x['pubkey_hash']
394
395 # transactions
396
397 def pubkey_to_script2(pubkey):
398     return ('\x41' + pubkey_type.pack(pubkey)) + '\xac'
399
400 def pubkey_hash_to_script2(pubkey_hash):
401     return '\x76\xa9' + ('\x14' + ShortHashType().pack(pubkey_hash)) + '\x88\xac'
402
403 # linked list tracker
404
405 class Tracker(object):
406     def __init__(self):
407         self.shares = {} # hash -> share
408         self.ids = {} # hash -> (id, height)
409         self.reverse_shares = {} # previous_hash -> set of share_hashes
410         
411         self.heads = {} # head hash -> tail_hash
412         self.tails = {} # tail hash -> set of head hashes
413         self.heights = {} # share_hash -> height_to, other_share_hash
414         self.skips = {} # share_hash -> skip list
415         
416         self.id_generator = itertools.count()
417         self.tails_by_id = {}
418         
419         self.get_nth_parent = skiplist.DistanceSkipList(self)
420     
421     def add(self, share):
422         assert not isinstance(share, (int, long, type(None)))
423         if share.hash in self.shares:
424             return # XXX raise exception?
425         
426         '''
427         parent_id = self.ids.get(share.previous_hash, None)
428         children_ids = set(self.ids.get(share2_hash) for share2_hash in self.reverse_shares.get(share.hash, set()))
429         infos = set()
430         if parent_id is not None:
431             infos.add((parent_id[0], parent_id[1] + 1))
432         for child_id in children_ids:
433             infos.add((child_id[0], child_id[1] - 1))
434         if not infos:
435             infos.add((self.id_generator.next(), 0))
436         chosen = min(infos)
437         '''
438         
439         self.shares[share.hash] = share
440         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
441         
442         if share.hash in self.tails:
443             heads = self.tails.pop(share.hash)
444         else:
445             heads = set([share.hash])
446         
447         if share.previous_hash in self.heads:
448             tail = self.heads.pop(share.previous_hash)
449         else:
450             #dist, tail = self.get_height_and_last(share.previous_hash) # XXX this should be moved out of the critical area even though it shouldn't matter
451             tail = share.previous_hash
452             while tail in self.shares:
453                 tail = self.shares[tail].previous_hash
454         
455         self.tails.setdefault(tail, set()).update(heads)
456         if share.previous_hash in self.tails[tail]:
457             self.tails[tail].remove(share.previous_hash)
458         
459         for head in heads:
460             self.heads[head] = tail
461     
462     def test(self):
463         t = Tracker()
464         for s in self.shares.itervalues():
465             t.add(s)
466         
467         assert self.shares == t.shares, (self.shares, t.shares)
468         assert self.reverse_shares == t.reverse_shares, (self.reverse_shares, t.reverse_shares)
469         assert self.heads == t.heads, (self.heads, t.heads)
470         assert self.tails == t.tails, (self.tails, t.tails)
471     
472     def remove(self, share_hash):
473         assert isinstance(share_hash, (int, long, type(None)))
474         if share_hash not in self.shares:
475             raise KeyError()
476         share = self.shares[share_hash]
477         del share_hash
478         
479         if share.hash in self.heads and share.previous_hash in self.tails:
480             tail = self.heads.pop(share.hash)
481             self.tails[tail].remove(share.hash)
482             if not self.tails[share.previous_hash]:
483                 self.tails.pop(share.previous_hash)
484         elif share.hash in self.heads:
485             tail = self.heads.pop(share.hash)
486             self.tails[tail].remove(share.hash)
487             if self.reverse_shares[share.previous_hash] != set([share.hash]):
488                 pass # has sibling
489             else:
490                 self.tails[tail].add(share.previous_hash)
491                 self.heads[share.previous_hash] = tail
492         elif share.previous_hash in self.tails:
493             raise NotImplementedError() # will break other things..
494             heads = self.tails[share.previous_hash]
495             if len(self.reverse_shares[share.previous_hash]) > 1:
496                 raise NotImplementedError()
497             else:
498                 del self.tails[share.previous_hash]
499                 for head in heads:
500                     self.heads[head] = share.hash
501                 self.tails[share.hash] = set(heads)
502         else:
503             raise NotImplementedError()
504         
505         '''
506         height, tail = self.get_height_and_last(share.hash)
507         
508         if share.hash in self.heads:
509             my_heads = set([share.hash])
510         elif share.previous_hash in self.tails:
511             my_heads = self.tails[share.previous_hash]
512         else:
513             some_heads = self.tails[tail]
514             some_heads_heights = dict((that_head, self.get_height_and_last(that_head)[0]) for that_head in some_heads)
515             my_heads = set(that_head for that_head in some_heads
516                 if some_heads_heights[that_head] > height and
517                 self.get_nth_parent_hash(that_head, some_heads_heights[that_head] - height) == share.hash)
518         
519         if share.previous_hash != tail:
520             self.heads[share.previous_hash] = tail
521         
522         for head in my_heads:
523             if head != share.hash:
524                 self.heads[head] = share.hash
525             else:
526                 self.heads.pop(head)
527         
528         if share.hash in self.heads:
529             self.heads.pop(share.hash)
530         
531         
532         self.tails[tail].difference_update(my_heads)
533         if share.previous_hash != tail:
534             self.tails[tail].add(share.previous_hash)
535         if not self.tails[tail]:
536             self.tails.pop(tail)
537         if my_heads != set([share.hash]):
538             self.tails[share.hash] = set(my_heads) - set([share.hash])
539         '''
540         
541         self.shares.pop(share.hash)
542         self.reverse_shares[share.previous_hash].remove(share.hash)
543         if not self.reverse_shares[share.previous_hash]:
544             self.reverse_shares.pop(share.previous_hash)
545         
546         assert self.test() is None
547     
548     def get_height_and_last(self, share_hash):
549         assert isinstance(share_hash, (int, long, type(None)))
550         orig = share_hash
551         height = 0
552         updates = []
553         while True:
554             if share_hash is None or share_hash not in self.shares:
555                 break
556             updates.append((share_hash, height))
557             if share_hash in self.heights:
558                 height_inc, share_hash = self.heights[share_hash]
559             else:
560                 height_inc, share_hash = 1, self.shares[share_hash].previous_hash
561             height += height_inc
562         for update_hash, height_then in updates:
563             self.heights[update_hash] = height - height_then, share_hash
564         #assert (height, share_hash) == self.get_height_and_last2(orig), ((height, share_hash), self.get_height_and_last2(orig))
565         return height, share_hash
566     
567     def get_height_and_last2(self, share_hash):
568         assert isinstance(share_hash, (int, long, type(None)))
569         height = 0
570         while True:
571             if share_hash not in self.shares:
572                 break
573             share_hash = self.shares[share_hash].previous_hash
574             height += 1
575         return height, share_hash
576     
577     def get_chain_known(self, start_hash):
578         assert isinstance(start_hash, (int, long, type(None)))
579         '''
580         Chain starting with item of hash I{start_hash} of items that this Tracker contains
581         '''
582         item_hash_to_get = start_hash
583         while True:
584             if item_hash_to_get not in self.shares:
585                 break
586             share = self.shares[item_hash_to_get]
587             assert not isinstance(share, long)
588             yield share
589             item_hash_to_get = share.previous_hash
590     
591     def get_chain_to_root(self, start_hash, root=None):
592         assert isinstance(start_hash, (int, long, type(None)))
593         assert isinstance(root, (int, long, type(None)))
594         '''
595         Chain of hashes starting with share_hash of shares to the root (doesn't include root)
596         Raises an error if one is missing
597         '''
598         share_hash_to_get = start_hash
599         while share_hash_to_get != root:
600             share = self.shares[share_hash_to_get]
601             yield share
602             share_hash_to_get = share.previous_hash
603     
604     def get_best_hash(self):
605         '''
606         Returns hash of item with the most items in its chain
607         '''
608         if not self.heads:
609             return None
610         return max(self.heads, key=self.get_height_and_last)
611     
612     def get_highest_height(self):
613         return max(self.get_height_and_last(head)[0] for head in self.heads) if self.heads else 0
614     
615     def get_nth_parent_hash(self, item_hash, n):
616         if n < 0:
617             raise ValueError('n must be >= 0')
618         
619         updates = {}
620         while n:
621             if item_hash not in self.skips:
622                 self.skips[item_hash] = math.geometric(.5), [(1, self.shares[item_hash].previous_hash)]
623             skip_length, skip = self.skips[item_hash]
624             
625             for i in xrange(skip_length):
626                 if i in updates:
627                     n_then, that_hash = updates.pop(i)
628                     x, y = self.skips[that_hash]
629                     assert len(y) == i
630                     y.append((n_then - n, item_hash))
631             
632             for i in xrange(len(skip), skip_length):
633                 updates[i] = n, item_hash
634             
635             for i, (dist, then_hash) in enumerate(reversed(skip)):
636                 if dist <= n:
637                     break
638             else:
639                 raise AssertionError()
640             
641             n -= dist
642             item_hash = then_hash
643         
644         return item_hash
645     
646     def get_nth_parent2(self, item_hash, n):
647         x = item_hash
648         for i in xrange(n):
649             x = self.shares[item_hash].previous_hash
650         return x
651
652 class FakeShare(object):
653     def __init__(self, **kwargs):
654         self.__dict__.update(kwargs)
655
656 if __name__ == '__main__':
657     
658     t = Tracker()
659     
660     for i in xrange(100):
661         t.add(FakeShare(i, i - 1 if i > 0 else None))
662     
663     t.remove(99)
664     
665     print "HEADS", t.heads
666     print "TAILS", t.tails
667     
668     import random
669     
670     while True:
671         print
672         print '-'*30
673         print
674         t = Tracker()
675         for i in xrange(random.randrange(100)):
676             x = random.choice(list(t.shares) + [None])
677             print i, '->', x
678             t.add(FakeShare(i, x))
679         while t.shares:
680             x = random.choice(list(t.shares))
681             print "DEL", x, t.__dict__
682             try:
683                 t.remove(x)
684             except NotImplementedError:
685                 print "aborted; not implemented"
686         import time
687         time.sleep(.1)
688         print "HEADS", t.heads
689         print "TAILS", t.tails
690     
691     #for share_hash, share in sorted(t.shares.iteritems()):
692     #    print share_hash, share.previous_hash, t.heads.get(share_hash), t.tails.get(share_hash)
693     
694     import sys;sys.exit()
695     
696     print t.get_nth_parent_hash(9000, 5000)
697     print t.get_nth_parent_hash(9001, 412)
698     #print t.get_nth_parent_hash(90, 51)
699     
700     for share_hash in sorted(t.shares):
701         print str(share_hash).rjust(4),
702         x = t.skips.get(share_hash, None)
703         if x is not None:
704             print str(x[0]).rjust(4),
705             for a in x[1]:
706                 print str(a).rjust(10),
707         print
708
709 # network definitions
710
711 class Mainnet(object):
712     BITCOIN_P2P_PREFIX = 'f9beb4d9'.decode('hex')
713     BITCOIN_P2P_PORT = 8333
714     BITCOIN_ADDRESS_VERSION = 0
715
716 class Testnet(object):
717     BITCOIN_P2P_PREFIX = 'fabfb5da'.decode('hex')
718     BITCOIN_P2P_PORT = 18333
719     BITCOIN_ADDRESS_VERSION = 111