to reduce memory usage, keep tx pointers in interleaved list instead of each in its...
[p2pool.git] / p2pool / util / pack.py
1 import binascii
2 import struct
3
4 import p2pool
5 from p2pool.util import memoize
6
7 class EarlyEnd(Exception):
8     pass
9
10 class LateEnd(Exception):
11     pass
12
13 def read((data, pos), length):
14     data2 = data[pos:pos + length]
15     if len(data2) != length:
16         raise EarlyEnd()
17     return data2, (data, pos + length)
18
19 def size((data, pos)):
20     return len(data) - pos
21
22 class Type(object):
23     __slots__ = []
24     
25     def __hash__(self):
26         rval = getattr(self, '_hash', None)
27         if rval is None:
28             try:
29                 rval = self._hash = hash((type(self), frozenset(self.__dict__.items())))
30             except:
31                 print self.__dict__
32                 raise
33         return rval
34     
35     def __eq__(self, other):
36         return type(other) is type(self) and other.__dict__ == self.__dict__
37     
38     def __ne__(self, other):
39         return not (self == other)
40     
41     def _unpack(self, data):
42         obj, (data2, pos) = self.read((data, 0))
43         
44         assert data2 is data
45         
46         if pos != len(data):
47             raise LateEnd()
48         
49         return obj
50     
51     def _pack(self, obj):
52         f = self.write(None, obj)
53         
54         res = []
55         while f is not None:
56             res.append(f[1])
57             f = f[0]
58         res.reverse()
59         return ''.join(res)
60     
61     
62     def unpack(self, data):
63         obj = self._unpack(data)
64         
65         if p2pool.DEBUG:
66             if self._pack(obj) != data:
67                 raise AssertionError()
68         
69         return obj
70     
71     def pack(self, obj):
72         data = self._pack(obj)
73         
74         if p2pool.DEBUG:
75             if self._unpack(data) != obj:
76                 raise AssertionError((self._unpack(data), obj))
77         
78         return data
79     
80     def packed_size(self, obj):
81         if hasattr(obj, '_packed_size') and obj._packed_size is not None:
82             type_obj, packed_size = obj._packed_size
83             if type_obj is self:
84                 return packed_size
85         
86         packed_size = len(self.pack(obj))
87         
88         if hasattr(obj, '_packed_size'):
89             obj._packed_size = self, packed_size
90         
91         return packed_size
92
93 class VarIntType(Type):
94     def read(self, file):
95         data, file = read(file, 1)
96         first = ord(data)
97         if first < 0xfd:
98             return first, file
99         if first == 0xfd:
100             desc, length, minimum = '<H', 2, 0xfd
101         elif first == 0xfe:
102             desc, length, minimum = '<I', 4, 2**16
103         elif first == 0xff:
104             desc, length, minimum = '<Q', 8, 2**32
105         else:
106             raise AssertionError()
107         data2, file = read(file, length)
108         res, = struct.unpack(desc, data2)
109         if res < minimum:
110             raise AssertionError('VarInt not canonically packed')
111         return res, file
112     
113     def write(self, file, item):
114         if item < 0xfd:
115             return file, struct.pack('<B', item)
116         elif item <= 0xffff:
117             return file, struct.pack('<BH', 0xfd, item)
118         elif item <= 0xffffffff:
119             return file, struct.pack('<BI', 0xfe, item)
120         elif item <= 0xffffffffffffffff:
121             return file, struct.pack('<BQ', 0xff, item)
122         else:
123             raise ValueError('int too large for varint')
124
125 class VarStrType(Type):
126     _inner_size = VarIntType()
127     
128     def read(self, file):
129         length, file = self._inner_size.read(file)
130         return read(file, length)
131     
132     def write(self, file, item):
133         return self._inner_size.write(file, len(item)), item
134
135 class EnumType(Type):
136     def __init__(self, inner, pack_to_unpack):
137         self.inner = inner
138         self.pack_to_unpack = pack_to_unpack
139         
140         self.unpack_to_pack = {}
141         for k, v in pack_to_unpack.iteritems():
142             if v in self.unpack_to_pack:
143                 raise ValueError('duplicate value in pack_to_unpack')
144             self.unpack_to_pack[v] = k
145     
146     def read(self, file):
147         data, file = self.inner.read(file)
148         if data not in self.pack_to_unpack:
149             raise ValueError('enum data (%r) not in pack_to_unpack (%r)' % (data, self.pack_to_unpack))
150         return self.pack_to_unpack[data], file
151     
152     def write(self, file, item):
153         if item not in self.unpack_to_pack:
154             raise ValueError('enum item (%r) not in unpack_to_pack (%r)' % (item, self.unpack_to_pack))
155         return self.inner.write(file, self.unpack_to_pack[item])
156
157 class ListType(Type):
158     _inner_size = VarIntType()
159     
160     def __init__(self, type, mul=1):
161         self.type = type
162         self.mul = mul
163     
164     def read(self, file):
165         length, file = self._inner_size.read(file)
166         length *= self.mul
167         res = [None]*length
168         for i in xrange(length):
169             res[i], file = self.type.read(file)
170         return res, file
171     
172     def write(self, file, item):
173         assert len(item) % self.mul == 0
174         file = self._inner_size.write(file, len(item)//self.mul)
175         for subitem in item:
176             file = self.type.write(file, subitem)
177         return file
178
179 class StructType(Type):
180     __slots__ = 'desc length'.split(' ')
181     
182     def __init__(self, desc):
183         self.desc = desc
184         self.length = struct.calcsize(self.desc)
185     
186     def read(self, file):
187         data, file = read(file, self.length)
188         return struct.unpack(self.desc, data)[0], file
189     
190     def write(self, file, item):
191         return file, struct.pack(self.desc, item)
192
193 @memoize.fast_memoize_multiple_args
194 class IntType(Type):
195     __slots__ = 'bytes step format_str max'.split(' ')
196     
197     def __new__(cls, bits, endianness='little'):
198         assert bits % 8 == 0
199         assert endianness in ['little', 'big']
200         if bits in [8, 16, 32, 64]:
201             return StructType(('<' if endianness == 'little' else '>') + {8: 'B', 16: 'H', 32: 'I', 64: 'Q'}[bits])
202         else:
203             return Type.__new__(cls, bits, endianness)
204     
205     def __init__(self, bits, endianness='little'):
206         assert bits % 8 == 0
207         assert endianness in ['little', 'big']
208         self.bytes = bits//8
209         self.step = -1 if endianness == 'little' else 1
210         self.format_str = '%%0%ix' % (2*self.bytes)
211         self.max = 2**bits
212     
213     def read(self, file, b2a_hex=binascii.b2a_hex):
214         if self.bytes == 0:
215             return 0, file
216         data, file = read(file, self.bytes)
217         return int(b2a_hex(data[::self.step]), 16), file
218     
219     def write(self, file, item, a2b_hex=binascii.a2b_hex):
220         if self.bytes == 0:
221             return file
222         if not 0 <= item < self.max:
223             raise ValueError('invalid int value - %r' % (item,))
224         return file, a2b_hex(self.format_str % (item,))[::self.step]
225
226 class IPV6AddressType(Type):
227     def read(self, file):
228         data, file = read(file, 16)
229         if data[:12] == '00000000000000000000ffff'.decode('hex'):
230             return '.'.join(str(ord(x)) for x in data[12:]), file
231         return ':'.join(data[i*2:(i+1)*2].encode('hex') for i in xrange(8)), file
232     
233     def write(self, file, item):
234         if ':' in item:
235             data = ''.join(item.replace(':', '')).decode('hex')
236         else:
237             bits = map(int, item.split('.'))
238             if len(bits) != 4:
239                 raise ValueError('invalid address: %r' % (bits,))
240             data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
241         assert len(data) == 16, len(data)
242         return file, data
243
244 _record_types = {}
245
246 def get_record(fields):
247     fields = tuple(sorted(fields))
248     if 'keys' in fields or '_packed_size' in fields:
249         raise ValueError()
250     if fields not in _record_types:
251         class _Record(object):
252             __slots__ = fields + ('_packed_size',)
253             def __init__(self):
254                 self._packed_size = None
255             def __repr__(self):
256                 return repr(dict(self))
257             def __getitem__(self, key):
258                 return getattr(self, key)
259             def __setitem__(self, key, value):
260                 setattr(self, key, value)
261             #def __iter__(self):
262             #    for field in fields:
263             #        yield field, getattr(self, field)
264             def keys(self):
265                 return fields
266             def get(self, key, default=None):
267                 return getattr(self, key, default)
268             def __eq__(self, other):
269                 if isinstance(other, dict):
270                     return dict(self) == other
271                 elif isinstance(other, _Record):
272                     for k in fields:
273                         if getattr(self, k) != getattr(other, k):
274                             return False
275                     return True
276                 elif other is None:
277                     return False
278                 raise TypeError()
279             def __ne__(self, other):
280                 return not (self == other)
281         _record_types[fields] = _Record
282     return _record_types[fields]
283
284 class ComposedType(Type):
285     def __init__(self, fields):
286         self.fields = list(fields)
287         self.field_names = set(k for k, v in fields)
288         self.record_type = get_record(k for k, v in self.fields)
289     
290     def read(self, file):
291         item = self.record_type()
292         for key, type_ in self.fields:
293             item[key], file = type_.read(file)
294         return item, file
295     
296     def write(self, file, item):
297         assert set(item.keys()) == self.field_names, (set(item.keys()) - self.field_names, self.field_names - set(item.keys()))
298         for key, type_ in self.fields:
299             file = type_.write(file, item[key])
300         return file
301
302 class PossiblyNoneType(Type):
303     def __init__(self, none_value, inner):
304         self.none_value = none_value
305         self.inner = inner
306     
307     def read(self, file):
308         value, file = self.inner.read(file)
309         return None if value == self.none_value else value, file
310     
311     def write(self, file, item):
312         if item == self.none_value:
313             raise ValueError('none_value used')
314         return self.inner.write(file, self.none_value if item is None else item)
315
316 class FixedStrType(Type):
317     def __init__(self, length):
318         self.length = length
319     
320     def read(self, file):
321         return read(file, self.length)
322     
323     def write(self, file, item):
324         if len(item) != self.length:
325             raise ValueError('incorrect length item!')
326         return file, item