added --debug, few small fixes
[p2pool.git] / p2pool / bitcoin / data.py
1 from __future__ import division
2
3 import struct
4 import hashlib
5 import itertools
6 import warnings
7
8 from . import base58
9 from p2pool.util import bases, math, skiplist, intern2
10 import p2pool
11
12 class EarlyEnd(Exception):
13     pass
14
15 class LateEnd(Exception):
16     pass
17
18 def read((data, pos), length):
19     data2 = data[pos:pos + length]
20     if len(data2) != length:
21         raise EarlyEnd()
22     return data2, (data, pos + length)
23
24 def size((data, pos)):
25     return len(data) - pos
26
27 class Type(object):
28     # the same data can have only one unpacked representation, but multiple packed binary representations
29     
30     #def __hash__(self):
31     #    return hash(tuple(self.__dict__.items()))
32     
33     #def __eq__(self, other):
34     #    if not isinstance(other, Type):
35     #        raise NotImplementedError()
36     #    return self.__dict__ == other.__dict__
37     
38     def _unpack(self, data):
39         obj, (data2, pos) = self.read((data, 0))
40         
41         assert data2 is data
42         
43         if pos != len(data):
44             raise LateEnd()
45         
46         return obj
47     
48     def _pack(self, obj):
49         f = self.write(None, obj)
50         
51         res = []
52         while f is not None:
53             res.append(f[1])
54             f = f[0]
55         res.reverse()
56         return ''.join(res)
57     
58     
59     def unpack(self, data):
60         obj = self._unpack(data)
61         
62         if p2pool.DEBUG:
63             data2 = self._pack(obj)
64             if data2 != data:
65                 if self._unpack(data2) != obj:
66                     raise AssertionError()
67         
68         return obj
69     
70     def pack(self, obj):
71         data = self._pack(obj)
72         
73         if p2pool.DEBUG:
74             if self._unpack(data) != obj:
75                 raise AssertionError()
76         
77         return data
78     
79     
80     def pack_base58(self, obj):
81         return base58.base58_encode(self.pack(obj))
82     
83     def unpack_base58(self, base58_data):
84         return self.unpack(base58.base58_decode(base58_data))
85     
86     
87     def hash160(self, obj):
88         return ShortHashType().unpack(hashlib.new('ripemd160', hashlib.sha256(self.pack(obj)).digest()).digest())
89     
90     def hash256(self, obj):
91         return HashType().unpack(hashlib.sha256(hashlib.sha256(self.pack(obj)).digest()).digest())
92
93 class VarIntType(Type):
94     # redundancy doesn't matter here because bitcoin and p2pool both reencode before hashing
95     def read(self, file):
96         data, file = read(file, 1)
97         first = ord(data)
98         if first < 0xfd:
99             return first, file
100         elif first == 0xfd:
101             desc, length = '<H', 2
102         elif first == 0xfe:
103             desc, length = '<I', 4
104         elif first == 0xff:
105             desc, length = '<Q', 8
106         else:
107             raise AssertionError()
108         data, file = read(file, length)
109         return struct.unpack(desc, data)[0], file
110     
111     def write(self, file, item):
112         if item < 0xfd:
113             file = file, struct.pack('<B', item)
114         elif item <= 0xffff:
115             file = file, struct.pack('<BH', 0xfd, item)
116         elif item <= 0xffffffff:
117             file = file, struct.pack('<BI', 0xfe, item)
118         elif item <= 0xffffffffffffffff:
119             file = file, struct.pack('<BQ', 0xff, item)
120         else:
121             raise ValueError('int too large for varint')
122         return file
123
124 class VarStrType(Type):
125     _inner_size = VarIntType()
126     
127     def read(self, file):
128         length, file = self._inner_size.read(file)
129         return read(file, length)
130     
131     def write(self, file, item):
132         return self._inner_size.write(file, len(item)), item
133
134 class FixedStrType(Type):
135     def __init__(self, length):
136         self.length = length
137     
138     def read(self, file):
139         return read(file, self.length)
140     
141     def write(self, file, item):
142         if len(item) != self.length:
143             raise ValueError('incorrect length item!')
144         return file, item
145
146 class EnumType(Type):
147     def __init__(self, inner, values):
148         self.inner = inner
149         self.values = values
150         
151         self.keys = {}
152         for k, v in values.iteritems():
153             if v in self.keys:
154                 raise ValueError('duplicate value in values')
155             self.keys[v] = k
156     
157     def read(self, file):
158         data, file = self.inner.read(file)
159         return self.keys[data], file
160     
161     def write(self, file, item):
162         return self.inner.write(file, self.values[item])
163
164 class HashType(Type):
165     def read(self, file):
166         data, file = read(file, 256//8)
167         return int(data[::-1].encode('hex'), 16), file
168     
169     def write(self, file, item):
170         if not 0 <= item < 2**256:
171             raise ValueError('invalid hash value - %r' % (item,))
172         if item != 0 and item < 2**160:
173             warnings.warn('very low hash value - maybe you meant to use ShortHashType? %x' % (item,))
174         return file, ('%064x' % (item,)).decode('hex')[::-1]
175
176 class ShortHashType(Type):
177     def read(self, file):
178         data, file = read(file, 160//8)
179         return int(data[::-1].encode('hex'), 16), file
180     
181     def write(self, file, item):
182         if not 0 <= item < 2**160:
183             raise ValueError('invalid hash value - %r' % (item,))
184         return file, ('%040x' % (item,)).decode('hex')[::-1]
185
186 class ListType(Type):
187     _inner_size = VarIntType()
188     
189     def __init__(self, type):
190         self.type = type
191     
192     def read(self, file):
193         length, file = self._inner_size.read(file)
194         res = []
195         for i in xrange(length):
196             item, file = self.type.read(file)
197             res.append(item)
198         return res, file
199     
200     def write(self, file, item):
201         file = self._inner_size.write(file, len(item))
202         for subitem in item:
203             file = self.type.write(file, subitem)
204         return file
205
206 class StructType(Type):
207     def __init__(self, desc):
208         self.desc = desc
209         self.length = struct.calcsize(self.desc)
210     
211     def read(self, file):
212         data, file = read(file, self.length)
213         res, = struct.unpack(self.desc, data)
214         return res, file
215     
216     def write(self, file, item):
217         data = struct.pack(self.desc, item)
218         if struct.unpack(self.desc, data)[0] != item:
219             # special test because struct doesn't error on some overflows
220             raise ValueError('''item didn't survive pack cycle (%r)''' % (item,))
221         return file, data
222
223 class IPV6AddressType(Type):
224     def read(self, file):
225         data, file = read(file, 16)
226         if data[:12] != '00000000000000000000ffff'.decode('hex'):
227             raise ValueError('ipv6 addresses not supported yet')
228         return '.'.join(str(ord(x)) for x in data[12:]), file
229     
230     def write(self, file, item):
231         bits = map(int, item.split('.'))
232         if len(bits) != 4:
233             raise ValueError('invalid address: %r' % (bits,))
234         data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
235         assert len(data) == 16, len(data)
236         return file, data
237
238 class ComposedType(Type):
239     def __init__(self, fields):
240         self.fields = fields
241     
242     def read(self, file):
243         item = {}
244         for key, type_ in self.fields:
245             item[key], file = type_.read(file)
246         return item, file
247     
248     def write(self, file, item):
249         for key, type_ in self.fields:
250             file = type_.write(file, item[key])
251         return file
252
253 class ChecksummedType(Type):
254     def __init__(self, inner):
255         self.inner = inner
256     
257     def read(self, file):
258         obj, file = self.inner.read(file)
259         data = self.inner.pack(obj)
260         
261         checksum, file = read(file, 4)
262         if checksum != hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]:
263             raise ValueError('invalid checksum')
264         
265         return obj, file
266     
267     def write(self, file, item):
268         data = self.inner.pack(item)
269         return (file, data), hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]
270
271 class FloatingIntegerType(Type):
272     # redundancy doesn't matter here because bitcoin checks binary bits against its own computed bits
273     # so it will always be encoded 'normally' in blocks (they way bitcoin does it)
274     _inner = StructType('<I')
275     
276     def read(self, file):
277         bits, file = self._inner.read(file)
278         target = self._bits_to_target(bits)
279         if p2pool.DEBUG:
280             if self._target_to_bits(target) != bits:
281                 raise ValueError('bits in non-canonical form')
282         target = intern2.intern2(target)
283         return target, file
284     
285     def write(self, file, item):
286         return self._inner.write(file, self._target_to_bits(item))
287     
288     def truncate_to(self, x):
289         return self._bits_to_target(self._target_to_bits(x, _check=False))
290     
291     def _bits_to_target(self, bits2):
292         target = math.shift_left(bits2 & 0x00ffffff, 8 * ((bits2 >> 24) - 3))
293         if p2pool.DEBUG:
294             assert target == self._bits_to_target1(struct.pack('<I', bits2))
295             assert self._target_to_bits(target, _check=False) == bits2
296         return target
297     
298     def _bits_to_target1(self, bits):
299         bits = bits[::-1]
300         length = ord(bits[0])
301         return bases.string_to_natural((bits[1:] + '\0'*length)[:length])
302     
303     def _target_to_bits(self, target, _check=True):
304         n = bases.natural_to_string(target)
305         if n and ord(n[0]) >= 128:
306             n = '\x00' + n
307         bits2 = (chr(len(n)) + (n + 3*chr(0))[:3])[::-1]
308         bits = struct.unpack('<I', bits2)[0]
309         if _check:
310             if self._bits_to_target(bits) != target:
311                 raise ValueError(repr((target, self._bits_to_target(bits, _check=False))))
312         return bits
313
314 class PossiblyNone(Type):
315     def __init__(self, none_value, inner):
316         self.none_value = none_value
317         self.inner = inner
318     
319     def read(self, file):
320         value, file = self.inner.read(file)
321         return None if value == self.none_value else value, file
322     
323     def write(self, file, item):
324         if item == self.none_value:
325             raise ValueError('none_value used')
326         return self.inner.write(file, self.none_value if item is None else item)
327
328 address_type = ComposedType([
329     ('services', StructType('<Q')),
330     ('address', IPV6AddressType()),
331     ('port', StructType('>H')),
332 ])
333
334 tx_type = ComposedType([
335     ('version', StructType('<I')),
336     ('tx_ins', ListType(ComposedType([
337         ('previous_output', PossiblyNone(dict(hash=0, index=2**32 - 1), ComposedType([
338             ('hash', HashType()),
339             ('index', StructType('<I')),
340         ]))),
341         ('script', VarStrType()),
342         ('sequence', PossiblyNone(2**32 - 1, StructType('<I'))),
343     ]))),
344     ('tx_outs', ListType(ComposedType([
345         ('value', StructType('<Q')),
346         ('script', VarStrType()),
347     ]))),
348     ('lock_time', StructType('<I')),
349 ])
350
351 block_header_type = ComposedType([
352     ('version', StructType('<I')),
353     ('previous_block', PossiblyNone(0, HashType())),
354     ('merkle_root', HashType()),
355     ('timestamp', StructType('<I')),
356     ('target', FloatingIntegerType()),
357     ('nonce', StructType('<I')),
358 ])
359
360 block_type = ComposedType([
361     ('header', block_header_type),
362     ('txs', ListType(tx_type)),
363 ])
364
365
366 merkle_record_type = ComposedType([
367     ('left', HashType()),
368     ('right', HashType()),
369 ])
370
371 def merkle_hash(tx_list):
372     if not tx_list:
373         return 0
374     hash_list = map(tx_type.hash256, tx_list)
375     while len(hash_list) > 1:
376         hash_list = [merkle_record_type.hash256(dict(left=left, right=left if right is None else right))
377             for left, right in zip(hash_list[::2], hash_list[1::2] + [None])]
378     return hash_list[0]
379
380 def target_to_average_attempts(target):
381     return 2**256//(target + 1)
382
383 # tx
384
385 def tx_get_sigop_count(tx):
386     return sum(script.get_sigop_count(txin['script']) for txin in tx['tx_ins']) + sum(script.get_sigop_count(txout['script']) for txout in tx['tx_outs'])
387
388 # human addresses
389
390 human_address_type = ChecksummedType(ComposedType([
391     ('version', StructType('<B')),
392     ('pubkey_hash', ShortHashType()),
393 ]))
394
395 pubkey_type = FixedStrType(65)
396
397 def pubkey_hash_to_address(pubkey_hash, net):
398     return human_address_type.pack_base58(dict(version=net.BITCOIN_ADDRESS_VERSION, pubkey_hash=pubkey_hash))
399
400 def pubkey_to_address(pubkey, net):
401     return pubkey_hash_to_address(pubkey_type.hash160(pubkey), net)
402
403 def address_to_pubkey_hash(address, net):
404     x = human_address_type.unpack_base58(address)
405     if x['version'] != net.BITCOIN_ADDRESS_VERSION:
406         raise ValueError('address not for this net!')
407     return x['pubkey_hash']
408
409 # transactions
410
411 def pubkey_to_script2(pubkey):
412     return ('\x41' + pubkey_type.pack(pubkey)) + '\xac'
413
414 def pubkey_hash_to_script2(pubkey_hash):
415     return '\x76\xa9' + ('\x14' + ShortHashType().pack(pubkey_hash)) + '\x88\xac'
416
417 # linked list tracker
418
419 class Tracker(object):
420     def __init__(self):
421         self.shares = {} # hash -> share
422         self.ids = {} # hash -> (id, height)
423         self.reverse_shares = {} # previous_hash -> set of share_hashes
424         
425         self.heads = {} # head hash -> tail_hash
426         self.tails = {} # tail hash -> set of head hashes
427         
428         self.heights = {} # share_hash -> height_to, other_share_hash
429         
430         self.id_generator = itertools.count()
431         self.tails_by_id = {}
432         
433         self.get_nth_parent_hash = skiplist.DistanceSkipList(self)
434     
435     def add(self, share):
436         assert not isinstance(share, (int, long, type(None)))
437         if share.hash in self.shares:
438             return # XXX raise exception?
439         
440         parent_id = self.ids.get(share.previous_hash, None)
441         children_ids = set(self.ids.get(share2_hash) for share2_hash in self.reverse_shares.get(share.hash, set()))
442         infos = set()
443         if parent_id is not None:
444             infos.add((parent_id[0], parent_id[1] + 1))
445         for child_id in children_ids:
446             infos.add((child_id[0], child_id[1] - 1))
447         if not infos:
448             infos.add((self.id_generator.next(), 0))
449         chosen = min(infos)
450         self.ids[share.hash] = chosen
451         
452         self.shares[share.hash] = share
453         self.reverse_shares.setdefault(share.previous_hash, set()).add(share.hash)
454         
455         if share.hash in self.tails:
456             heads = self.tails.pop(share.hash)
457         else:
458             heads = set([share.hash])
459         
460         if share.previous_hash in self.heads:
461             tail = self.heads.pop(share.previous_hash)
462         else:
463             #dist, tail = self.get_height_and_last(share.previous_hash) # XXX this should be moved out of the critical area even though it shouldn't matter
464             tail = share.previous_hash
465             while tail in self.shares:
466                 tail = self.shares[tail].previous_hash
467         
468         self.tails.setdefault(tail, set()).update(heads)
469         if share.previous_hash in self.tails[tail]:
470             self.tails[tail].remove(share.previous_hash)
471         
472         for head in heads:
473             self.heads[head] = tail
474     
475     def test(self):
476         t = Tracker()
477         for s in self.shares.itervalues():
478             t.add(s)
479         
480         assert self.shares == t.shares, (self.shares, t.shares)
481         assert self.reverse_shares == t.reverse_shares, (self.reverse_shares, t.reverse_shares)
482         assert self.heads == t.heads, (self.heads, t.heads)
483         assert self.tails == t.tails, (self.tails, t.tails)
484     
485     def remove(self, share_hash):
486         assert isinstance(share_hash, (int, long, type(None)))
487         if share_hash not in self.shares:
488             raise KeyError()
489         share = self.shares[share_hash]
490         del share_hash
491         
492         if share.hash in self.heads and share.previous_hash in self.tails:
493             tail = self.heads.pop(share.hash)
494             self.tails[tail].remove(share.hash)
495             if not self.tails[share.previous_hash]:
496                 self.tails.pop(share.previous_hash)
497         elif share.hash in self.heads:
498             tail = self.heads.pop(share.hash)
499             self.tails[tail].remove(share.hash)
500             if self.reverse_shares[share.previous_hash] != set([share.hash]):
501                 pass # has sibling
502             else:
503                 self.tails[tail].add(share.previous_hash)
504                 self.heads[share.previous_hash] = tail
505         elif share.previous_hash in self.tails:
506             raise NotImplementedError() # will break other things..
507             heads = self.tails[share.previous_hash]
508             if len(self.reverse_shares[share.previous_hash]) > 1:
509                 raise NotImplementedError()
510             else:
511                 del self.tails[share.previous_hash]
512                 for head in heads:
513                     self.heads[head] = share.hash
514                 self.tails[share.hash] = set(heads)
515         else:
516             raise NotImplementedError()
517         
518         '''
519         height, tail = self.get_height_and_last(share.hash)
520         
521         if share.hash in self.heads:
522             my_heads = set([share.hash])
523         elif share.previous_hash in self.tails:
524             my_heads = self.tails[share.previous_hash]
525         else:
526             some_heads = self.tails[tail]
527             some_heads_heights = dict((that_head, self.get_height_and_last(that_head)[0]) for that_head in some_heads)
528             my_heads = set(that_head for that_head in some_heads
529                 if some_heads_heights[that_head] > height and
530                 self.get_nth_parent_hash(that_head, some_heads_heights[that_head] - height) == share.hash)
531         
532         if share.previous_hash != tail:
533             self.heads[share.previous_hash] = tail
534         
535         for head in my_heads:
536             if head != share.hash:
537                 self.heads[head] = share.hash
538             else:
539                 self.heads.pop(head)
540         
541         if share.hash in self.heads:
542             self.heads.pop(share.hash)
543         
544         
545         self.tails[tail].difference_update(my_heads)
546         if share.previous_hash != tail:
547             self.tails[tail].add(share.previous_hash)
548         if not self.tails[tail]:
549             self.tails.pop(tail)
550         if my_heads != set([share.hash]):
551             self.tails[share.hash] = set(my_heads) - set([share.hash])
552         '''
553         
554         self.shares.pop(share.hash)
555         self.reverse_shares[share.previous_hash].remove(share.hash)
556         if not self.reverse_shares[share.previous_hash]:
557             self.reverse_shares.pop(share.previous_hash)
558         
559         #assert self.test() is None
560     
561     def get_height(self, share_hash):
562         height, work, last = self.get_height_work_and_last(share_hash)
563         return height
564     
565     def get_work(self, share_hash):
566         height, work, last = self.get_height_work_and_last(share_hash)
567         return work
568     
569     def get_height_and_last(self, share_hash):
570         height, work, last = self.get_height_work_and_last(share_hash)
571         return height, last
572     
573     def get_height_work_and_last(self, share_hash):
574         assert isinstance(share_hash, (int, long, type(None)))
575         orig = share_hash
576         height = 0
577         work = 0
578         updates = []
579         while True:
580             if share_hash is None or share_hash not in self.shares:
581                 break
582             updates.append((share_hash, height, work))
583             if share_hash in self.heights:
584                 height_inc, share_hash, work_inc = self.heights[share_hash]
585             else:
586                 height_inc, share_hash, work_inc = 1, self.shares[share_hash].previous_hash, target_to_average_attempts(self.shares[share_hash].target)
587             height += height_inc
588             work += work_inc
589         for update_hash, height_then, work_then in updates:
590             self.heights[update_hash] = height - height_then, share_hash, work - work_then
591         return height, work, share_hash
592     
593     def get_chain_known(self, start_hash):
594         assert isinstance(start_hash, (int, long, type(None)))
595         '''
596         Chain starting with item of hash I{start_hash} of items that this Tracker contains
597         '''
598         item_hash_to_get = start_hash
599         while True:
600             if item_hash_to_get not in self.shares:
601                 break
602             share = self.shares[item_hash_to_get]
603             assert not isinstance(share, long)
604             yield share
605             item_hash_to_get = share.previous_hash
606     
607     def get_chain_to_root(self, start_hash, root=None):
608         assert isinstance(start_hash, (int, long, type(None)))
609         assert isinstance(root, (int, long, type(None)))
610         '''
611         Chain of hashes starting with share_hash of shares to the root (doesn't include root)
612         Raises an error if one is missing
613         '''
614         share_hash_to_get = start_hash
615         while share_hash_to_get != root:
616             share = self.shares[share_hash_to_get]
617             yield share
618             share_hash_to_get = share.previous_hash
619     
620     def get_best_hash(self):
621         '''
622         Returns hash of item with the most items in its chain
623         '''
624         if not self.heads:
625             return None
626         return max(self.heads, key=self.get_height_and_last)
627     
628     def get_highest_height(self):
629         return max(self.get_height_and_last(head)[0] for head in self.heads) if self.heads else 0
630
631 class FakeShare(object):
632     def __init__(self, **kwargs):
633         self.__dict__.update(kwargs)
634
635 if __name__ == '__main__':
636     
637     t = Tracker()
638     
639     for i in xrange(10000):
640         t.add(FakeShare(hash=i, previous_hash=i - 1 if i > 0 else None))
641     
642     #t.remove(99)
643     
644     print "HEADS", t.heads
645     print "TAILS", t.tails
646     
647     import random
648     
649     while False:
650         print
651         print '-'*30
652         print
653         t = Tracker()
654         for i in xrange(random.randrange(100)):
655             x = random.choice(list(t.shares) + [None])
656             print i, '->', x
657             t.add(FakeShare(i, x))
658         while t.shares:
659             x = random.choice(list(t.shares))
660             print "DEL", x, t.__dict__
661             try:
662                 t.remove(x)
663             except NotImplementedError:
664                 print "aborted; not implemented"
665         import time
666         time.sleep(.1)
667         print "HEADS", t.heads
668         print "TAILS", t.tails
669     
670     #for share_hash, share in sorted(t.shares.iteritems()):
671     #    print share_hash, share.previous_hash, t.heads.get(share_hash), t.tails.get(share_hash)
672     
673     #import sys;sys.exit()
674     
675     print t.get_nth_parent_hash(9000, 5000)
676     print t.get_nth_parent_hash(9001, 412)
677     #print t.get_nth_parent_hash(90, 51)
678     
679     for share_hash in sorted(t.shares):
680         print str(share_hash).rjust(4),
681         x = t.skips.get(share_hash, None)
682         if x is not None:
683             print str(x[0]).rjust(4),
684             for a in x[1]:
685                 print str(a).rjust(10),
686         print
687
688 # network definitions
689
690 class Mainnet(object):
691     BITCOIN_P2P_PREFIX = 'f9beb4d9'.decode('hex')
692     BITCOIN_P2P_PORT = 8333
693     BITCOIN_ADDRESS_VERSION = 0
694
695 class Testnet(object):
696     BITCOIN_P2P_PREFIX = 'fabfb5da'.decode('hex')
697     BITCOIN_P2P_PORT = 18333
698     BITCOIN_ADDRESS_VERSION = 111