import binascii
-import hashlib
import struct
-from p2pool.bitcoin import base58
import p2pool
+from p2pool.util import memoize
class EarlyEnd(Exception):
pass
class Type(object):
__slots__ = []
- # the same data can have only one unpacked representation, but multiple packed binary representations
-
def __hash__(self):
rval = getattr(self, '_hash', None)
if rval is None:
def __ne__(self, other):
return not (self == other)
- def _unpack(self, data):
+ def _unpack(self, data, ignore_trailing=False):
obj, (data2, pos) = self.read((data, 0))
assert data2 is data
- if pos != len(data):
+ if pos != len(data) and not ignore_trailing:
raise LateEnd()
return obj
return ''.join(res)
- def unpack(self, data):
- obj = self._unpack(data)
+ def unpack(self, data, ignore_trailing=False):
+ obj = self._unpack(data, ignore_trailing)
if p2pool.DEBUG:
- data2 = self._pack(obj)
- if data2 != data:
- if self._unpack(data2) != obj:
- raise AssertionError()
+ packed = self._pack(obj)
+ good = data.startswith(packed) if ignore_trailing else data == packed
+ if not good:
+ raise AssertionError()
return obj
return data
-
- def pack_base58(self, obj):
- return base58.encode(self.pack(obj))
-
- def unpack_base58(self, base58_data):
- return self.unpack(base58.decode(base58_data))
-
-
- def hash160(self, obj):
- return IntType(160).unpack(hashlib.new('ripemd160', hashlib.sha256(self.pack(obj)).digest()).digest())
-
- def hash256(self, obj):
- return IntType(256).unpack(hashlib.sha256(hashlib.sha256(self.pack(obj)).digest()).digest())
-
- def scrypt(self, obj):
- import ltc_scrypt
- return IntType(256).unpack(ltc_scrypt.getPoWHash(self.pack(obj)))
+ def packed_size(self, obj):
+ if hasattr(obj, '_packed_size') and obj._packed_size is not None:
+ type_obj, packed_size = obj._packed_size
+ if type_obj is self:
+ return packed_size
+
+ packed_size = len(self.pack(obj))
+
+ if hasattr(obj, '_packed_size'):
+ obj._packed_size = self, packed_size
+
+ return packed_size
class VarIntType(Type):
- # redundancy doesn't matter here because bitcoin and p2pool both reencode before hashing
def read(self, file):
data, file = read(file, 1)
first = ord(data)
if first < 0xfd:
return first, file
- elif first == 0xfd:
- desc, length = '<H', 2
+ if first == 0xfd:
+ desc, length, minimum = '<H', 2, 0xfd
elif first == 0xfe:
- desc, length = '<I', 4
+ desc, length, minimum = '<I', 4, 2**16
elif first == 0xff:
- desc, length = '<Q', 8
+ desc, length, minimum = '<Q', 8, 2**32
else:
raise AssertionError()
- data, file = read(file, length)
- return struct.unpack(desc, data)[0], file
+ data2, file = read(file, length)
+ res, = struct.unpack(desc, data2)
+ if res < minimum:
+ raise AssertionError('VarInt not canonically packed')
+ return res, file
def write(self, file, item):
if item < 0xfd:
- file = file, struct.pack('<B', item)
+ return file, struct.pack('<B', item)
elif item <= 0xffff:
- file = file, struct.pack('<BH', 0xfd, item)
+ return file, struct.pack('<BH', 0xfd, item)
elif item <= 0xffffffff:
- file = file, struct.pack('<BI', 0xfe, item)
+ return file, struct.pack('<BI', 0xfe, item)
elif item <= 0xffffffffffffffff:
- file = file, struct.pack('<BQ', 0xff, item)
+ return file, struct.pack('<BQ', 0xff, item)
else:
raise ValueError('int too large for varint')
- return file
class VarStrType(Type):
_inner_size = VarIntType()
def write(self, file, item):
return self._inner_size.write(file, len(item)), item
-class PassthruType(Type):
- def read(self, file):
- return read(file, size(file))
-
- def write(self, file, item):
- return file, item
-
class EnumType(Type):
- def __init__(self, inner, values):
+ def __init__(self, inner, pack_to_unpack):
self.inner = inner
- self.values = values
+ self.pack_to_unpack = pack_to_unpack
- keys = {}
- for k, v in values.iteritems():
- if v in keys:
- raise ValueError('duplicate value in values')
- keys[v] = k
- self.keys = keys
+ self.unpack_to_pack = {}
+ for k, v in pack_to_unpack.iteritems():
+ if v in self.unpack_to_pack:
+ raise ValueError('duplicate value in pack_to_unpack')
+ self.unpack_to_pack[v] = k
def read(self, file):
data, file = self.inner.read(file)
- if data not in self.keys:
- raise ValueError('enum data (%r) not in values (%r)' % (data, self.values))
- return self.keys[data], file
+ if data not in self.pack_to_unpack:
+ raise ValueError('enum data (%r) not in pack_to_unpack (%r)' % (data, self.pack_to_unpack))
+ return self.pack_to_unpack[data], file
def write(self, file, item):
- if item not in self.values:
- raise ValueError('enum item (%r) not in values (%r)' % (item, self.values))
- return self.inner.write(file, self.values[item])
+ if item not in self.unpack_to_pack:
+ raise ValueError('enum item (%r) not in unpack_to_pack (%r)' % (item, self.unpack_to_pack))
+ return self.inner.write(file, self.unpack_to_pack[item])
class ListType(Type):
_inner_size = VarIntType()
- def __init__(self, type):
+ def __init__(self, type, mul=1):
self.type = type
+ self.mul = mul
def read(self, file):
length, file = self._inner_size.read(file)
- res = []
+ length *= self.mul
+ res = [None]*length
for i in xrange(length):
- item, file = self.type.read(file)
- res.append(item)
+ res[i], file = self.type.read(file)
return res, file
def write(self, file, item):
- file = self._inner_size.write(file, len(item))
+ assert len(item) % self.mul == 0
+ file = self._inner_size.write(file, len(item)//self.mul)
for subitem in item:
file = self.type.write(file, subitem)
return file
def write(self, file, item):
return file, struct.pack(self.desc, item)
+@memoize.fast_memoize_multiple_args
class IntType(Type):
__slots__ = 'bytes step format_str max'.split(' ')
self.max = 2**bits
def read(self, file, b2a_hex=binascii.b2a_hex):
+ if self.bytes == 0:
+ return 0, file
data, file = read(file, self.bytes)
return int(b2a_hex(data[::self.step]), 16), file
def write(self, file, item, a2b_hex=binascii.a2b_hex):
+ if self.bytes == 0:
+ return file
if not 0 <= item < self.max:
raise ValueError('invalid int value - %r' % (item,))
return file, a2b_hex(self.format_str % (item,))[::self.step]
class IPV6AddressType(Type):
def read(self, file):
data, file = read(file, 16)
- if data[:12] != '00000000000000000000ffff'.decode('hex'):
- raise ValueError('ipv6 addresses not supported yet')
- return '.'.join(str(ord(x)) for x in data[12:]), file
+ if data[:12] == '00000000000000000000ffff'.decode('hex'):
+ return '.'.join(str(ord(x)) for x in data[12:]), file
+ return ':'.join(data[i*2:(i+1)*2].encode('hex') for i in xrange(8)), file
def write(self, file, item):
- bits = map(int, item.split('.'))
- if len(bits) != 4:
- raise ValueError('invalid address: %r' % (bits,))
- data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
+ if ':' in item:
+ data = ''.join(item.replace(':', '')).decode('hex')
+ else:
+ bits = map(int, item.split('.'))
+ if len(bits) != 4:
+ raise ValueError('invalid address: %r' % (bits,))
+ data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
assert len(data) == 16, len(data)
return file, data
def get_record(fields):
fields = tuple(sorted(fields))
- if 'keys' in fields:
+ if 'keys' in fields or '_packed_size' in fields:
raise ValueError()
if fields not in _record_types:
class _Record(object):
- __slots__ = fields
+ __slots__ = fields + ('_packed_size',)
+ def __init__(self):
+ self._packed_size = None
def __repr__(self):
return repr(dict(self))
def __getitem__(self, key):
def __setitem__(self, key, value):
setattr(self, key, value)
#def __iter__(self):
- # for field in self.__slots__:
+ # for field in fields:
# yield field, getattr(self, field)
def keys(self):
- return self.__slots__
+ return fields
+ def get(self, key, default=None):
+ return getattr(self, key, default)
def __eq__(self, other):
if isinstance(other, dict):
return dict(self) == other
elif isinstance(other, _Record):
- return all(self[k] == other[k] for k in self.keys())
+ for k in fields:
+ if getattr(self, k) != getattr(other, k):
+ return False
+ return True
+ elif other is None:
+ return False
raise TypeError()
def __ne__(self, other):
return not (self == other)
_record_types[fields] = _Record
- return _record_types[fields]()
+ return _record_types[fields]
class ComposedType(Type):
def __init__(self, fields):
- self.fields = tuple(fields)
+ self.fields = list(fields)
+ self.field_names = set(k for k, v in fields)
+ self.record_type = get_record(k for k, v in self.fields)
def read(self, file):
- item = get_record(k for k, v in self.fields)
+ item = self.record_type()
for key, type_ in self.fields:
item[key], file = type_.read(file)
return item, file
def write(self, file, item):
+ assert set(item.keys()) == self.field_names, (set(item.keys()) - self.field_names, self.field_names - set(item.keys()))
for key, type_ in self.fields:
file = type_.write(file, item[key])
return file
if item == self.none_value:
raise ValueError('none_value used')
return self.inner.write(file, self.none_value if item is None else item)
+
+class FixedStrType(Type):
+ def __init__(self, length):
+ self.length = length
+
+ def read(self, file):
+ return read(file, self.length)
+
+ def write(self, file, item):
+ if len(item) != self.length:
+ raise ValueError('incorrect length item!')
+ return file, item