unified HashType, ShortHashType, and StructType into IntType
[p2pool.git] / p2pool / bitcoin / data.py
1 from __future__ import division
2
3 import hashlib
4 import struct
5
6 from . import base58
7 from p2pool.util import bases, math
8 import p2pool
9
10 class EarlyEnd(Exception):
11     pass
12
13 class LateEnd(Exception):
14     pass
15
16 def read((data, pos), length):
17     data2 = data[pos:pos + length]
18     if len(data2) != length:
19         raise EarlyEnd()
20     return data2, (data, pos + length)
21
22 def size((data, pos)):
23     return len(data) - pos
24
25 class Type(object):
26     # the same data can have only one unpacked representation, but multiple packed binary representations
27     
28     def __hash__(self):
29         rval = getattr(self, '_hash', None)
30         if rval is None:
31             try:
32                 rval = self._hash = hash((type(self), frozenset(self.__dict__.items())))
33             except:
34                 print self.__dict__
35                 raise
36         return rval
37     
38     def __eq__(self, other):
39         return type(other) is type(self) and other.__dict__ == self.__dict__
40     
41     def __ne__(self, other):
42         return not (self == other)
43     
44     def _unpack(self, data):
45         obj, (data2, pos) = self.read((data, 0))
46         
47         assert data2 is data
48         
49         if pos != len(data):
50             raise LateEnd()
51         
52         return obj
53     
54     def _pack(self, obj):
55         f = self.write(None, obj)
56         
57         res = []
58         while f is not None:
59             res.append(f[1])
60             f = f[0]
61         res.reverse()
62         return ''.join(res)
63     
64     
65     def unpack(self, data):
66         obj = self._unpack(data)
67         
68         if p2pool.DEBUG:
69             data2 = self._pack(obj)
70             if data2 != data:
71                 if self._unpack(data2) != obj:
72                     raise AssertionError()
73         
74         return obj
75     
76     def pack(self, obj):
77         data = self._pack(obj)
78         
79         if p2pool.DEBUG:
80             if self._unpack(data) != obj:
81                 raise AssertionError((self._unpack(data), obj))
82         
83         return data
84     
85     
86     def pack_base58(self, obj):
87         return base58.encode(self.pack(obj))
88     
89     def unpack_base58(self, base58_data):
90         return self.unpack(base58.decode(base58_data))
91     
92     
93     def hash160(self, obj):
94         return IntType(160).unpack(hashlib.new('ripemd160', hashlib.sha256(self.pack(obj)).digest()).digest())
95     
96     def hash256(self, obj):
97         return IntType(256).unpack(hashlib.sha256(hashlib.sha256(self.pack(obj)).digest()).digest())
98     
99     def scrypt(self, obj):
100         import ltc_scrypt
101         return IntType(256).unpack(ltc_scrypt.getPoWHash(self.pack(obj)))
102
103 class VarIntType(Type):
104     # redundancy doesn't matter here because bitcoin and p2pool both reencode before hashing
105     def read(self, file):
106         data, file = read(file, 1)
107         first = ord(data)
108         if first < 0xfd:
109             return first, file
110         elif first == 0xfd:
111             desc, length = '<H', 2
112         elif first == 0xfe:
113             desc, length = '<I', 4
114         elif first == 0xff:
115             desc, length = '<Q', 8
116         else:
117             raise AssertionError()
118         data, file = read(file, length)
119         return struct.unpack(desc, data)[0], file
120     
121     def write(self, file, item):
122         if item < 0xfd:
123             file = file, struct.pack('<B', item)
124         elif item <= 0xffff:
125             file = file, struct.pack('<BH', 0xfd, item)
126         elif item <= 0xffffffff:
127             file = file, struct.pack('<BI', 0xfe, item)
128         elif item <= 0xffffffffffffffff:
129             file = file, struct.pack('<BQ', 0xff, item)
130         else:
131             raise ValueError('int too large for varint')
132         return file
133
134 class VarStrType(Type):
135     _inner_size = VarIntType()
136     
137     def read(self, file):
138         length, file = self._inner_size.read(file)
139         return read(file, length)
140     
141     def write(self, file, item):
142         return self._inner_size.write(file, len(item)), item
143
144 class PassthruType(Type):
145     def read(self, file):
146         return read(file, size(file))
147     
148     def write(self, file, item):
149         return file, item
150
151 class EnumType(Type):
152     def __init__(self, inner, values):
153         self.inner = inner
154         self.values = values
155         
156         keys = {}
157         for k, v in values.iteritems():
158             if v in keys:
159                 raise ValueError('duplicate value in values')
160             keys[v] = k
161         self.keys = keys
162     
163     def read(self, file):
164         data, file = self.inner.read(file)
165         if data not in self.keys:
166             raise ValueError('enum data (%r) not in values (%r)' % (data, self.values))
167         return self.keys[data], file
168     
169     def write(self, file, item):
170         if item not in self.values:
171             raise ValueError('enum item (%r) not in values (%r)' % (item, self.values))
172         return self.inner.write(file, self.values[item])
173
174 class ListType(Type):
175     _inner_size = VarIntType()
176     
177     def __init__(self, type):
178         self.type = type
179     
180     def read(self, file):
181         length, file = self._inner_size.read(file)
182         res = []
183         for i in xrange(length):
184             item, file = self.type.read(file)
185             res.append(item)
186         return res, file
187     
188     def write(self, file, item):
189         file = self._inner_size.write(file, len(item))
190         for subitem in item:
191             file = self.type.write(file, subitem)
192         return file
193
194 class StructType(Type):
195     def __init__(self, desc):
196         self.desc = desc
197         self.length = struct.calcsize(self.desc)
198     
199     def read(self, file):
200         data, file = read(file, self.length)
201         return struct.unpack(self.desc, data)[0], file
202     
203     def write(self, file, item):
204         return file, struct.pack(self.desc, item)
205
206 class IntType(Type):
207     def __new__(cls, bits, endianness='little'):
208         assert bits % 8 == 0
209         assert endianness in ['little', 'big']
210         if bits in [8, 16, 32, 64]:
211             return StructType(('<' if endianness == 'little' else '>') + {8: 'B', 16: 'H', 32: 'I', 64: 'Q'}[bits])
212         else:
213             return object.__new__(cls, bits, endianness)
214     
215     def __init__(self, bits, endianness='little'):
216         assert bits % 8 == 0
217         assert endianness in ['little', 'big']
218         self.bytes = bits//8
219         self.step = -1 if endianness == 'little' else 1
220     
221     def read(self, file):
222         data, file = read(file, self.bytes)
223         return int(data[::self.step].encode('hex'), 16), file
224     
225     def write(self, file, item):
226         if not 0 <= item < 2**(8*self.bytes):
227             raise ValueError('invalid int value - %r' % (item,))
228         return file, ('%x' % (item,)).zfill(2*self.bytes).decode('hex')[::self.step]
229
230 class IPV6AddressType(Type):
231     def read(self, file):
232         data, file = read(file, 16)
233         if data[:12] != '00000000000000000000ffff'.decode('hex'):
234             raise ValueError('ipv6 addresses not supported yet')
235         return '.'.join(str(ord(x)) for x in data[12:]), file
236     
237     def write(self, file, item):
238         bits = map(int, item.split('.'))
239         if len(bits) != 4:
240             raise ValueError('invalid address: %r' % (bits,))
241         data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
242         assert len(data) == 16, len(data)
243         return file, data
244
245 _record_types = {}
246
247 def get_record(fields):
248     fields = tuple(sorted(fields))
249     if 'keys' in fields:
250         raise ValueError()
251     if fields not in _record_types:
252         class _Record(object):
253             __slots__ = fields
254             def __repr__(self):
255                 return repr(dict(self))
256             def __getitem__(self, key):
257                 return getattr(self, key)
258             def __setitem__(self, key, value):
259                 setattr(self, key, value)
260             #def __iter__(self):
261             #    for field in self.__slots__:
262             #        yield field, getattr(self, field)
263             def keys(self):
264                 return self.__slots__
265             def __eq__(self, other):
266                 if isinstance(other, dict):
267                     return dict(self) == other
268                 elif isinstance(other, _Record):
269                     return all(self[k] == other[k] for k in self.keys())
270                 raise TypeError()
271             def __ne__(self, other):
272                 return not (self == other)
273         _record_types[fields] = _Record
274     return _record_types[fields]()
275
276 class ComposedType(Type):
277     def __init__(self, fields):
278         self.fields = tuple(fields)
279     
280     def read(self, file):
281         item = get_record(k for k, v in self.fields)
282         for key, type_ in self.fields:
283             item[key], file = type_.read(file)
284         return item, file
285     
286     def write(self, file, item):
287         for key, type_ in self.fields:
288             file = type_.write(file, item[key])
289         return file
290
291 class ChecksummedType(Type):
292     def __init__(self, inner):
293         self.inner = inner
294     
295     def read(self, file):
296         obj, file = self.inner.read(file)
297         data = self.inner.pack(obj)
298         
299         checksum, file = read(file, 4)
300         if checksum != hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]:
301             raise ValueError('invalid checksum')
302         
303         return obj, file
304     
305     def write(self, file, item):
306         data = self.inner.pack(item)
307         return (file, data), hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]
308
309 class FloatingInteger(object):
310     __slots__ = ['bits', '_target']
311     
312     @classmethod
313     def from_target_upper_bound(cls, target):
314         n = bases.natural_to_string(target)
315         if n and ord(n[0]) >= 128:
316             n = '\x00' + n
317         bits2 = (chr(len(n)) + (n + 3*chr(0))[:3])[::-1]
318         bits = struct.unpack('<I', bits2)[0]
319         return cls(bits)
320     
321     def __init__(self, bits, target=None):
322         self.bits = bits
323         self._target = None
324         if target is not None and self.target != target:
325             raise ValueError('target does not match')
326     
327     @property
328     def target(self):
329         res = self._target
330         if res is None:
331             res = self._target = math.shift_left(self.bits & 0x00ffffff, 8 * ((self.bits >> 24) - 3))
332         return res
333     
334     def __hash__(self):
335         return hash(self.bits)
336     
337     def __eq__(self, other):
338         return self.bits == other.bits
339     
340     def __ne__(self, other):
341         return not (self == other)
342     
343     def __cmp__(self, other):
344         assert False
345     
346     def __repr__(self):
347         return 'FloatingInteger(bits=%s, target=%s)' % (hex(self.bits), hex(self.target))
348
349 class FloatingIntegerType(Type):
350     _inner = IntType(32)
351     
352     def read(self, file):
353         bits, file = self._inner.read(file)
354         return FloatingInteger(bits), file
355     
356     def write(self, file, item):
357         return self._inner.write(file, item.bits)
358
359 class PossiblyNoneType(Type):
360     def __init__(self, none_value, inner):
361         self.none_value = none_value
362         self.inner = inner
363     
364     def read(self, file):
365         value, file = self.inner.read(file)
366         return None if value == self.none_value else value, file
367     
368     def write(self, file, item):
369         if item == self.none_value:
370             raise ValueError('none_value used')
371         return self.inner.write(file, self.none_value if item is None else item)
372
373 address_type = ComposedType([
374     ('services', IntType(64)),
375     ('address', IPV6AddressType()),
376     ('port', IntType(16, 'big')),
377 ])
378
379 tx_type = ComposedType([
380     ('version', IntType(32)),
381     ('tx_ins', ListType(ComposedType([
382         ('previous_output', PossiblyNoneType(dict(hash=0, index=2**32 - 1), ComposedType([
383             ('hash', IntType(256)),
384             ('index', IntType(32)),
385         ]))),
386         ('script', VarStrType()),
387         ('sequence', PossiblyNoneType(2**32 - 1, IntType(32))),
388     ]))),
389     ('tx_outs', ListType(ComposedType([
390         ('value', IntType(64)),
391         ('script', VarStrType()),
392     ]))),
393     ('lock_time', IntType(32)),
394 ])
395
396 merkle_branch_type = ListType(IntType(256))
397
398 merkle_tx_type = ComposedType([
399     ('tx', tx_type),
400     ('block_hash', IntType(256)),
401     ('merkle_branch', merkle_branch_type),
402     ('index', IntType(32)),
403 ])
404
405 block_header_type = ComposedType([
406     ('version', IntType(32)),
407     ('previous_block', PossiblyNoneType(0, IntType(256))),
408     ('merkle_root', IntType(256)),
409     ('timestamp', IntType(32)),
410     ('bits', FloatingIntegerType()),
411     ('nonce', IntType(32)),
412 ])
413
414 block_type = ComposedType([
415     ('header', block_header_type),
416     ('txs', ListType(tx_type)),
417 ])
418
419 aux_pow_type = ComposedType([
420     ('merkle_tx', merkle_tx_type),
421     ('merkle_branch', merkle_branch_type),
422     ('index', IntType(32)),
423     ('parent_block_header', block_header_type),
424 ])
425
426
427 merkle_record_type = ComposedType([
428     ('left', IntType(256)),
429     ('right', IntType(256)),
430 ])
431
432 def merkle_hash(hashes):
433     if not hashes:
434         return 0
435     hash_list = list(hashes)
436     while len(hash_list) > 1:
437         hash_list = [merkle_record_type.hash256(dict(left=left, right=left if right is None else right))
438             for left, right in zip(hash_list[::2], hash_list[1::2] + [None])]
439     return hash_list[0]
440
441 def calculate_merkle_branch(hashes, index):
442     # XXX optimize this
443     
444     hash_list = [(h, i == index, []) for i, h in enumerate(hashes)]
445     
446     while len(hash_list) > 1:
447         hash_list = [
448             (
449                 merkle_record_type.hash256(dict(left=left, right=right)),
450                 left_f or right_f,
451                 (left_l if left_f else right_l) + [dict(side=1, hash=right) if left_f else dict(side=0, hash=left)],
452             )
453             for (left, left_f, left_l), (right, right_f, right_l) in
454                 zip(hash_list[::2], hash_list[1::2] + [hash_list[::2][-1]])
455         ]
456     
457     res = [x['hash'] for x in hash_list[0][2]]
458     
459     assert hash_list[0][1]
460     assert check_merkle_branch(hashes[index], index, res) == hash_list[0][0]
461     assert index == sum(k*2**i for i, k in enumerate([1-x['side'] for x in hash_list[0][2]]))
462     
463     return res
464
465 def check_merkle_branch(tip_hash, index, merkle_branch):
466     return reduce(lambda c, (i, h): merkle_record_type.hash256(
467         dict(left=h, right=c) if 2**i & index else
468         dict(left=c, right=h)
469     ), enumerate(merkle_branch), tip_hash)
470
471 def target_to_average_attempts(target):
472     return 2**256//(target + 1)
473
474 def target_to_difficulty(target):
475     return (0xffff0000 * 2**(256-64) + 1)/(target + 1)
476
477 # tx
478
479 def tx_get_sigop_count(tx):
480     return sum(script.get_sigop_count(txin['script']) for txin in tx['tx_ins']) + sum(script.get_sigop_count(txout['script']) for txout in tx['tx_outs'])
481
482 # human addresses
483
484 human_address_type = ChecksummedType(ComposedType([
485     ('version', IntType(8)),
486     ('pubkey_hash', IntType(160)),
487 ]))
488
489 pubkey_type = PassthruType()
490
491 def pubkey_hash_to_address(pubkey_hash, net):
492     return human_address_type.pack_base58(dict(version=net.ADDRESS_VERSION, pubkey_hash=pubkey_hash))
493
494 def pubkey_to_address(pubkey, net):
495     return pubkey_hash_to_address(pubkey_type.hash160(pubkey), net)
496
497 def address_to_pubkey_hash(address, net):
498     x = human_address_type.unpack_base58(address)
499     if x['version'] != net.ADDRESS_VERSION:
500         raise ValueError('address not for this net!')
501     return x['pubkey_hash']
502
503 # transactions
504
505 def pubkey_to_script2(pubkey):
506     return ('\x41' + pubkey_type.pack(pubkey)) + '\xac'
507
508 def pubkey_hash_to_script2(pubkey_hash):
509     return '\x76\xa9' + ('\x14' + IntType(160).pack(pubkey_hash)) + '\x88\xac'
510
511 def script2_to_address(script2, net):
512     try:
513         pubkey = script2[1:-1]
514         script2_test = pubkey_to_script2(pubkey)
515     except:
516         pass
517     else:
518         if script2_test == script2:
519             return pubkey_to_address(pubkey, net)
520     
521     try:
522         pubkey_hash = IntType(160).unpack(script2[3:-2])
523         script2_test2 = pubkey_hash_to_script2(pubkey_hash)
524     except:
525         pass
526     else:
527         if script2_test2 == script2:
528             return pubkey_hash_to_address(pubkey_hash, net)
529
530 def script2_to_human(script2, net):
531     try:
532         pubkey = script2[1:-1]
533         script2_test = pubkey_to_script2(pubkey)
534     except:
535         pass
536     else:
537         if script2_test == script2:
538             return 'Pubkey. Address: %s' % (pubkey_to_address(pubkey, net),)
539     
540     try:
541         pubkey_hash = IntType(160).unpack(script2[3:-2])
542         script2_test2 = pubkey_hash_to_script2(pubkey_hash)
543     except:
544         pass
545     else:
546         if script2_test2 == script2:
547             return 'Address. Address: %s' % (pubkey_hash_to_address(pubkey_hash, net),)
548     
549     return 'Unknown. Script: %s'  % (script2.encode('hex'),)