to reduce memory usage, keep tx pointers in interleaved list instead of each in its...
[p2pool.git] / p2pool / util / pack.py
index c2d0ca1..c768258 100644 (file)
@@ -2,6 +2,7 @@ import binascii
 import struct
 
 import p2pool
+from p2pool.util import memoize
 
 class EarlyEnd(Exception):
     pass
@@ -63,7 +64,7 @@ class Type(object):
         
         if p2pool.DEBUG:
             if self._pack(obj) != data:
-                    raise AssertionError()
+                raise AssertionError()
         
         return obj
     
@@ -75,6 +76,19 @@ class Type(object):
                 raise AssertionError((self._unpack(data), obj))
         
         return data
+    
+    def packed_size(self, obj):
+        if hasattr(obj, '_packed_size') and obj._packed_size is not None:
+            type_obj, packed_size = obj._packed_size
+            if type_obj is self:
+                return packed_size
+        
+        packed_size = len(self.pack(obj))
+        
+        if hasattr(obj, '_packed_size'):
+            obj._packed_size = self, packed_size
+        
+        return packed_size
 
 class VarIntType(Type):
     def read(self, file):
@@ -119,44 +133,45 @@ class VarStrType(Type):
         return self._inner_size.write(file, len(item)), item
 
 class EnumType(Type):
-    def __init__(self, inner, values):
+    def __init__(self, inner, pack_to_unpack):
         self.inner = inner
-        self.values = values
+        self.pack_to_unpack = pack_to_unpack
         
-        keys = {}
-        for k, v in values.iteritems():
-            if v in keys:
-                raise ValueError('duplicate value in values')
-            keys[v] = k
-        self.keys = keys
+        self.unpack_to_pack = {}
+        for k, v in pack_to_unpack.iteritems():
+            if v in self.unpack_to_pack:
+                raise ValueError('duplicate value in pack_to_unpack')
+            self.unpack_to_pack[v] = k
     
     def read(self, file):
         data, file = self.inner.read(file)
-        if data not in self.keys:
-            raise ValueError('enum data (%r) not in values (%r)' % (data, self.values))
-        return self.keys[data], file
+        if data not in self.pack_to_unpack:
+            raise ValueError('enum data (%r) not in pack_to_unpack (%r)' % (data, self.pack_to_unpack))
+        return self.pack_to_unpack[data], file
     
     def write(self, file, item):
-        if item not in self.values:
-            raise ValueError('enum item (%r) not in values (%r)' % (item, self.values))
-        return self.inner.write(file, self.values[item])
+        if item not in self.unpack_to_pack:
+            raise ValueError('enum item (%r) not in unpack_to_pack (%r)' % (item, self.unpack_to_pack))
+        return self.inner.write(file, self.unpack_to_pack[item])
 
 class ListType(Type):
     _inner_size = VarIntType()
     
-    def __init__(self, type):
+    def __init__(self, type, mul=1):
         self.type = type
+        self.mul = mul
     
     def read(self, file):
         length, file = self._inner_size.read(file)
-        res = []
+        length *= self.mul
+        res = [None]*length
         for i in xrange(length):
-            item, file = self.type.read(file)
-            res.append(item)
+            res[i], file = self.type.read(file)
         return res, file
     
     def write(self, file, item):
-        file = self._inner_size.write(file, len(item))
+        assert len(item) % self.mul == 0
+        file = self._inner_size.write(file, len(item)//self.mul)
         for subitem in item:
             file = self.type.write(file, subitem)
         return file
@@ -175,6 +190,7 @@ class StructType(Type):
     def write(self, file, item):
         return file, struct.pack(self.desc, item)
 
+@memoize.fast_memoize_multiple_args
 class IntType(Type):
     __slots__ = 'bytes step format_str max'.split(' ')
     
@@ -195,10 +211,14 @@ class IntType(Type):
         self.max = 2**bits
     
     def read(self, file, b2a_hex=binascii.b2a_hex):
+        if self.bytes == 0:
+            return 0, file
         data, file = read(file, self.bytes)
         return int(b2a_hex(data[::self.step]), 16), file
     
     def write(self, file, item, a2b_hex=binascii.a2b_hex):
+        if self.bytes == 0:
+            return file
         if not 0 <= item < self.max:
             raise ValueError('invalid int value - %r' % (item,))
         return file, a2b_hex(self.format_str % (item,))[::self.step]
@@ -206,15 +226,18 @@ class IntType(Type):
 class IPV6AddressType(Type):
     def read(self, file):
         data, file = read(file, 16)
-        if data[:12] != '00000000000000000000ffff'.decode('hex'):
-            raise ValueError('ipv6 addresses not supported yet')
-        return '.'.join(str(ord(x)) for x in data[12:]), file
+        if data[:12] == '00000000000000000000ffff'.decode('hex'):
+            return '.'.join(str(ord(x)) for x in data[12:]), file
+        return ':'.join(data[i*2:(i+1)*2].encode('hex') for i in xrange(8)), file
     
     def write(self, file, item):
-        bits = map(int, item.split('.'))
-        if len(bits) != 4:
-            raise ValueError('invalid address: %r' % (bits,))
-        data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
+        if ':' in item:
+            data = ''.join(item.replace(':', '')).decode('hex')
+        else:
+            bits = map(int, item.split('.'))
+            if len(bits) != 4:
+                raise ValueError('invalid address: %r' % (bits,))
+            data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
         assert len(data) == 16, len(data)
         return file, data
 
@@ -222,11 +245,13 @@ _record_types = {}
 
 def get_record(fields):
     fields = tuple(sorted(fields))
-    if 'keys' in fields:
+    if 'keys' in fields or '_packed_size' in fields:
         raise ValueError()
     if fields not in _record_types:
         class _Record(object):
-            __slots__ = fields
+            __slots__ = fields + ('_packed_size',)
+            def __init__(self):
+                self._packed_size = None
             def __repr__(self):
                 return repr(dict(self))
             def __getitem__(self, key):
@@ -234,36 +259,42 @@ def get_record(fields):
             def __setitem__(self, key, value):
                 setattr(self, key, value)
             #def __iter__(self):
-            #    for field in self.__slots__:
+            #    for field in fields:
             #        yield field, getattr(self, field)
             def keys(self):
-                return self.__slots__
+                return fields
             def get(self, key, default=None):
                 return getattr(self, key, default)
             def __eq__(self, other):
                 if isinstance(other, dict):
                     return dict(self) == other
                 elif isinstance(other, _Record):
-                    return all(self[k] == other[k] for k in self.keys())
+                    for k in fields:
+                        if getattr(self, k) != getattr(other, k):
+                            return False
+                    return True
+                elif other is None:
+                    return False
                 raise TypeError()
             def __ne__(self, other):
                 return not (self == other)
         _record_types[fields] = _Record
-    return _record_types[fields]()
+    return _record_types[fields]
 
 class ComposedType(Type):
     def __init__(self, fields):
-        self.fields = tuple(fields)
+        self.fields = list(fields)
         self.field_names = set(k for k, v in fields)
+        self.record_type = get_record(k for k, v in self.fields)
     
     def read(self, file):
-        item = get_record(k for k, v in self.fields)
+        item = self.record_type()
         for key, type_ in self.fields:
             item[key], file = type_.read(file)
         return item, file
     
     def write(self, file, item):
-        assert set(item.keys()) == self.field_names
+        assert set(item.keys()) == self.field_names, (set(item.keys()) - self.field_names, self.field_names - set(item.keys()))
         for key, type_ in self.fields:
             file = type_.write(file, item[key])
         return file