fixed 5cde217 causing assertion failures in debug mode
[p2pool.git] / p2pool / util / pack.py
1 import binascii
2 import struct
3
4 import p2pool
5 from p2pool.util import memoize
6
7 class EarlyEnd(Exception):
8     pass
9
10 class LateEnd(Exception):
11     pass
12
13 def read((data, pos), length):
14     data2 = data[pos:pos + length]
15     if len(data2) != length:
16         raise EarlyEnd()
17     return data2, (data, pos + length)
18
19 def size((data, pos)):
20     return len(data) - pos
21
22 class Type(object):
23     __slots__ = []
24     
25     def __hash__(self):
26         rval = getattr(self, '_hash', None)
27         if rval is None:
28             try:
29                 rval = self._hash = hash((type(self), frozenset(self.__dict__.items())))
30             except:
31                 print self.__dict__
32                 raise
33         return rval
34     
35     def __eq__(self, other):
36         return type(other) is type(self) and other.__dict__ == self.__dict__
37     
38     def __ne__(self, other):
39         return not (self == other)
40     
41     def _unpack(self, data, ignore_trailing=False):
42         obj, (data2, pos) = self.read((data, 0))
43         
44         assert data2 is data
45         
46         if pos != len(data) and not ignore_trailing:
47             raise LateEnd()
48         
49         return obj
50     
51     def _pack(self, obj):
52         f = self.write(None, obj)
53         
54         res = []
55         while f is not None:
56             res.append(f[1])
57             f = f[0]
58         res.reverse()
59         return ''.join(res)
60     
61     
62     def unpack(self, data, ignore_trailing=False):
63         obj = self._unpack(data, ignore_trailing)
64         
65         if p2pool.DEBUG:
66             packed = self._pack(obj)
67             good = data.startswith(packed) if ignore_trailing else data == packed
68             if not good:
69                 raise AssertionError()
70         
71         return obj
72     
73     def pack(self, obj):
74         data = self._pack(obj)
75         
76         if p2pool.DEBUG:
77             if self._unpack(data) != obj:
78                 raise AssertionError((self._unpack(data), obj))
79         
80         return data
81     
82     def packed_size(self, obj):
83         if hasattr(obj, '_packed_size') and obj._packed_size is not None:
84             type_obj, packed_size = obj._packed_size
85             if type_obj is self:
86                 return packed_size
87         
88         packed_size = len(self.pack(obj))
89         
90         if hasattr(obj, '_packed_size'):
91             obj._packed_size = self, packed_size
92         
93         return packed_size
94
95 class VarIntType(Type):
96     def read(self, file):
97         data, file = read(file, 1)
98         first = ord(data)
99         if first < 0xfd:
100             return first, file
101         if first == 0xfd:
102             desc, length, minimum = '<H', 2, 0xfd
103         elif first == 0xfe:
104             desc, length, minimum = '<I', 4, 2**16
105         elif first == 0xff:
106             desc, length, minimum = '<Q', 8, 2**32
107         else:
108             raise AssertionError()
109         data2, file = read(file, length)
110         res, = struct.unpack(desc, data2)
111         if res < minimum:
112             raise AssertionError('VarInt not canonically packed')
113         return res, file
114     
115     def write(self, file, item):
116         if item < 0xfd:
117             return file, struct.pack('<B', item)
118         elif item <= 0xffff:
119             return file, struct.pack('<BH', 0xfd, item)
120         elif item <= 0xffffffff:
121             return file, struct.pack('<BI', 0xfe, item)
122         elif item <= 0xffffffffffffffff:
123             return file, struct.pack('<BQ', 0xff, item)
124         else:
125             raise ValueError('int too large for varint')
126
127 class VarStrType(Type):
128     _inner_size = VarIntType()
129     
130     def read(self, file):
131         length, file = self._inner_size.read(file)
132         return read(file, length)
133     
134     def write(self, file, item):
135         return self._inner_size.write(file, len(item)), item
136
137 class EnumType(Type):
138     def __init__(self, inner, pack_to_unpack):
139         self.inner = inner
140         self.pack_to_unpack = pack_to_unpack
141         
142         self.unpack_to_pack = {}
143         for k, v in pack_to_unpack.iteritems():
144             if v in self.unpack_to_pack:
145                 raise ValueError('duplicate value in pack_to_unpack')
146             self.unpack_to_pack[v] = k
147     
148     def read(self, file):
149         data, file = self.inner.read(file)
150         if data not in self.pack_to_unpack:
151             raise ValueError('enum data (%r) not in pack_to_unpack (%r)' % (data, self.pack_to_unpack))
152         return self.pack_to_unpack[data], file
153     
154     def write(self, file, item):
155         if item not in self.unpack_to_pack:
156             raise ValueError('enum item (%r) not in unpack_to_pack (%r)' % (item, self.unpack_to_pack))
157         return self.inner.write(file, self.unpack_to_pack[item])
158
159 class ListType(Type):
160     _inner_size = VarIntType()
161     
162     def __init__(self, type, mul=1):
163         self.type = type
164         self.mul = mul
165     
166     def read(self, file):
167         length, file = self._inner_size.read(file)
168         length *= self.mul
169         res = [None]*length
170         for i in xrange(length):
171             res[i], file = self.type.read(file)
172         return res, file
173     
174     def write(self, file, item):
175         assert len(item) % self.mul == 0
176         file = self._inner_size.write(file, len(item)//self.mul)
177         for subitem in item:
178             file = self.type.write(file, subitem)
179         return file
180
181 class StructType(Type):
182     __slots__ = 'desc length'.split(' ')
183     
184     def __init__(self, desc):
185         self.desc = desc
186         self.length = struct.calcsize(self.desc)
187     
188     def read(self, file):
189         data, file = read(file, self.length)
190         return struct.unpack(self.desc, data)[0], file
191     
192     def write(self, file, item):
193         return file, struct.pack(self.desc, item)
194
195 @memoize.fast_memoize_multiple_args
196 class IntType(Type):
197     __slots__ = 'bytes step format_str max'.split(' ')
198     
199     def __new__(cls, bits, endianness='little'):
200         assert bits % 8 == 0
201         assert endianness in ['little', 'big']
202         if bits in [8, 16, 32, 64]:
203             return StructType(('<' if endianness == 'little' else '>') + {8: 'B', 16: 'H', 32: 'I', 64: 'Q'}[bits])
204         else:
205             return Type.__new__(cls, bits, endianness)
206     
207     def __init__(self, bits, endianness='little'):
208         assert bits % 8 == 0
209         assert endianness in ['little', 'big']
210         self.bytes = bits//8
211         self.step = -1 if endianness == 'little' else 1
212         self.format_str = '%%0%ix' % (2*self.bytes)
213         self.max = 2**bits
214     
215     def read(self, file, b2a_hex=binascii.b2a_hex):
216         if self.bytes == 0:
217             return 0, file
218         data, file = read(file, self.bytes)
219         return int(b2a_hex(data[::self.step]), 16), file
220     
221     def write(self, file, item, a2b_hex=binascii.a2b_hex):
222         if self.bytes == 0:
223             return file
224         if not 0 <= item < self.max:
225             raise ValueError('invalid int value - %r' % (item,))
226         return file, a2b_hex(self.format_str % (item,))[::self.step]
227
228 class IPV6AddressType(Type):
229     def read(self, file):
230         data, file = read(file, 16)
231         if data[:12] == '00000000000000000000ffff'.decode('hex'):
232             return '.'.join(str(ord(x)) for x in data[12:]), file
233         return ':'.join(data[i*2:(i+1)*2].encode('hex') for i in xrange(8)), file
234     
235     def write(self, file, item):
236         if ':' in item:
237             data = ''.join(item.replace(':', '')).decode('hex')
238         else:
239             bits = map(int, item.split('.'))
240             if len(bits) != 4:
241                 raise ValueError('invalid address: %r' % (bits,))
242             data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
243         assert len(data) == 16, len(data)
244         return file, data
245
246 _record_types = {}
247
248 def get_record(fields):
249     fields = tuple(sorted(fields))
250     if 'keys' in fields or '_packed_size' in fields:
251         raise ValueError()
252     if fields not in _record_types:
253         class _Record(object):
254             __slots__ = fields + ('_packed_size',)
255             def __init__(self):
256                 self._packed_size = None
257             def __repr__(self):
258                 return repr(dict(self))
259             def __getitem__(self, key):
260                 return getattr(self, key)
261             def __setitem__(self, key, value):
262                 setattr(self, key, value)
263             #def __iter__(self):
264             #    for field in fields:
265             #        yield field, getattr(self, field)
266             def keys(self):
267                 return fields
268             def get(self, key, default=None):
269                 return getattr(self, key, default)
270             def __eq__(self, other):
271                 if isinstance(other, dict):
272                     return dict(self) == other
273                 elif isinstance(other, _Record):
274                     for k in fields:
275                         if getattr(self, k) != getattr(other, k):
276                             return False
277                     return True
278                 elif other is None:
279                     return False
280                 raise TypeError()
281             def __ne__(self, other):
282                 return not (self == other)
283         _record_types[fields] = _Record
284     return _record_types[fields]
285
286 class ComposedType(Type):
287     def __init__(self, fields):
288         self.fields = list(fields)
289         self.field_names = set(k for k, v in fields)
290         self.record_type = get_record(k for k, v in self.fields)
291     
292     def read(self, file):
293         item = self.record_type()
294         for key, type_ in self.fields:
295             item[key], file = type_.read(file)
296         return item, file
297     
298     def write(self, file, item):
299         assert set(item.keys()) == self.field_names, (set(item.keys()) - self.field_names, self.field_names - set(item.keys()))
300         for key, type_ in self.fields:
301             file = type_.write(file, item[key])
302         return file
303
304 class PossiblyNoneType(Type):
305     def __init__(self, none_value, inner):
306         self.none_value = none_value
307         self.inner = inner
308     
309     def read(self, file):
310         value, file = self.inner.read(file)
311         return None if value == self.none_value else value, file
312     
313     def write(self, file, item):
314         if item == self.none_value:
315             raise ValueError('none_value used')
316         return self.inner.write(file, self.none_value if item is None else item)
317
318 class FixedStrType(Type):
319     def __init__(self, length):
320         self.length = length
321     
322     def read(self, file):
323         return read(file, self.length)
324     
325     def write(self, file, item):
326         if len(item) != self.length:
327             raise ValueError('incorrect length item!')
328         return file, item