X-Git-Url: https://git.novaco.in/?a=blobdiff_plain;f=p2pool%2Futil%2Fpack.py;h=c7682580474da648424287c4f6b2221ce030b725;hb=c95830cd6486b848c9201668720072485b2a8aed;hp=c2d0ca12404b4826c8e97badec214799b9fd3939;hpb=4b08230de03af966d684912cb4764aeec4dc0cc6;p=p2pool.git diff --git a/p2pool/util/pack.py b/p2pool/util/pack.py index c2d0ca1..c768258 100644 --- a/p2pool/util/pack.py +++ b/p2pool/util/pack.py @@ -2,6 +2,7 @@ import binascii import struct import p2pool +from p2pool.util import memoize class EarlyEnd(Exception): pass @@ -63,7 +64,7 @@ class Type(object): if p2pool.DEBUG: if self._pack(obj) != data: - raise AssertionError() + raise AssertionError() return obj @@ -75,6 +76,19 @@ class Type(object): raise AssertionError((self._unpack(data), obj)) return data + + def packed_size(self, obj): + if hasattr(obj, '_packed_size') and obj._packed_size is not None: + type_obj, packed_size = obj._packed_size + if type_obj is self: + return packed_size + + packed_size = len(self.pack(obj)) + + if hasattr(obj, '_packed_size'): + obj._packed_size = self, packed_size + + return packed_size class VarIntType(Type): def read(self, file): @@ -119,44 +133,45 @@ class VarStrType(Type): return self._inner_size.write(file, len(item)), item class EnumType(Type): - def __init__(self, inner, values): + def __init__(self, inner, pack_to_unpack): self.inner = inner - self.values = values + self.pack_to_unpack = pack_to_unpack - keys = {} - for k, v in values.iteritems(): - if v in keys: - raise ValueError('duplicate value in values') - keys[v] = k - self.keys = keys + self.unpack_to_pack = {} + for k, v in pack_to_unpack.iteritems(): + if v in self.unpack_to_pack: + raise ValueError('duplicate value in pack_to_unpack') + self.unpack_to_pack[v] = k def read(self, file): data, file = self.inner.read(file) - if data not in self.keys: - raise ValueError('enum data (%r) not in values (%r)' % (data, self.values)) - return self.keys[data], file + if data not in self.pack_to_unpack: + raise ValueError('enum data (%r) not in pack_to_unpack (%r)' % (data, self.pack_to_unpack)) + return self.pack_to_unpack[data], file def write(self, file, item): - if item not in self.values: - raise ValueError('enum item (%r) not in values (%r)' % (item, self.values)) - return self.inner.write(file, self.values[item]) + if item not in self.unpack_to_pack: + raise ValueError('enum item (%r) not in unpack_to_pack (%r)' % (item, self.unpack_to_pack)) + return self.inner.write(file, self.unpack_to_pack[item]) class ListType(Type): _inner_size = VarIntType() - def __init__(self, type): + def __init__(self, type, mul=1): self.type = type + self.mul = mul def read(self, file): length, file = self._inner_size.read(file) - res = [] + length *= self.mul + res = [None]*length for i in xrange(length): - item, file = self.type.read(file) - res.append(item) + res[i], file = self.type.read(file) return res, file def write(self, file, item): - file = self._inner_size.write(file, len(item)) + assert len(item) % self.mul == 0 + file = self._inner_size.write(file, len(item)//self.mul) for subitem in item: file = self.type.write(file, subitem) return file @@ -175,6 +190,7 @@ class StructType(Type): def write(self, file, item): return file, struct.pack(self.desc, item) +@memoize.fast_memoize_multiple_args class IntType(Type): __slots__ = 'bytes step format_str max'.split(' ') @@ -195,10 +211,14 @@ class IntType(Type): self.max = 2**bits def read(self, file, b2a_hex=binascii.b2a_hex): + if self.bytes == 0: + return 0, file data, file = read(file, self.bytes) return int(b2a_hex(data[::self.step]), 16), file def write(self, file, item, a2b_hex=binascii.a2b_hex): + if self.bytes == 0: + return file if not 0 <= item < self.max: raise ValueError('invalid int value - %r' % (item,)) return file, a2b_hex(self.format_str % (item,))[::self.step] @@ -206,15 +226,18 @@ class IntType(Type): class IPV6AddressType(Type): def read(self, file): data, file = read(file, 16) - if data[:12] != '00000000000000000000ffff'.decode('hex'): - raise ValueError('ipv6 addresses not supported yet') - return '.'.join(str(ord(x)) for x in data[12:]), file + if data[:12] == '00000000000000000000ffff'.decode('hex'): + return '.'.join(str(ord(x)) for x in data[12:]), file + return ':'.join(data[i*2:(i+1)*2].encode('hex') for i in xrange(8)), file def write(self, file, item): - bits = map(int, item.split('.')) - if len(bits) != 4: - raise ValueError('invalid address: %r' % (bits,)) - data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits) + if ':' in item: + data = ''.join(item.replace(':', '')).decode('hex') + else: + bits = map(int, item.split('.')) + if len(bits) != 4: + raise ValueError('invalid address: %r' % (bits,)) + data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits) assert len(data) == 16, len(data) return file, data @@ -222,11 +245,13 @@ _record_types = {} def get_record(fields): fields = tuple(sorted(fields)) - if 'keys' in fields: + if 'keys' in fields or '_packed_size' in fields: raise ValueError() if fields not in _record_types: class _Record(object): - __slots__ = fields + __slots__ = fields + ('_packed_size',) + def __init__(self): + self._packed_size = None def __repr__(self): return repr(dict(self)) def __getitem__(self, key): @@ -234,36 +259,42 @@ def get_record(fields): def __setitem__(self, key, value): setattr(self, key, value) #def __iter__(self): - # for field in self.__slots__: + # for field in fields: # yield field, getattr(self, field) def keys(self): - return self.__slots__ + return fields def get(self, key, default=None): return getattr(self, key, default) def __eq__(self, other): if isinstance(other, dict): return dict(self) == other elif isinstance(other, _Record): - return all(self[k] == other[k] for k in self.keys()) + for k in fields: + if getattr(self, k) != getattr(other, k): + return False + return True + elif other is None: + return False raise TypeError() def __ne__(self, other): return not (self == other) _record_types[fields] = _Record - return _record_types[fields]() + return _record_types[fields] class ComposedType(Type): def __init__(self, fields): - self.fields = tuple(fields) + self.fields = list(fields) self.field_names = set(k for k, v in fields) + self.record_type = get_record(k for k, v in self.fields) def read(self, file): - item = get_record(k for k, v in self.fields) + item = self.record_type() for key, type_ in self.fields: item[key], file = type_.read(file) return item, file def write(self, file, item): - assert set(item.keys()) == self.field_names + assert set(item.keys()) == self.field_names, (set(item.keys()) - self.field_names, self.field_names - set(item.keys())) for key, type_ in self.fields: file = type_.write(file, item[key]) return file