working ... sorry for bad log messages\!
[p2pool.git] / p2pool / bitcoin / data.py
index 59a7f8d..5d302af 100644 (file)
@@ -3,9 +3,10 @@ from __future__ import division
 import struct
 import StringIO
 import hashlib
+import warnings
 
 from . import base58
-from p2pool.util import bases
+from p2pool.util import bases, expiring_dict, math
 
 class EarlyEnd(Exception):
     pass
@@ -16,6 +17,14 @@ class LateEnd(Exception):
 class Type(object):
     # the same data can have only one unpacked representation, but multiple packed binary representations
     
+    #def __hash__(self):
+    #    return hash(tuple(self.__dict__.items()))
+    
+    #def __eq__(self, other):
+    #    if not isinstance(other, Type):
+    #        raise NotImplementedError()
+    #    return self.__dict__ == other.__dict__
+    
     def _unpack(self, data):
         f = StringIO.StringIO(data)
         
@@ -28,7 +37,11 @@ class Type(object):
     
     def unpack(self, data):
         obj = self._unpack(data)
-        assert self._unpack(self._pack(obj)) == obj
+        
+        data2 = self._pack(obj)
+        if data2 != data:
+            assert self._unpack(data2) == obj
+        
         return obj
     
     def _pack(self, obj):
@@ -43,6 +56,7 @@ class Type(object):
     def pack(self, obj):
         data = self._pack(obj)
         assert self._unpack(data) == obj
+                
         return data
     
     
@@ -54,10 +68,10 @@ class Type(object):
         
     
     def hash160(self, obj):
-        return ripemdsha(self.pack(obj))
+        return ShortHashType().unpack(hashlib.new('ripemd160', hashlib.sha256(self.pack(obj)).digest()).digest())
     
     def hash256(self, obj):
-        return doublesha(self.pack(obj))
+        return HashType().unpack(hashlib.sha256(hashlib.sha256(self.pack(obj)).digest()).digest())
 
 class VarIntType(Type):
     def read(self, file):
@@ -141,6 +155,8 @@ class HashType(Type):
     def write(self, file, item):
         if item >= 2**256:
             raise ValueError("invalid hash value")
+        if item != 0 and item < 2**160:
+            warnings.warn("very low hash value - maybe you meant to use ShortHashType?")
         file.write(('%064x' % (item,)).decode('hex')[::-1])
 
 class ShortHashType(Type):
@@ -194,13 +210,9 @@ class IPV6AddressType(Type):
             raise EarlyEnd()
         if data[:12] != '00000000000000000000ffff'.decode('hex'):
             raise ValueError("ipv6 addresses not supported yet")
-        return '::ffff:' + '.'.join(str(ord(x)) for x in data[12:])
+        return '.'.join(str(ord(x)) for x in data[12:])
     
     def write(self, file, item):
-        prefix = '::ffff:'
-        if not item.startswith(prefix):
-            raise ValueError("ipv6 addresses not supported yet")
-        item = item[len(prefix):]
         bits = map(int, item.split('.'))
         if len(bits) != 4:
             raise ValueError("invalid address: %r" % (bits,))
@@ -240,6 +252,64 @@ class ChecksummedType(Type):
         file.write(data)
         file.write(hashlib.sha256(hashlib.sha256(data).digest()).digest()[:4])
 
