allow for variably-sized pubkeys
[p2pool.git] / p2pool / bitcoin / data.py
1 from __future__ import division
2
3 import hashlib
4 import struct
5
6 from . import base58
7 from p2pool.util import bases, math, expiring_dict, memoize, slush
8 import p2pool
9
10 class EarlyEnd(Exception):
11     pass
12
13 class LateEnd(Exception):
14     pass
15
16 def read((data, pos), length):
17     data2 = data[pos:pos + length]
18     if len(data2) != length:
19         raise EarlyEnd()
20     return data2, (data, pos + length)
21
22 def size((data, pos)):
23     return len(data) - pos
24
25 class Type(object):
26     # the same data can have only one unpacked representation, but multiple packed binary representations
27     
28     def __hash__(self):
29         rval = getattr(self, '_hash', None)
30         if rval is None:
31             try:
32                 rval = self._hash = hash((type(self), frozenset(self.__dict__.items())))
33             except:
34                 print self.__dict__
35                 raise
36         return rval
37     
38     def __eq__(self, other):
39         return type(other) is type(self) and other.__dict__ == self.__dict__
40     
41     def __ne__(self, other):
42         return not (self == other)
43     
44     def _unpack(self, data):
45         obj, (data2, pos) = self.read((data, 0))
46         
47         assert data2 is data
48         
49         if pos != len(data):
50             raise LateEnd()
51         
52         return obj
53     
54     def _pack(self, obj):
55         f = self.write(None, obj)
56         
57         res = []
58         while f is not None:
59             res.append(f[1])
60             f = f[0]
61         res.reverse()
62         return ''.join(res)
63     
64     
65     def unpack(self, data):
66         obj = self._unpack(data)
67         
68         if p2pool.DEBUG:
69             data2 = self._pack(obj)
70             if data2 != data:
71                 if self._unpack(data2) != obj:
72                     raise AssertionError()
73         
74         return obj
75     
76     def pack2(self, obj):
77         data = self._pack(obj)
78         
79         if p2pool.DEBUG:
80             if self._unpack(data) != obj:
81                 raise AssertionError((self._unpack(data), obj))
82         
83         return data
84     
85     _backing = None
86     
87     @classmethod
88     def enable_caching(cls):
89         assert cls._backing is None
90         cls._backing = expiring_dict.ExpiringDict(100)
91         cls._pre_pack2 = cls.pack2
92         cls.pack2 = memoize.memoize_with_backing(cls._backing, [cls.unpack])(cls.pack2)
93         cls._pre_unpack = cls.unpack
94         cls.unpack = memoize.memoize_with_backing(cls._backing)(cls.unpack) # doesn't have an inverse
95     
96     @classmethod
97     def disable_caching(cls):
98         assert cls._backing is not None
99         cls._backing.stop()
100         cls._backing = None
101         cls.pack2 = cls._pre_pack2
102         del cls._pre_pack2
103         cls.unpack = cls._pre_unpack
104         del cls._pre_unpack
105     
106     def pack(self, obj):
107         return self.pack2(slush.immutify(obj))
108     
109     
110     def pack_base58(self, obj):
111         return base58.encode(self.pack(obj))
112     
113     def unpack_base58(self, base58_data):
114         return self.unpack(base58.decode(base58_data))
115     
116     
117     def hash160(self, obj):
118         return ShortHashType().unpack(hashlib.new('ripemd160', hashlib.sha256(self.pack(obj)).digest()).digest())
119     
120     def hash256(self, obj):
121         return HashType().unpack(hashlib.sha256(hashlib.sha256(self.pack(obj)).digest()).digest())
122     
123     def scrypt(self, obj):
124         import ltc_scrypt
125         return HashType().unpack(ltc_scrypt.getPoWHash(self.pack(obj)))
126
127 class VarIntType(Type):
128     # redundancy doesn't matter here because bitcoin and p2pool both reencode before hashing
129     def read(self, file):
130         data, file = read(file, 1)
131         first = ord(data)
132         if first < 0xfd:
133             return first, file
134         elif first == 0xfd:
135             desc, length = '<H', 2
136         elif first == 0xfe:
137             desc, length = '<I', 4
138         elif first == 0xff:
139             desc, length = '<Q', 8
140         else:
141             raise AssertionError()
142         data, file = read(file, length)
143         return struct.unpack(desc, data)[0], file
144     
145     def write(self, file, item):
146         if item < 0xfd:
147             file = file, struct.pack('<B', item)
148         elif item <= 0xffff:
149             file = file, struct.pack('<BH', 0xfd, item)
150         elif item <= 0xffffffff:
151             file = file, struct.pack('<BI', 0xfe, item)
152         elif item <= 0xffffffffffffffff:
153             file = file, struct.pack('<BQ', 0xff, item)
154         else:
155             raise ValueError('int too large for varint')
156         return file
157
158 class VarStrType(Type):
159     _inner_size = VarIntType()
160     
161     def read(self, file):
162         length, file = self._inner_size.read(file)
163         return read(file, length)
164     
165     def write(self, file, item):
166         return self._inner_size.write(file, len(item)), item
167
168 class FixedStrType(Type):
169     def __init__(self, length):
170         self.length = length
171     
172     def read(self, file):
173         return read(file, self.length)
174     
175     def write(self, file, item):
176         if len(item) != self.length:
177             raise ValueError('incorrect length item!')
178         return file, item
179
180 class PassthruType(Type):
181     def read(self, file):
182         return read(file, size(file))
183     
184     def write(self, file, item):
185         return file, item
186
187 class EnumType(Type):
188     def __init__(self, inner, values):
189         self.inner = inner
190         self.values = slush.frozendict(values)
191         
192         keys = {}
193         for k, v in values.iteritems():
194             if v in keys:
195                 raise ValueError('duplicate value in values')
196             keys[v] = k
197         self.keys = slush.frozendict(keys)
198     
199     def read(self, file):
200         data, file = self.inner.read(file)
201         if data not in self.keys:
202             raise ValueError('enum data (%r) not in values (%r)' % (data, self.values))
203         return self.keys[data], file
204     
205     def write(self, file, item):
206         if item not in self.values:
207             raise ValueError('enum item (%r) not in values (%r)' % (item, self.values))
208         return self.inner.write(file, self.values[item])
209
210 class HashType(Type):
211     def read(self, file):
212         data, file = read(file, 256//8)
213         return int(data[::-1].encode('hex'), 16), file
214     
215     def write(self, file, item):
216         if not 0 <= item < 2**256:
217             raise ValueError('invalid hash value - %r' % (item,))
218         if item != 0 and item < 2**160:
219             print 'Very low hash value - maybe you meant to use ShortHashType? %x' % (item,)
220         return file, ('%064x' % (item,)).decode('hex')[::-1]
221
222 class ShortHashType(Type):
223     def read(self, file):
224         data, file = read(file, 160//8)
225         return int(data[::-1].encode('hex'), 16), file
226     
227     def write(self, file, item):
228         if not 0 <= item < 2**160:
229             raise ValueError('invalid hash value - %r' % (item,))
230         return file, ('%040x' % (item,)).decode('hex')[::-1]
231
232 class ListType(Type):
233     _inner_size = VarIntType()
234     
235     def __init__(self, type):
236         self.type = type
237     
238     def read(self, file):
239         length, file = self._inner_size.read(file)
240         res = []
241         for i in xrange(length):
242             item, file = self.type.read(file)
243             res.append(item)
244         return res, file
245     
246     def write(self, file, item):
247         file = self._inner_size.write(file, len(item))
248         for subitem in item:
249             file = self.type.write(file, subitem)
250         return file
251
252 class StructType(Type):
253     def __init__(self, desc):
254         self.desc = desc
255         self.length = struct.calcsize(self.desc)
256     
257     def read(self, file):
258         data, file = read(file, self.length)
259         res, = struct.unpack(self.desc, data)
260         return res, file
261     
262     def write(self, file, item):
263         data = struct.pack(self.desc, item)
264         if struct.unpack(self.desc, data)[0] != item:
265             # special test because struct doesn't error on some overflows
266             raise ValueError('''item didn't survive pack cycle (%r)''' % (item,))
267         return file, data
268
269 class IPV6AddressType(Type):
270     def read(self, file):
271         data, file = read(file, 16)
272         if data[:12] != '00000000000000000000ffff'.decode('hex'):
273             raise ValueError('ipv6 addresses not supported yet')
274         return '.'.join(str(ord(x)) for x in data[12:]), file
275     
276     def write(self, file, item):
277         bits = map(int, item.split('.'))
278         if len(bits) != 4:
279             raise ValueError('invalid address: %r' % (bits,))
280         data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
281         assert len(data) == 16, len(data)
282         return file, data
283
284 _record_types = {}
285
286 def get_record(fields):
287     fields = tuple(sorted(fields))
288     if 'keys' in fields:
289         raise ValueError()
290     if fields not in _record_types:
291         class _Record(object):
292             __slots__ = fields
293             def __repr__(self):
294                 return repr(dict(self))
295             def __getitem__(self, key):
296                 return getattr(self, key)
297             def __setitem__(self, key, value):
298                 setattr(self, key, value)
299             #def __iter__(self):
300             #    for field in self.__slots__:
301             #        yield field, getattr(self, field)
302             def keys(self):
303                 return self.__slots__
304             def __eq__(self, other):
305                 if isinstance(other, dict):
306                     return dict(self) == other
307                 elif isinstance(other, _Record):
308                     return all(self[k] == other[k] for k in self.keys())
309                 raise TypeError()
310             def __ne__(self, other):
311                 return not (self == other)
312         _record_types[fields] = _Record
313     return _record_types[fields]()
314
315 class ComposedType(Type):
316     def __init__(self, fields):
317         self.fields = tuple(fields)
318     
319     def read(self, file):
320         item = get_record(k for k, v in self.fields)
321         for key, type_ in self.fields:
322             item[key], file = type_.read(file)
323         return item, file
324     
325     def write(self, file, item):
326         for key, type_ in self.fields:
327             file = type_.write(file, item[key])
328         return file
329
330 class ChecksummedType(Type):
331     def __init__(self, inner):
332         self.inner = inner
333     
334     def read(self, file):
335         obj, file = self.inner.read(file)
336         data = self.inner.pack(obj)
337         
338         checksum, file = read(file, 4)
339         if checksum != hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]:
340             raise ValueError('invalid checksum')
341         
342         return obj, file
343     
344     def write(self, file, item):
345         data = self.inner.pack(item)
346         return (file, data), hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]
347
348 class FloatingInteger(object):
349     __slots__ = ['bits', '_target']
350     
351     @classmethod
352     def from_target_upper_bound(cls, target):
353         n = bases.natural_to_string(target)
354         if n and ord(n[0]) >= 128:
355             n = '\x00' + n
356         bits2 = (chr(len(n)) + (n + 3*chr(0))[:3])[::-1]
357         bits = struct.unpack('<I', bits2)[0]
358         return cls(bits)
359     
360     def __init__(self, bits, target=None):
361         self.bits = bits
362         self._target = None
363         if target is not None and self.target != target:
364             raise ValueError('target does not match')
365     
366     @property
367     def target(self):
368         res = self._target
369         if res is None:
370             res = self._target = math.shift_left(self.bits & 0x00ffffff, 8 * ((self.bits >> 24) - 3))
371         return res
372     
373     def __hash__(self):
374         return hash(self.bits)
375     
376     def __eq__(self, other):
377         return self.bits == other.bits
378     
379     def __ne__(self, other):
380         return not (self == other)
381     
382     def __cmp__(self, other):
383         assert False
384     
385     def __repr__(self):
386         return 'FloatingInteger(bits=%s, target=%s)' % (hex(self.bits), hex(self.target))
387
388 class FloatingIntegerType(Type):
389     _inner = StructType('<I')
390     
391     def read(self, file):
392         bits, file = self._inner.read(file)
393         return FloatingInteger(bits), file
394     
395     def write(self, file, item):
396         return self._inner.write(file, item.bits)
397
398 class PossiblyNoneType(Type):
399     def __init__(self, none_value, inner):
400         self.none_value = none_value
401         self.inner = inner
402     
403     def read(self, file):
404         value, file = self.inner.read(file)
405         return None if value == self.none_value else value, file
406     
407     def write(self, file, item):
408         if item == self.none_value:
409             raise ValueError('none_value used')
410         return self.inner.write(file, self.none_value if item is None else item)
411
412 address_type = ComposedType([
413     ('services', StructType('<Q')),
414     ('address', IPV6AddressType()),
415     ('port', StructType('>H')),
416 ])
417
418 tx_type = ComposedType([
419     ('version', StructType('<I')),
420     ('tx_ins', ListType(ComposedType([
421         ('previous_output', PossiblyNoneType(slush.frozendict(hash=0, index=2**32 - 1), ComposedType([
422             ('hash', HashType()),
423             ('index', StructType('<I')),
424         ]))),
425         ('script', VarStrType()),
426         ('sequence', PossiblyNoneType(2**32 - 1, StructType('<I'))),
427     ]))),
428     ('tx_outs', ListType(ComposedType([
429         ('value', StructType('<Q')),
430         ('script', VarStrType()),
431     ]))),
432     ('lock_time', StructType('<I')),
433 ])
434
435 merkle_branch_type = ListType(HashType())
436
437 merkle_tx_type = ComposedType([
438     ('tx', tx_type),
439     ('block_hash', HashType()),
440     ('merkle_branch', merkle_branch_type),
441     ('index', StructType('<i')),
442 ])
443
444 block_header_type = ComposedType([
445     ('version', StructType('<I')),
446     ('previous_block', PossiblyNoneType(0, HashType())),
447     ('merkle_root', HashType()),
448     ('timestamp', StructType('<I')),
449     ('bits', FloatingIntegerType()),
450     ('nonce', StructType('<I')),
451 ])
452
453 block_type = ComposedType([
454     ('header', block_header_type),
455     ('txs', ListType(tx_type)),
456 ])
457
458 aux_pow_type = ComposedType([
459     ('merkle_tx', merkle_tx_type),
460     ('merkle_branch', merkle_branch_type),
461     ('index', StructType('<i')),
462     ('parent_block_header', block_header_type),
463 ])
464
465
466 merkle_record_type = ComposedType([
467     ('left', HashType()),
468     ('right', HashType()),
469 ])
470
471 def merkle_hash(hashes):
472     if not hashes:
473         return 0
474     hash_list = list(hashes)
475     while len(hash_list) > 1:
476         hash_list = [merkle_record_type.hash256(dict(left=left, right=left if right is None else right))
477             for left, right in zip(hash_list[::2], hash_list[1::2] + [None])]
478     return hash_list[0]
479
480 def calculate_merkle_branch(hashes, index):
481     # XXX optimize this
482     
483     hash_list = [(h, i == index, []) for i, h in enumerate(hashes)]
484     
485     while len(hash_list) > 1:
486         hash_list = [
487             (
488                 merkle_record_type.hash256(dict(left=left, right=right)),
489                 left_f or right_f,
490                 (left_l if left_f else right_l) + [dict(side=1, hash=right) if left_f else dict(side=0, hash=left)],
491             )
492             for (left, left_f, left_l), (right, right_f, right_l) in
493                 zip(hash_list[::2], hash_list[1::2] + [hash_list[::2][-1]])
494         ]
495     
496     res = [x['hash'] for x in hash_list[0][2]]
497     
498     assert hash_list[0][1]
499     assert check_merkle_branch(hashes[index], index, res) == hash_list[0][0]
500     assert index == sum(k*2**i for i, k in enumerate([1-x['side'] for x in hash_list[0][2]]))
501     
502     return res
503
504 def check_merkle_branch(tip_hash, index, merkle_branch):
505     return reduce(lambda c, (i, h): merkle_record_type.hash256(
506         dict(left=h, right=c) if 2**i & index else
507         dict(left=c, right=h)
508     ), enumerate(merkle_branch), tip_hash)
509
510 def target_to_average_attempts(target):
511     return 2**256//(target + 1)
512
513 def target_to_difficulty(target):
514     return (0xffff0000 * 2**(256-64) + 1)/(target + 1)
515
516 # tx
517
518 def tx_get_sigop_count(tx):
519     return sum(script.get_sigop_count(txin['script']) for txin in tx['tx_ins']) + sum(script.get_sigop_count(txout['script']) for txout in tx['tx_outs'])
520
521 # human addresses
522
523 human_address_type = ChecksummedType(ComposedType([
524     ('version', StructType('<B')),
525     ('pubkey_hash', ShortHashType()),
526 ]))
527
528 pubkey_type = PassthruType()
529
530 def pubkey_hash_to_address(pubkey_hash, net):
531     return human_address_type.pack_base58(dict(version=net.ADDRESS_VERSION, pubkey_hash=pubkey_hash))
532
533 def pubkey_to_address(pubkey, net):
534     return pubkey_hash_to_address(pubkey_type.hash160(pubkey), net)
535
536 def address_to_pubkey_hash(address, net):
537     x = human_address_type.unpack_base58(address)
538     if x['version'] != net.ADDRESS_VERSION:
539         raise ValueError('address not for this net!')
540     return x['pubkey_hash']
541
542 # transactions
543
544 def pubkey_to_script2(pubkey):
545     return ('\x41' + pubkey_type.pack(pubkey)) + '\xac'
546
547 def pubkey_hash_to_script2(pubkey_hash):
548     return '\x76\xa9' + ('\x14' + ShortHashType().pack(pubkey_hash)) + '\x88\xac'
549
550 def script2_to_address(script2, net):
551     try:
552         pubkey = script2[1:-1]
553         script2_test = pubkey_to_script2(pubkey)
554     except:
555         pass
556     else:
557         if script2_test == script2:
558             return pubkey_to_address(pubkey, net)
559     
560     try:
561         pubkey_hash = ShortHashType().unpack(script2[3:-2])
562         script2_test2 = pubkey_hash_to_script2(pubkey_hash)
563     except:
564         pass
565     else:
566         if script2_test2 == script2:
567             return pubkey_hash_to_address(pubkey_hash, net)
568
569 def script2_to_human(script2, net):
570     try:
571         pubkey = script2[1:-1]
572         script2_test = pubkey_to_script2(pubkey)
573     except:
574         pass
575     else:
576         if script2_test == script2:
577             return 'Pubkey. Address: %s' % (pubkey_to_address(pubkey, net),)
578     
579     try:
580         pubkey_hash = ShortHashType().unpack(script2[3:-2])
581         script2_test2 = pubkey_hash_to_script2(pubkey_hash)
582     except:
583         pass
584     else:
585         if script2_test2 == script2:
586             return 'Address. Address: %s' % (pubkey_hash_to_address(pubkey_hash, net),)
587     
588     return 'Unknown. Script: %s'  % (script2.encode('hex'),)