optimization
[p2pool.git] / p2pool / bitcoin / data.py
1 from __future__ import division
2
3 import struct
4 import hashlib
5 import itertools
6 import warnings
7
8 from . import base58
9 from p2pool.util import bases, math
10
11 class EarlyEnd(Exception):
12     pass
13
14 class LateEnd(Exception):
15     pass
16
17 def read((data, pos), length):
18     data2 = data[pos:pos + length]
19     if len(data2) != length:
20         raise EarlyEnd()
21     return data2, (data, pos + length)
22
23 class Type(object):
24     # the same data can have only one unpacked representation, but multiple packed binary representations
25     
26     #def __hash__(self):
27     #    return hash(tuple(self.__dict__.items()))
28     
29     #def __eq__(self, other):
30     #    if not isinstance(other, Type):
31     #        raise NotImplementedError()
32     #    return self.__dict__ == other.__dict__
33     
34     def _unpack(self, data):
35         obj, (data2, pos) = self.read((data, 0))
36         
37         assert data2 is data
38         
39         if pos != len(data):
40             raise LateEnd()
41         
42         return obj
43     
44     def _pack(self, obj):
45         f = self.write(None, obj)
46         
47         res = []
48         while f is not None:
49             res.append(f[1])
50             f = f[0]
51         res.reverse()
52         return ''.join(res)
53     
54     
55     def unpack(self, data):
56         obj = self._unpack(data)
57         
58         if __debug__:
59             data2 = self._pack(obj)
60             if data2 != data:
61                 assert self._unpack(data2) == obj
62         
63         return obj
64     
65     def pack(self, obj):
66         data = self._pack(obj)
67         
68         assert self._unpack(data) == obj
69         
70         return data
71     
72     
73     def pack_base58(self, obj):
74         return base58.base58_encode(self.pack(obj))
75     
76     def unpack_base58(self, base58_data):
77         return self.unpack(base58.base58_decode(base58_data))
78     
79     
80     def hash160(self, obj):
81         return ShortHashType().unpack(hashlib.new('ripemd160', hashlib.sha256(self.pack(obj)).digest()).digest())
82     
83     def hash256(self, obj):
84         return HashType().unpack(hashlib.sha256(hashlib.sha256(self.pack(obj)).digest()).digest())
85
86 class VarIntType(Type):
87     # redundancy doesn't matter here because bitcoin and p2pool both reencode before hashing
88     def read(self, file):
89         data, file = read(file, 1)
90         first = ord(data)
91         if first < 0xfd:
92             return first, file
93         elif first == 0xfd:
94             desc, length = '<H', 2
95         elif first == 0xfe:
96             desc, length = '<I', 4
97         elif first == 0xff:
98             desc, length = '<Q', 8
99         else:
100             raise AssertionError()
101         data, file = read(file, length)
102         return struct.unpack(desc, data)[0], file
103     
104     def write(self, file, item):
105         if item < 0xfd:
106             file = file, struct.pack('<B', item)
107         elif item <= 0xffff:
108             file = file, struct.pack('<BH', 0xfd, item)
109         elif item <= 0xffffffff:
110             file = file, struct.pack('<BI', 0xfe, item)
111         elif item <= 0xffffffffffffffff:
112             file = file, struct.pack('<BQ', 0xff, item)
113         else:
114             raise ValueError('int too large for varint')
115         return file
116
117 class VarStrType(Type):
118     _inner_size = VarIntType()
119     
120     def read(self, file):
121         length, file = self._inner_size.read(file)
122         return read(file, length)
123     
124     def write(self, file, item):
125         return self._inner_size.write(file, len(item)), item
126
127 class FixedStrType(Type):
128     def __init__(self, length):
129         self.length = length
130     
131     def read(self, file):
132         return read(file, self.length)
133     
134     def write(self, file, item):
135         if len(item) != self.length:
136             raise ValueError('incorrect length item!')
137         return file, item
138
139 class EnumType(Type):
140     def __init__(self, inner, values):
141         self.inner = inner
142         self.values = values
143         
144         self.keys = {}
145         for k, v in values.iteritems():
146             if v in self.keys:
147                 raise ValueError('duplicate value in values')
148             self.keys[v] = k
149     
150     def read(self, file):
151         data, file = self.inner.read(file)
152         return self.keys[data], file
153     
154     def write(self, file, item):
155         return self.inner.write(file, self.values[item])
156
157 class HashType(Type):
158     def read(self, file):
159         data, file = read(file, 256//8)
160         return int(data[::-1].encode('hex'), 16), file
161     
162     def write(self, file, item):
163         if not 0 <= item < 2**256:
164             raise ValueError('invalid hash value - %r' % (item,))
165         if item != 0 and item < 2**160:
166             warnings.warn('very low hash value - maybe you meant to use ShortHashType? %x' % (item,))
167         return file, ('%064x' % (item,)).decode('hex')[::-1]
168
169 class ShortHashType(Type):
170     def read(self, file):
171         data, file = read(file, 160//8)
172         return int(data[::-1].encode('hex'), 16), file
173     
174     def write(self, file, item):
175         if not 0 <= item < 2**160:
176             raise ValueError('invalid hash value - %r' % (item,))
177         return file, ('%040x' % (item,)).decode('hex')[::-1]
178
179 class ListType(Type):
180     _inner_size = VarIntType()
181     
182     def __init__(self, type):
183         self.type = type
184     
185     def read(self, file):
186         length, file = self._inner_size.read(file)
187         res = []
188         for i in xrange(length):
189             item, file = self.type.read(file)
190             res.append(item)
191         return res, file
192     
193     def write(self, file, item):
194         file = self._inner_size.write(file, len(item))
195         for subitem in item:
196             file = self.type.write(file, subitem)
197         return file
198
199 class StructType(Type):
200     def __init__(self, desc):
201         self.desc = desc
202         self.length = struct.calcsize(self.desc)
203     
204     def read(self, file):
205         data, file = read(file, self.length)
206         res, = struct.unpack(self.desc, data)
207         return res, file
208     
209     def write(self, file, item):
210         data = struct.pack(self.desc, item)
211         if struct.unpack(self.desc, data)[0] != item:
212             # special test because struct doesn't error on some overflows
213             raise ValueError('''item didn't survive pack cycle (%r)''' % (item,))
214         return file, data
215
216 class IPV6AddressType(Type):
217     def read(self, file):
218         data, file = read(file, 16)
219         if data[:12] != '00000000000000000000ffff'.decode('hex'):
220             raise ValueError('ipv6 addresses not supported yet')
221         return '.'.join(str(ord(x)) for x in data[12:]), file
222     
223     def write(self, file, item):
224         bits = map(int, item.split('.'))
225         if len(bits) != 4:
226             raise ValueError('invalid address: %r' % (bits,))
227         data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
228         assert len(data) == 16, len(data)
229         return file, data
230
231 class ComposedType(Type):
232     def __init__(self, fields):
233         self.fields = fields
234     
235     def read(self, file):
236         item = {}
237         for key, type_ in self.fields:
238             item[key], file = type_.read(file)
239         return item, file
240     
241     def write(self, file, item):
242         for key, type_ in self.fields:
243             file = type_.write(file, item[key])
244         return file
245
246 class ChecksummedType(Type):
247     def __init__(self, inner):
248         self.inner = inner
249     
250     def read(self, file):
251         obj, file = self.inner.read(file)
252         data = self.inner.pack(obj)
253         
254         checksum, file = read(file, 4)
255         if checksum != hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]:
256             raise ValueError('invalid checksum')
257         
258         return obj, file
259     
260     def write(self, file, item):
261         data = self.inner.pack(item)
262         return (file, data), hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]
263
264 class FloatingIntegerType(Type):
265     # redundancy doesn't matter here because bitcoin checks binary bits against its own computed bits
266     # so it will always be encoded 'normally' in blocks (they way bitcoin does it)
267     _inner = StructType('<I')
268     
269     def read(self, file):
270         bits, file = self._inner.read(file)
271         target = self._bits_to_target(bits)
272         if __debug__:
273             if self._target_to_bits(target) != bits:
274                 raise ValueError('bits in non-canonical form')
275         return target, file
276     
277     def write(self, file, item):
278         return self._inner.write(file, self._target_to_bits(item))
279     
280     def truncate_to(self, x):
281         return self._bits_to_target(self._target_to_bits(x, _check=False))
282     
283     def _bits_to_target(self, bits2):
284         target = math.shift_left(bits2 & 0x00ffffff, 8 * ((bits2 >> 24) - 3))
285         assert target == self._bits_to_target1(struct.pack('<I', bits2))
286         assert self._target_to_bits(target, _check=False) == bits2
287         return target
288     
289     def _bits_to_target1(self, bits):
290         bits = bits[::-1]
291         length = ord(bits[0])
292         return bases.string_to_natural((bits[1:] + '\0'*length)[:length])
293     
294     def _target_to_bits(self, target, _check=True):
295         n = bases.natural_to_string(target)
296         if n and ord(n[0]) >= 128:
297             n = '\x00' + n
298         bits2 = (chr(len(n)) + (n + 3*chr(0))[:3])[::-1]
299         bits = struct.unpack('<I', bits2)[0]
300         if _check:
301             if self._bits_to_target(bits) != target:
302                 raise ValueError(repr((target, self._bits_to_target(bits, _check=False))))
303         return bits
304
305 class PossiblyNone(Type):
306     def __init__(self, none_value, inner):
307         self.none_value = none_value
308         self.inner = inner
309     
310     def read(self, file):
311         value, file = self.inner.read(file)
312         return None if value == self.none_value else value, file
313     
314     def write(self, file, item):
315         if item == self.none_value:
316             raise ValueError('none_value used')
317         return self.inner.write(file, self.none_value if item is None else item)
318
319 address_type = ComposedType([
320     ('services', StructType('<Q')),
321     ('address', IPV6AddressType()),
322     ('port', StructType('>H')),
323 ])
324
325 tx_type = ComposedType([
326     ('version', StructType('<I')),
327     ('tx_ins', ListType(ComposedType([
328         ('previous_output', PossiblyNone(dict(hash=0, index=2**32 - 1), ComposedType([
329             ('hash', HashType()),
330             ('index', StructType('<I')),
331         ]))),
332         ('script', VarStrType()),
333         ('sequence', PossiblyNone(2**32 - 1, StructType('<I'))),
334     ]))),
335     ('tx_outs', ListType(ComposedType([
336         ('value', StructType('<Q')),
337         ('script', VarStrType()),
338     ]))),
339     ('lock_time', StructType('<I')),
340 ])
341
342 block_header_type = ComposedType([
343     ('version', StructType('<I')),
344     ('previous_block', PossiblyNone(0, HashType())),
345     ('merkle_root', HashType()),
346     ('timestamp', StructType('<I')),
347     ('target', FloatingIntegerType()),
348     ('nonce', StructType('<I')),
349 ])
350
351 block_type = ComposedType([
352     ('header', block_header_type),
353     ('txs', ListType(tx_type)),
354 ])
355
356
357 merkle_record_type = ComposedType([
358     ('left', HashType()),
359     ('right', HashType()),
360 ])
361
362 def merkle_hash(tx_list):
363     if not tx_list:
364         return 0
365     hash_list = map(tx_type.hash256, tx_list)
366     while len(hash_list) > 1:
367         hash_list = [merkle_record_type.hash256(dict(left=left, right=left if right is None else right))
368             for left, right in zip(hash_list[::2], hash_list[1::2] + [None])]
369     return hash_list[0]
370
371 def target_to_average_attempts(target):
372     return 2**256//(target + 1)
373
374 # human addresses
375
376 human_address_type = ChecksummedType(ComposedType([
377     ('version', StructType('<B')),
378     ('pubkey_hash', ShortHashType()),
379 ]))
380
381 pubkey_type = FixedStrType(65)
382
383 def pubkey_hash_to_address(pubkey_hash, net):
384     return human_address_type.pack_base58(dict(version=net.BITCOIN_ADDRESS_VERSION, pubkey_hash=pubkey_hash))
385
386 def pubkey_to_address(pubkey, net):
387     return pubkey_hash_to_address(pubkey_type.hash160(pubkey), net)
388
389 def address_to_pubkey_hash(address, net):
390     x = human_address_type.unpack_base58(address)
391     if x['version'] != net.BITCOIN_ADDRESS_VERSION:
392         raise ValueError('address not for this net!')
393     return x['pubkey_hash']
394
395 # transactions
396
397 def pubkey_to_script2(pubkey):
398     return ('\x41' + pubkey_type.pack(pubkey)) + '\xac'
399
400 def pubkey_hash_to_script2(pubkey_hash):
401     return '\x76\xa9' + ('\x14' + ShortHashType().pack(pubkey_hash)) + '\x88\xac'
402
403 # linked list tracker
404
405 class Tracker(object):
406     def __init__(self):
407         self.shares = {} # hash -> share
408         self.ids = {} # hash -> (id, height)
409         self.reverse_shares = {} # previous_hash -> set of share_hashes
410         
411         self.heads = {} # head hash -> tail_hash
412         self.tails = {} # tail hash -> set of head hashes
413         self.heights = {} # share_hash -> height_to, other_share_hash
414         self.skips = {} # share_hash -> skip list
415         
416         self.id_generator = itertools.count()
417         self.tails_by_id = {}
418     
419     def add(self, share):
420         assert not isinstance(share, (int, long, type(None)))
421         if share.hash in self.shares:
422             return # XXX raise exception?
423         
424         '''
425         parent_id = self.ids.get(share.previous_hash, None)
426         children_ids = set(self.ids.get(share2_hash) for share2_hash in self.reverse_shares.get(share.hash, set()))
427         infos = set()
428         if parent_id is not None:
429             infos.add((parent_id[0], parent_id[1] + 1))
430         for child_id in children_ids:
431             infos.add((child_id[0], child_id[1] - 1))
432         if not infos:
433             infos.add((self.id_generator.next(), 0))
434         chosen = min(infos)
435         '''
436         
437         self.shares[share.hash] = share
438         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
439         
440         if share.hash in self.tails:
441             heads = self.tails.pop(share.hash)
442         else:
443             heads = set([share.hash])
444         
445         if share.previous_hash in self.heads:
446             tail = self.heads.pop(share.previous_hash)
447         else:
448             #dist, tail = self.get_height_and_last(share.previous_hash) # XXX this should be moved out of the critical area even though it shouldn't matter
449             tail = share.previous_hash
450             while tail in self.shares:
451                 tail = self.shares[tail].previous_hash
452         
453         self.tails.setdefault(tail, set()).update(heads)
454         if share.previous_hash in self.tails[tail]:
455             self.tails[tail].remove(share.previous_hash)
456         
457         for head in heads:
458             self.heads[head] = tail
459     
460     def test(self):
461         t = Tracker()
462         for s in self.shares.itervalues():
463             t.add(s)
464         
465         assert self.shares == t.shares, (self.shares, t.shares)
466         assert self.reverse_shares == t.reverse_shares, (self.reverse_shares, t.reverse_shares)
467         assert self.heads == t.heads, (self.heads, t.heads)
468         assert self.tails == t.tails, (self.tails, t.tails)
469     
470     def remove(self, share_hash):
471         assert isinstance(share_hash, (int, long, type(None)))
472         if share_hash not in self.shares:
473             raise KeyError()
474         share = self.shares[share_hash]
475         del share_hash
476         
477         if share.hash in self.heads and share.previous_hash in self.tails:
478             tail = self.heads.pop(share.hash)
479             self.tails[tail].remove(share.hash)
480             if not self.tails[share.previous_hash]:
481                 self.tails.pop(share.previous_hash)
482         elif share.hash in self.heads:
483             tail = self.heads.pop(share.hash)
484             self.tails[tail].remove(share.hash)
485             if self.reverse_shares[share.previous_hash] != set([share.hash]):
486                 pass # has sibling
487             else:
488                 self.tails[tail].add(share.previous_hash)
489                 self.heads[share.previous_hash] = tail
490         elif share.previous_hash in self.tails:
491             raise NotImplementedError() # will break other things..
492             heads = self.tails[share.previous_hash]
493             if len(self.reverse_shares[share.previous_hash]) > 1:
494                 raise NotImplementedError()
495             else:
496                 del self.tails[share.previous_hash]
497                 for head in heads:
498                     self.heads[head] = share.hash
499                 self.tails[share.hash] = set(heads)
500         else:
501             raise NotImplementedError()
502         
503         '''
504         height, tail = self.get_height_and_last(share.hash)
505         
506         if share.hash in self.heads:
507             my_heads = set([share.hash])
508         elif share.previous_hash in self.tails:
509             my_heads = self.tails[share.previous_hash]
510         else:
511             some_heads = self.tails[tail]
512             some_heads_heights = dict((that_head, self.get_height_and_last(that_head)[0]) for that_head in some_heads)
513             my_heads = set(that_head for that_head in some_heads
514                 if some_heads_heights[that_head] > height and
515                 self.get_nth_parent_hash(that_head, some_heads_heights[that_head] - height) == share.hash)
516         
517         if share.previous_hash != tail:
518             self.heads[share.previous_hash] = tail
519         
520         for head in my_heads:
521             if head != share.hash:
522                 self.heads[head] = share.hash
523             else:
524                 self.heads.pop(head)
525         
526         if share.hash in self.heads:
527             self.heads.pop(share.hash)
528         
529         
530         self.tails[tail].difference_update(my_heads)
531         if share.previous_hash != tail:
532             self.tails[tail].add(share.previous_hash)
533         if not self.tails[tail]:
534             self.tails.pop(tail)
535         if my_heads != set([share.hash]):
536             self.tails[share.hash] = set(my_heads) - set([share.hash])
537         '''
538         
539         self.shares.pop(share.hash)
540         self.reverse_shares[share.previous_hash].remove(share.hash)
541         if not self.reverse_shares[share.previous_hash]:
542             self.reverse_shares.pop(share.previous_hash)
543         
544         assert self.test() is None
545     
546     def get_height_and_last(self, share_hash):
547         assert isinstance(share_hash, (int, long, type(None)))
548         orig = share_hash
549         height = 0
550         updates = []
551         while True:
552             if share_hash is None or share_hash not in self.shares:
553                 break
554             updates.append((share_hash, height))
555             if share_hash in self.heights:
556                 height_inc, share_hash = self.heights[share_hash]
557             else:
558                 height_inc, share_hash = 1, self.shares[share_hash].previous_hash
559             height += height_inc
560         for update_hash, height_then in updates:
561             self.heights[update_hash] = height - height_then, share_hash
562         #assert (height, share_hash) == self.get_height_and_last2(orig), ((height, share_hash), self.get_height_and_last2(orig))
563         return height, share_hash
564     
565     def get_height_and_last2(self, share_hash):
566         assert isinstance(share_hash, (int, long, type(None)))
567         height = 0
568         while True:
569             if share_hash not in self.shares:
570                 break
571             share_hash = self.shares[share_hash].previous_hash
572             height += 1
573         return height, share_hash
574     
575     def get_chain_known(self, start_hash):
576         assert isinstance(start_hash, (int, long, type(None)))
577         '''
578         Chain starting with item of hash I{start_hash} of items that this Tracker contains
579         '''
580         item_hash_to_get = start_hash
581         while True:
582             if item_hash_to_get not in self.shares:
583                 break
584             share = self.shares[item_hash_to_get]
585             assert not isinstance(share, long)
586             yield share
587             item_hash_to_get = share.previous_hash
588     
589     def get_chain_to_root(self, start_hash, root=None):
590         assert isinstance(start_hash, (int, long, type(None)))
591         assert isinstance(root, (int, long, type(None)))
592         '''
593         Chain of hashes starting with share_hash of shares to the root (doesn't include root)
594         Raises an error if one is missing
595         '''
596         share_hash_to_get = start_hash
597         while share_hash_to_get != root:
598             share = self.shares[share_hash_to_get]
599             yield share
600             share_hash_to_get = share.previous_hash
601     
602     def get_best_hash(self):
603         '''
604         Returns hash of item with the most items in its chain
605         '''
606         if not self.heads:
607             return None
608         return max(self.heads, key=self.get_height_and_last)
609     
610     def get_highest_height(self):
611         return max(self.get_height_and_last(head)[0] for head in self.heads) if self.heads else 0
612     
613     def get_nth_parent_hash(self, item_hash, n):
614         if n < 0:
615             raise ValueError('n must be >= 0')
616         
617         updates = {}
618         while n:
619             if item_hash not in self.skips:
620                 self.skips[item_hash] = math.geometric(.5), [(1, self.shares[item_hash].previous_hash)]
621             skip_length, skip = self.skips[item_hash]
622             
623             for i in xrange(skip_length):
624                 if i in updates:
625                     n_then, that_hash = updates.pop(i)
626                     x, y = self.skips[that_hash]
627                     assert len(y) == i
628                     y.append((n_then - n, item_hash))
629             
630             for i in xrange(len(skip), skip_length):
631                 updates[i] = n, item_hash
632             
633             for i, (dist, then_hash) in enumerate(reversed(skip)):
634                 if dist <= n:
635                     break
636             else:
637                 raise AssertionError()
638             
639             n -= dist
640             item_hash = then_hash
641         
642         return item_hash
643     
644     def get_nth_parent2(self, item_hash, n):
645         x = item_hash
646         for i in xrange(n):
647             x = self.shares[item_hash].previous_hash
648         return x
649     
650     def distance_up_to_branch(self, item_hash, max_dist=None):
651         while True:
652             if a:
653                 pass
654
655 if __name__ == '__main__':
656     class FakeShare(object):
657         def __init__(self, hash, previous_hash):
658             self.hash = hash
659             self.previous_hash = previous_hash
660     
661     t = Tracker()
662     
663     for i in xrange(100):
664         t.add(FakeShare(i, i - 1 if i > 0 else None))
665     
666     t.remove(99)
667     
668     print "HEADS", t.heads
669     print "TAILS", t.tails
670     
671     import random
672     
673     while True:
674         print
675         print '-'*30
676         print
677         t = Tracker()
678         for i in xrange(random.randrange(100)):
679             x = random.choice(list(t.shares) + [None])
680             print i, '->', x
681             t.add(FakeShare(i, x))
682         while t.shares:
683             x = random.choice(list(t.shares))
684             print "DEL", x, t.__dict__
685             try:
686                 t.remove(x)
687             except NotImplementedError:
688                 print "aborted; not implemented"
689         import time
690         time.sleep(.1)
691         print "HEADS", t.heads
692         print "TAILS", t.tails
693     
694     #for share_hash, share in sorted(t.shares.iteritems()):
695     #    print share_hash, share.previous_hash, t.heads.get(share_hash), t.tails.get(share_hash)
696     
697     import sys;sys.exit()
698     
699     print t.get_nth_parent_hash(9000, 5000)
700     print t.get_nth_parent_hash(9001, 412)
701     #print t.get_nth_parent_hash(90, 51)
702     
703     for share_hash in sorted(t.shares):
704         print str(share_hash).rjust(4),
705         x = t.skips.get(share_hash, None)
706         if x is not None:
707             print str(x[0]).rjust(4),
708             for a in x[1]:
709                 print str(a).rjust(10),
710         print
711
712 # network definitions
713
714 class Mainnet(object):
715     BITCOIN_P2P_PREFIX = 'f9beb4d9'.decode('hex')
716     BITCOIN_P2P_PORT = 8333
717     BITCOIN_ADDRESS_VERSION = 0
718
719 class Testnet(object):
720     BITCOIN_P2P_PREFIX = 'fabfb5da'.decode('hex')
721     BITCOIN_P2P_PORT = 18333
722     BITCOIN_ADDRESS_VERSION = 111