+class FloatingIntegerType(Type):
+    def read(self, file):
+        data = FixedStrType(4).read(file)
+        target = self._bits_to_target(data)
+        if self._target_to_bits(target) != data:
+            raise ValueError("bits in non-canonical form")
+        return target
+    
+    def write(self, file, item):
+        FixedStrType(4).write(file, self._target_to_bits(item))
+    
+    def truncate_to(self, x):
+        return self._bits_to_target(self._target_to_bits(x, _check=False))
+        
+    def _bits_to_target(self, bits, _check=True):
+        assert len(bits) == 4, repr(bits)
+        target1 = self._bits_to_target1(bits)
+        target2 = self._bits_to_target2(bits)
+        if target1 != target2:
+            raise ValueError()
+        if _check:
+            if self._target_to_bits(target1, _check=False) != bits:
+                raise ValueError()
+        return target1
+    
+    def _bits_to_target1(self, bits):
+        bits = bits[::-1]
+        length = ord(bits[0])
+        return bases.string_to_natural((bits[1:] + "\0"*length)[:length])
+
+    def _bits_to_target2(self, bits):
+        bits = struct.unpack("<I", bits)[0]
+        return math.shift_left(bits & 0x00ffffff, 8 * ((bits >> 24) - 3))
+
+    def _target_to_bits(self, target, _check=True):
+        n = bases.natural_to_string(target)
+        if n and ord(n[0]) >= 128:
+            n = "\x00" + n
+        bits = (chr(len(n)) + (n + 3*chr(0))[:3])[::-1]
+        if _check:
+            if self._bits_to_target(bits, _check=False) != target:
+                raise ValueError(repr((target, self._bits_to_target(bits, _check=False))))
+        return bits
+
+class PossiblyNone(Type):
+    def __init__(self, none_value, inner):
+        self.none_value = none_value
+        self.inner = inner
+    
+    def read(self, file):
+        value = self.inner.read(file)
+        return None if value == self.none_value else value
+    
+    def write(self, file, item):
+        if item == self.none_value:
+            raise ValueError("none_value used")
+        self.inner.write(file, self.none_value if item is None else item)
+
 address_type = ComposedType([
     ('services', StructType('<Q')),
     ('address', IPV6AddressType()),
@@ -249,12 +319,12 @@ address_type = ComposedType([
 tx_type = ComposedType([
     ('version', StructType('<I')),
     ('tx_ins', ListType(ComposedType([
-        ('previous_output', ComposedType([
+        ('previous_output', PossiblyNone(dict(hash=0, index=2**32 - 1), ComposedType([
             ('hash', HashType()),
             ('index', StructType('<I')),
-        ])),
+        ]))),
         ('script', VarStrType()),
-        ('sequence', StructType('<I')),
+        ('sequence', PossiblyNone(2**32 - 1, StructType('<I'))),
     ]))),
     ('tx_outs', ListType(ComposedType([
         ('value', StructType('<Q')),
@@ -265,10 +335,10 @@ tx_type = ComposedType([
 
 block_header_type = ComposedType([
     ('version', StructType('<I')),
-    ('previous_block', HashType()),
+    ('previous_block', PossiblyNone(0, HashType())),
     ('merkle_root', HashType()),
     ('timestamp', StructType('<I')),
-    ('bits', FixedStrType(4)),
+    ('target', FloatingIntegerType()),
     ('nonce', StructType('<I')),
 ])
 
@@ -277,11 +347,6 @@ block_type = ComposedType([
     ('txs', ListType(tx_type)),
 ])
 
-def doublesha(data):
-    return HashType().unpack(hashlib.sha256(hashlib.sha256(data).digest()).digest())
-
-def ripemdsha(data):
-    return ShortHashType().unpack(hashlib.new('ripemd160', hashlib.sha256(data).digest()).digest())
 
 merkle_record_type = ComposedType([
     ('left', HashType()),
@@ -289,50 +354,14 @@ merkle_record_type = ComposedType([
 ])
 
 def merkle_hash(tx_list):
-    hash_list = map(tx_hash, tx_list)
+    if not tx_list:
+        return 0
+    hash_list = map(tx_type.hash256, tx_list)
     while len(hash_list) > 1:
-        hash_list = [doublesha(merkle_record_type.pack(dict(left=left, right=left if right is None else right)))
+        hash_list = [merkle_record_type.hash256(dict(left=left, right=left if right is None else right))
             for left, right in zip(hash_list[::2], hash_list[1::2] + [None])]
     return hash_list[0]
 
-def tx_hash(tx):
-    return doublesha(tx_type.pack(tx))
-
-def block_hash(header):
-    return doublesha(block_header_type.pack(header))
-
-def shift_left(n, m):
-    # python: :(
-    if m < 0:
-        return n >> -m
-    return n << m
-
-def bits_to_target(bits):
-    bits = bits[::-1]
-    length = ord(bits[0])
-    return bases.string_to_natural((bits[1:] + "\0"*length)[:length])
-
-def old_bits_to_target(bits):
-    return shift_left(bits & 0x00ffffff, 8 * ((bits >> 24) - 3))
-
-def about_equal(a, b):
-    if a == b: return True
-    return abs(a-b)/((abs(a)+abs(b))/2) < .01
-
-def compress_target_to_bits(target): # loses precision
-    print
-    print "t", target
-    n = bases.natural_to_string(target)
-    print "n", n.encode('hex')
-    bits = chr(len(n)) + n[:3].ljust(3, '\0')
-    bits = bits[::-1]
-    print "bits", bits.encode('hex')
-    print "new", bits_to_target(bits)
-    print "old", old_bits_to_target(struct.unpack("<I", bits)[0])
-    assert about_equal(bits_to_target(bits), target), (bits_to_target(bits), target)
-    assert about_equal(old_bits_to_target(struct.unpack("<I", bits)[0]), target), (old_bits_to_target(struct.unpack("<I", bits)[0]), target)
-    return bits
-
 def target_to_average_attempts(target):
     return 2**256//(target + 1)
 
@@ -364,4 +393,3 @@ class Testnet(object):
     BITCOIN_P2P_PREFIX = 'fabfb5da'.decode('hex')
     BITCOIN_P2P_PORT = 18333
     BITCOIN_ADDRESS_VERSION = 111
-