import struct
import p2pool
+from p2pool.util import memoize
class EarlyEnd(Exception):
pass
if p2pool.DEBUG:
if self._pack(obj) != data:
- raise AssertionError()
+ raise AssertionError()
return obj
raise AssertionError((self._unpack(data), obj))
return data
+
+ def packed_size(self, obj):
+ if hasattr(obj, '_packed_size') and obj._packed_size is not None:
+ type_obj, packed_size = obj._packed_size
+ if type_obj is self:
+ return packed_size
+
+ packed_size = len(self.pack(obj))
+
+ if hasattr(obj, '_packed_size'):
+ obj._packed_size = self, packed_size
+
+ return packed_size
class VarIntType(Type):
def read(self, file):
return self._inner_size.write(file, len(item)), item
class EnumType(Type):
- def __init__(self, inner, values):
+ def __init__(self, inner, pack_to_unpack):
self.inner = inner
- self.values = values
+ self.pack_to_unpack = pack_to_unpack
- keys = {}
- for k, v in values.iteritems():
- if v in keys:
- raise ValueError('duplicate value in values')
- keys[v] = k
- self.keys = keys
+ self.unpack_to_pack = {}
+ for k, v in pack_to_unpack.iteritems():
+ if v in self.unpack_to_pack:
+ raise ValueError('duplicate value in pack_to_unpack')
+ self.unpack_to_pack[v] = k
def read(self, file):
data, file = self.inner.read(file)
- if data not in self.keys:
- raise ValueError('enum data (%r) not in values (%r)' % (data, self.values))
- return self.keys[data], file
+ if data not in self.pack_to_unpack:
+ raise ValueError('enum data (%r) not in pack_to_unpack (%r)' % (data, self.pack_to_unpack))
+ return self.pack_to_unpack[data], file
def write(self, file, item):
- if item not in self.values:
- raise ValueError('enum item (%r) not in values (%r)' % (item, self.values))
- return self.inner.write(file, self.values[item])
+ if item not in self.unpack_to_pack:
+ raise ValueError('enum item (%r) not in unpack_to_pack (%r)' % (item, self.unpack_to_pack))
+ return self.inner.write(file, self.unpack_to_pack[item])
class ListType(Type):
_inner_size = VarIntType()
def read(self, file):
length, file = self._inner_size.read(file)
- res = []
+ res = [None]*length
for i in xrange(length):
- item, file = self.type.read(file)
- res.append(item)
+ res[i], file = self.type.read(file)
return res, file
def write(self, file, item):
def write(self, file, item):
return file, struct.pack(self.desc, item)
+@memoize.fast_memoize_multiple_args
class IntType(Type):
__slots__ = 'bytes step format_str max'.split(' ')
self.max = 2**bits
def read(self, file, b2a_hex=binascii.b2a_hex):
+ if self.bytes == 0:
+ return 0, file
data, file = read(file, self.bytes)
return int(b2a_hex(data[::self.step]), 16), file
def write(self, file, item, a2b_hex=binascii.a2b_hex):
+ if self.bytes == 0:
+ return file
if not 0 <= item < self.max:
raise ValueError('invalid int value - %r' % (item,))
return file, a2b_hex(self.format_str % (item,))[::self.step]
class IPV6AddressType(Type):
def read(self, file):
data, file = read(file, 16)
- if data[:12] != '00000000000000000000ffff'.decode('hex'):
- raise ValueError('ipv6 addresses not supported yet')
- return '.'.join(str(ord(x)) for x in data[12:]), file
+ if data[:12] == '00000000000000000000ffff'.decode('hex'):
+ return '.'.join(str(ord(x)) for x in data[12:]), file
+ return ':'.join(data[i*2:(i+1)*2].encode('hex') for i in xrange(8)), file
def write(self, file, item):
- bits = map(int, item.split('.'))
- if len(bits) != 4:
- raise ValueError('invalid address: %r' % (bits,))
- data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
+ if ':' in item:
+ data = ''.join(item.replace(':', '')).decode('hex')
+ else:
+ bits = map(int, item.split('.'))
+ if len(bits) != 4:
+ raise ValueError('invalid address: %r' % (bits,))
+ data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
assert len(data) == 16, len(data)
return file, data
def get_record(fields):
fields = tuple(sorted(fields))
- if 'keys' in fields:
+ if 'keys' in fields or '_packed_size' in fields:
raise ValueError()
if fields not in _record_types:
class _Record(object):
- __slots__ = fields
+ __slots__ = fields + ('_packed_size',)
+ def __init__(self):
+ self._packed_size = None
def __repr__(self):
return repr(dict(self))
def __getitem__(self, key):
def __setitem__(self, key, value):
setattr(self, key, value)
#def __iter__(self):
- # for field in self.__slots__:
+ # for field in fields:
# yield field, getattr(self, field)
def keys(self):
- return self.__slots__
+ return fields
def get(self, key, default=None):
return getattr(self, key, default)
def __eq__(self, other):
if isinstance(other, dict):
return dict(self) == other
elif isinstance(other, _Record):
- return all(self[k] == other[k] for k in self.keys())
+ for k in fields:
+ if getattr(self, k) != getattr(other, k):
+ return False
+ return True
+ elif other is None:
+ return False
raise TypeError()
def __ne__(self, other):
return not (self == other)
_record_types[fields] = _Record
- return _record_types[fields]()
+ return _record_types[fields]
class ComposedType(Type):
def __init__(self, fields):
- self.fields = tuple(fields)
+ self.fields = list(fields)
self.field_names = set(k for k, v in fields)
+ self.record_type = get_record(k for k, v in self.fields)
def read(self, file):
- item = get_record(k for k, v in self.fields)
+ item = self.record_type()
for key, type_ in self.fields:
item[key], file = type_.read(file)
return item, file
def write(self, file, item):
- assert set(item.keys()) == self.field_names
+ assert set(item.keys()) == self.field_names, (set(item.keys()) - self.field_names, self.field_names - set(item.keys()))
for key, type_ in self.fields:
file = type_.write(file, item[key])
return file