style and import fixes
[p2pool.git] / p2pool / bitcoin / data.py
1 from __future__ import division
2
3 import hashlib
4 import struct
5
6 from . import base58
7 from p2pool.util import bases, math, expiring_dict, memoize, dicts
8 import p2pool
9
10 class EarlyEnd(Exception):
11     pass
12
13 class LateEnd(Exception):
14     pass
15
16 def read((data, pos), length):
17     data2 = data[pos:pos + length]
18     if len(data2) != length:
19         raise EarlyEnd()
20     return data2, (data, pos + length)
21
22 def size((data, pos)):
23     return len(data) - pos
24
25 class Type(object):
26     # the same data can have only one unpacked representation, but multiple packed binary representations
27     
28     def __hash__(self):
29         rval = getattr(self, '_hash', None)
30         if rval is None:
31             try:
32                 rval = self._hash = hash((type(self), frozenset(self.__dict__.items())))
33             except:
34                 print self.__dict__
35                 raise
36         return rval
37     
38     def __eq__(self, other):
39         return type(other) is type(self) and other.__dict__ == self.__dict__
40     
41     def __ne__(self, other):
42         return not (self == other)
43     
44     def _unpack(self, data):
45         obj, (data2, pos) = self.read((data, 0))
46         
47         assert data2 is data
48         
49         if pos != len(data):
50             raise LateEnd()
51         
52         return obj
53     
54     def _pack(self, obj):
55         f = self.write(None, obj)
56         
57         res = []
58         while f is not None:
59             res.append(f[1])
60             f = f[0]
61         res.reverse()
62         return ''.join(res)
63     
64     
65     def unpack(self, data):
66         obj = self._unpack(data)
67         
68         if p2pool.DEBUG:
69             data2 = self._pack(obj)
70             if data2 != data:
71                 if self._unpack(data2) != obj:
72                     raise AssertionError()
73         
74         return obj
75     
76     def pack2(self, obj):
77         data = self._pack(obj)
78         
79         if p2pool.DEBUG:
80             if self._unpack(data) != obj:
81                 raise AssertionError((self._unpack(data), obj))
82         
83         return data
84     
85     _backing = expiring_dict.ExpiringDict(100)
86     pack2 = memoize.memoize_with_backing(_backing, [unpack])(pack2)
87     unpack = memoize.memoize_with_backing(_backing)(unpack) # doesn't have an inverse
88     
89     def pack(self, obj):
90         return self.pack2(dicts.immutify(obj))
91     
92     
93     def pack_base58(self, obj):
94         return base58.base58_encode(self.pack(obj))
95     
96     def unpack_base58(self, base58_data):
97         return self.unpack(base58.base58_decode(base58_data))
98     
99     
100     def hash160(self, obj):
101         return ShortHashType().unpack(hashlib.new('ripemd160', hashlib.sha256(self.pack(obj)).digest()).digest())
102     
103     def hash256(self, obj):
104         return HashType().unpack(hashlib.sha256(hashlib.sha256(self.pack(obj)).digest()).digest())
105     
106     def scrypt(self, obj):
107         import ltc_scrypt
108         return HashType().unpack(ltc_scrypt.getPoWHash(self.pack(obj)))
109
110 class VarIntType(Type):
111     # redundancy doesn't matter here because bitcoin and p2pool both reencode before hashing
112     def read(self, file):
113         data, file = read(file, 1)
114         first = ord(data)
115         if first < 0xfd:
116             return first, file
117         elif first == 0xfd:
118             desc, length = '<H', 2
119         elif first == 0xfe:
120             desc, length = '<I', 4
121         elif first == 0xff:
122             desc, length = '<Q', 8
123         else:
124             raise AssertionError()
125         data, file = read(file, length)
126         return struct.unpack(desc, data)[0], file
127     
128     def write(self, file, item):
129         if item < 0xfd:
130             file = file, struct.pack('<B', item)
131         elif item <= 0xffff:
132             file = file, struct.pack('<BH', 0xfd, item)
133         elif item <= 0xffffffff:
134             file = file, struct.pack('<BI', 0xfe, item)
135         elif item <= 0xffffffffffffffff:
136             file = file, struct.pack('<BQ', 0xff, item)
137         else:
138             raise ValueError('int too large for varint')
139         return file
140
141 class VarStrType(Type):
142     _inner_size = VarIntType()
143     
144     def read(self, file):
145         length, file = self._inner_size.read(file)
146         return read(file, length)
147     
148     def write(self, file, item):
149         return self._inner_size.write(file, len(item)), item
150
151 class FixedStrType(Type):
152     def __init__(self, length):
153         self.length = length
154     
155     def read(self, file):
156         return read(file, self.length)
157     
158     def write(self, file, item):
159         if len(item) != self.length:
160             raise ValueError('incorrect length item!')
161         return file, item
162
163 class EnumType(Type):
164     def __init__(self, inner, values):
165         self.inner = inner
166         self.values = dicts.frozendict(values)
167         
168         keys = {}
169         for k, v in values.iteritems():
170             if v in keys:
171                 raise ValueError('duplicate value in values')
172             keys[v] = k
173         self.keys = dicts.frozendict(keys)
174     
175     def read(self, file):
176         data, file = self.inner.read(file)
177         if data not in self.keys:
178             raise ValueError('enum data (%r) not in values (%r)' % (data, self.values))
179         return self.keys[data], file
180     
181     def write(self, file, item):
182         if item not in self.values:
183             raise ValueError('enum item (%r) not in values (%r)' % (item, self.values))
184         return self.inner.write(file, self.values[item])
185
186 class HashType(Type):
187     def read(self, file):
188         data, file = read(file, 256//8)
189         return int(data[::-1].encode('hex'), 16), file
190     
191     def write(self, file, item):
192         if not 0 <= item < 2**256:
193             raise ValueError('invalid hash value - %r' % (item,))
194         if item != 0 and item < 2**160:
195             print 'Very low hash value - maybe you meant to use ShortHashType? %x' % (item,)
196         return file, ('%064x' % (item,)).decode('hex')[::-1]
197
198 class ShortHashType(Type):
199     def read(self, file):
200         data, file = read(file, 160//8)
201         return int(data[::-1].encode('hex'), 16), file
202     
203     def write(self, file, item):
204         if not 0 <= item < 2**160:
205             raise ValueError('invalid hash value - %r' % (item,))
206         return file, ('%040x' % (item,)).decode('hex')[::-1]
207
208 class ListType(Type):
209     _inner_size = VarIntType()
210     
211     def __init__(self, type):
212         self.type = type
213     
214     def read(self, file):
215         length, file = self._inner_size.read(file)
216         res = []
217         for i in xrange(length):
218             item, file = self.type.read(file)
219             res.append(item)
220         return res, file
221     
222     def write(self, file, item):
223         file = self._inner_size.write(file, len(item))
224         for subitem in item:
225             file = self.type.write(file, subitem)
226         return file
227
228 class StructType(Type):
229     def __init__(self, desc):
230         self.desc = desc
231         self.length = struct.calcsize(self.desc)
232     
233     def read(self, file):
234         data, file = read(file, self.length)
235         res, = struct.unpack(self.desc, data)
236         return res, file
237     
238     def write(self, file, item):
239         data = struct.pack(self.desc, item)
240         if struct.unpack(self.desc, data)[0] != item:
241             # special test because struct doesn't error on some overflows
242             raise ValueError('''item didn't survive pack cycle (%r)''' % (item,))
243         return file, data
244
245 class IPV6AddressType(Type):
246     def read(self, file):
247         data, file = read(file, 16)
248         if data[:12] != '00000000000000000000ffff'.decode('hex'):
249             raise ValueError('ipv6 addresses not supported yet')
250         return '.'.join(str(ord(x)) for x in data[12:]), file
251     
252     def write(self, file, item):
253         bits = map(int, item.split('.'))
254         if len(bits) != 4:
255             raise ValueError('invalid address: %r' % (bits,))
256         data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
257         assert len(data) == 16, len(data)
258         return file, data
259
260 _record_types = {}
261
262 def get_record(fields):
263     fields = tuple(sorted(fields))
264     if 'keys' in fields:
265         raise ValueError()
266     if fields not in _record_types:
267         class _Record(object):
268             __slots__ = fields
269             def __repr__(self):
270                 return repr(dict(self))
271             def __getitem__(self, key):
272                 return getattr(self, key)
273             def __setitem__(self, key, value):
274                 setattr(self, key, value)
275             #def __iter__(self):
276             #    for field in self.__slots__:
277             #        yield field, getattr(self, field)
278             def keys(self):
279                 return self.__slots__
280             def __eq__(self, other):
281                 if isinstance(other, dict):
282                     return dict(self) == other
283                 elif isinstance(other, _Record):
284                     return all(self[k] == other[k] for k in self.keys())
285                 raise TypeError()
286             def __ne__(self, other):
287                 return not (self == other)
288         _record_types[fields] = _Record
289     return _record_types[fields]()
290
291 class ComposedType(Type):
292     def __init__(self, fields):
293         self.fields = tuple(fields)
294     
295     def read(self, file):
296         item = get_record(k for k, v in self.fields)
297         for key, type_ in self.fields:
298             item[key], file = type_.read(file)
299         return item, file
300     
301     def write(self, file, item):
302         for key, type_ in self.fields:
303             file = type_.write(file, item[key])
304         return file
305
306 class ChecksummedType(Type):
307     def __init__(self, inner):
308         self.inner = inner
309     
310     def read(self, file):
311         obj, file = self.inner.read(file)
312         data = self.inner.pack(obj)
313         
314         checksum, file = read(file, 4)
315         if checksum != hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]:
316             raise ValueError('invalid checksum')
317         
318         return obj, file
319     
320     def write(self, file, item):
321         data = self.inner.pack(item)
322         return (file, data), hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4]
323
324 class FloatingInteger(object):
325     __slots__ = ['_bits']
326     
327     @classmethod
328     def from_target_upper_bound(cls, target):
329         n = bases.natural_to_string(target)
330         if n and ord(n[0]) >= 128:
331             n = '\x00' + n
332         bits2 = (chr(len(n)) + (n + 3*chr(0))[:3])[::-1]
333         bits = struct.unpack('<I', bits2)[0]
334         return cls(bits)
335     
336     def __init__(self, bits):
337         self._bits = bits
338     
339     @property
340     def _value(self):
341         return math.shift_left(self._bits & 0x00ffffff, 8 * ((self._bits >> 24) - 3))
342     
343     def __hash__(self):
344         return hash(self._value)
345     
346     def __cmp__(self, other):
347         if isinstance(other, FloatingInteger):
348             return cmp(self._value, other._value)
349         elif isinstance(other, (int, long)):
350             return cmp(self._value, other)
351         else:
352             raise NotImplementedError(other)
353     
354     def __int__(self):
355         return self._value
356     
357     def __repr__(self):
358         return 'FloatingInteger(bits=%s (%x))' % (hex(self._bits), self)
359     
360     def __add__(self, other):
361         if isinstance(other, (int, long)):
362             return self._value + other
363         raise NotImplementedError()
364     __radd__ = __add__
365     def __mul__(self, other):
366         if isinstance(other, (int, long)):
367             return self._value * other
368         raise NotImplementedError()
369     __rmul__ = __mul__
370     def __truediv__(self, other):
371         if isinstance(other, (int, long)):
372             return self._value / other
373         raise NotImplementedError()
374     def __floordiv__(self, other):
375         if isinstance(other, (int, long)):
376             return self._value // other
377         raise NotImplementedError()
378     __div__ = __truediv__
379     def __rtruediv__(self, other):
380         if isinstance(other, (int, long)):
381             return other / self._value
382         raise NotImplementedError()
383     def __rfloordiv__(self, other):
384         if isinstance(other, (int, long)):
385             return other // self._value
386         raise NotImplementedError()
387     __rdiv__ = __rtruediv__
388
389 class FloatingIntegerType(Type):
390     _inner = StructType('<I')
391     
392     def read(self, file):
393         bits, file = self._inner.read(file)
394         return FloatingInteger(bits), file
395     
396     def write(self, file, item):
397         return self._inner.write(file, item._bits)
398
399 class PossiblyNoneType(Type):
400     def __init__(self, none_value, inner):
401         self.none_value = none_value
402         self.inner = inner
403     
404     def read(self, file):
405         value, file = self.inner.read(file)
406         return None if value == self.none_value else value, file
407     
408     def write(self, file, item):
409         if item == self.none_value:
410             raise ValueError('none_value used')
411         return self.inner.write(file, self.none_value if item is None else item)
412
413 address_type = ComposedType([
414     ('services', StructType('<Q')),
415     ('address', IPV6AddressType()),
416     ('port', StructType('>H')),
417 ])
418
419 tx_type = ComposedType([
420     ('version', StructType('<I')),
421     ('tx_ins', ListType(ComposedType([
422         ('previous_output', PossiblyNoneType(dicts.frozendict(hash=0, index=2**32 - 1), ComposedType([
423             ('hash', HashType()),
424             ('index', StructType('<I')),
425         ]))),
426         ('script', VarStrType()),
427         ('sequence', PossiblyNoneType(2**32 - 1, StructType('<I'))),
428     ]))),
429     ('tx_outs', ListType(ComposedType([
430         ('value', StructType('<Q')),
431         ('script', VarStrType()),
432     ]))),
433     ('lock_time', StructType('<I')),
434 ])
435
436 merkle_branch_type = ListType(HashType())
437
438 merkle_tx_type = ComposedType([
439     ('tx', tx_type),
440     ('block_hash', HashType()),
441     ('merkle_branch', merkle_branch_type),
442     ('index', StructType('<i')),
443 ])
444
445 block_header_type = ComposedType([
446     ('version', StructType('<I')),
447     ('previous_block', PossiblyNoneType(0, HashType())),
448     ('merkle_root', HashType()),
449     ('timestamp', StructType('<I')),
450     ('target', FloatingIntegerType()),
451     ('nonce', StructType('<I')),
452 ])
453
454 block_type = ComposedType([
455     ('header', block_header_type),
456     ('txs', ListType(tx_type)),
457 ])
458
459 aux_pow_type = ComposedType([
460     ('merkle_tx', merkle_tx_type),
461     ('merkle_branch', merkle_branch_type),
462     ('index', StructType('<i')),
463     ('parent_block_header', block_header_type),
464 ])
465
466
467 merkle_record_type = ComposedType([
468     ('left', HashType()),
469     ('right', HashType()),
470 ])
471
472 def merkle_hash(tx_list):
473     if not tx_list:
474         return 0
475     hash_list = map(tx_type.hash256, tx_list)
476     while len(hash_list) > 1:
477         hash_list = [merkle_record_type.hash256(dict(left=left, right=left if right is None else right))
478             for left, right in zip(hash_list[::2], hash_list[1::2] + [None])]
479     return hash_list[0]
480
481 def calculate_merkle_branch(txs, index):
482     # XXX optimize this
483     
484     hash_list = [(tx_type.hash256(tx), i == index, []) for i, tx in enumerate(txs)]
485     
486     while len(hash_list) > 1:
487         hash_list = [
488             (
489                 merkle_record_type.hash256(dict(left=left, right=right)),
490                 left_f or right_f,
491                 (left_l if left_f else right_l) + [dict(side=1, hash=right) if left_f else dict(side=0, hash=left)],
492             )
493             for (left, left_f, left_l), (right, right_f, right_l) in
494                 zip(hash_list[::2], hash_list[1::2] + [hash_list[::2][-1]])
495         ]
496     
497     res = [x['hash'] for x in hash_list[0][2]]
498     
499     assert hash_list[0][1]
500     assert check_merkle_branch(txs[index], index, res) == hash_list[0][0]
501     assert index == sum(k*2**i for i, k in enumerate([1-x['side'] for x in hash_list[0][2]]))
502     
503     return res
504
505 def check_merkle_branch(tx, index, merkle_branch):
506     return reduce(lambda c, (i, h): merkle_record_type.hash256(
507         dict(left=h, right=c) if 2**i & index else
508         dict(left=c, right=h)
509     ), enumerate(merkle_branch), tx_type.hash256(tx))
510
511 def target_to_average_attempts(target):
512     return 2**256//(target + 1)
513
514 def target_to_difficulty(target):
515     return (0xffff0000 * 2**(256-64) + 1)/(target + 1)
516
517 # tx
518
519 def tx_get_sigop_count(tx):
520     return sum(script.get_sigop_count(txin['script']) for txin in tx['tx_ins']) + sum(script.get_sigop_count(txout['script']) for txout in tx['tx_outs'])
521
522 # human addresses
523
524 human_address_type = ChecksummedType(ComposedType([
525     ('version', StructType('<B')),
526     ('pubkey_hash', ShortHashType()),
527 ]))
528
529 pubkey_type = FixedStrType(65)
530
531 def pubkey_hash_to_address(pubkey_hash, net):
532     return human_address_type.pack_base58(dict(version=net.BITCOIN_ADDRESS_VERSION, pubkey_hash=pubkey_hash))
533
534 def pubkey_to_address(pubkey, net):
535     return pubkey_hash_to_address(pubkey_type.hash160(pubkey), net)
536
537 def address_to_pubkey_hash(address, net):
538     x = human_address_type.unpack_base58(address)
539     if x['version'] != net.BITCOIN_ADDRESS_VERSION:
540         raise ValueError('address not for this net!')
541     return x['pubkey_hash']
542
543 # transactions
544
545 def pubkey_to_script2(pubkey):
546     return ('\x41' + pubkey_type.pack(pubkey)) + '\xac'
547
548 def pubkey_hash_to_script2(pubkey_hash):
549     return '\x76\xa9' + ('\x14' + ShortHashType().pack(pubkey_hash)) + '\x88\xac'
550
551 def script2_to_human(script2, net):
552     try:
553         pubkey = script2[1:-1]
554         script2_test = pubkey_to_script2(pubkey)
555     except:
556         pass
557     else:
558         if script2_test == script2:
559             return 'Pubkey. Address: %s' % (pubkey_to_address(pubkey, net),)
560     
561     try:
562         pubkey_hash = ShortHashType().unpack(script2[3:-2])
563         script2_test2 = pubkey_hash_to_script2(pubkey_hash)
564     except:
565         pass
566     else:
567         if script2_test2 == script2:
568             return 'Address. Address: %s' % (pubkey_hash_to_address(pubkey_hash, net),)
569     
570     return 'Unknown. Script: %s'  % (script2.encode('hex'),)