removed outdated comment in util.pack
[p2pool.git] / p2pool / util / pack.py
1 import binascii
2 import struct
3
4 import p2pool
5
6 class EarlyEnd(Exception):
7     pass
8
9 class LateEnd(Exception):
10     pass
11
12 def read((data, pos), length):
13     data2 = data[pos:pos + length]
14     if len(data2) != length:
15         raise EarlyEnd()
16     return data2, (data, pos + length)
17
18 def size((data, pos)):
19     return len(data) - pos
20
21 class Type(object):
22     __slots__ = []
23     
24     def __hash__(self):
25         rval = getattr(self, '_hash', None)
26         if rval is None:
27             try:
28                 rval = self._hash = hash((type(self), frozenset(self.__dict__.items())))
29             except:
30                 print self.__dict__
31                 raise
32         return rval
33     
34     def __eq__(self, other):
35         return type(other) is type(self) and other.__dict__ == self.__dict__
36     
37     def __ne__(self, other):
38         return not (self == other)
39     
40     def _unpack(self, data):
41         obj, (data2, pos) = self.read((data, 0))
42         
43         assert data2 is data
44         
45         if pos != len(data):
46             raise LateEnd()
47         
48         return obj
49     
50     def _pack(self, obj):
51         f = self.write(None, obj)
52         
53         res = []
54         while f is not None:
55             res.append(f[1])
56             f = f[0]
57         res.reverse()
58         return ''.join(res)
59     
60     
61     def unpack(self, data):
62         obj = self._unpack(data)
63         
64         if p2pool.DEBUG:
65             data2 = self._pack(obj)
66             if data2 != data:
67                 if self._unpack(data2) != obj:
68                     raise AssertionError()
69         
70         return obj
71     
72     def pack(self, obj):
73         data = self._pack(obj)
74         
75         if p2pool.DEBUG:
76             if self._unpack(data) != obj:
77                 raise AssertionError((self._unpack(data), obj))
78         
79         return data
80
81 class VarIntType(Type):
82     def read(self, file):
83         data, file = read(file, 1)
84         first = ord(data)
85         if first < 0xfd:
86             return first, file
87         if first == 0xfd:
88             desc, length, minimum = '<H', 2, 0xfd
89         elif first == 0xfe:
90             desc, length, minimum = '<I', 4, 2**16
91         elif first == 0xff:
92             desc, length, minimum = '<Q', 8, 2**32
93         else:
94             raise AssertionError()
95         data2, file = read(file, length)
96         res, = struct.unpack(desc, data2)
97         if res < minimum:
98             raise AssertionError('VarInt not canonically packed')
99         return res, file
100     
101     def write(self, file, item):
102         if item < 0xfd:
103             return file, struct.pack('<B', item)
104         elif item <= 0xffff:
105             return file, struct.pack('<BH', 0xfd, item)
106         elif item <= 0xffffffff:
107             return file, struct.pack('<BI', 0xfe, item)
108         elif item <= 0xffffffffffffffff:
109             return file, struct.pack('<BQ', 0xff, item)
110         else:
111             raise ValueError('int too large for varint')
112
113 class VarStrType(Type):
114     _inner_size = VarIntType()
115     
116     def read(self, file):
117         length, file = self._inner_size.read(file)
118         return read(file, length)
119     
120     def write(self, file, item):
121         return self._inner_size.write(file, len(item)), item
122
123 class EnumType(Type):
124     def __init__(self, inner, values):
125         self.inner = inner
126         self.values = values
127         
128         keys = {}
129         for k, v in values.iteritems():
130             if v in keys:
131                 raise ValueError('duplicate value in values')
132             keys[v] = k
133         self.keys = keys
134     
135     def read(self, file):
136         data, file = self.inner.read(file)
137         if data not in self.keys:
138             raise ValueError('enum data (%r) not in values (%r)' % (data, self.values))
139         return self.keys[data], file
140     
141     def write(self, file, item):
142         if item not in self.values:
143             raise ValueError('enum item (%r) not in values (%r)' % (item, self.values))
144         return self.inner.write(file, self.values[item])
145
146 class ListType(Type):
147     _inner_size = VarIntType()
148     
149     def __init__(self, type):
150         self.type = type
151     
152     def read(self, file):
153         length, file = self._inner_size.read(file)
154         res = []
155         for i in xrange(length):
156             item, file = self.type.read(file)
157             res.append(item)
158         return res, file
159     
160     def write(self, file, item):
161         file = self._inner_size.write(file, len(item))
162         for subitem in item:
163             file = self.type.write(file, subitem)
164         return file
165
166 class StructType(Type):
167     __slots__ = 'desc length'.split(' ')
168     
169     def __init__(self, desc):
170         self.desc = desc
171         self.length = struct.calcsize(self.desc)
172     
173     def read(self, file):
174         data, file = read(file, self.length)
175         return struct.unpack(self.desc, data)[0], file
176     
177     def write(self, file, item):
178         return file, struct.pack(self.desc, item)
179
180 class IntType(Type):
181     __slots__ = 'bytes step format_str max'.split(' ')
182     
183     def __new__(cls, bits, endianness='little'):
184         assert bits % 8 == 0
185         assert endianness in ['little', 'big']
186         if bits in [8, 16, 32, 64]:
187             return StructType(('<' if endianness == 'little' else '>') + {8: 'B', 16: 'H', 32: 'I', 64: 'Q'}[bits])
188         else:
189             return Type.__new__(cls, bits, endianness)
190     
191     def __init__(self, bits, endianness='little'):
192         assert bits % 8 == 0
193         assert endianness in ['little', 'big']
194         self.bytes = bits//8
195         self.step = -1 if endianness == 'little' else 1
196         self.format_str = '%%0%ix' % (2*self.bytes)
197         self.max = 2**bits
198     
199     def read(self, file, b2a_hex=binascii.b2a_hex):
200         data, file = read(file, self.bytes)
201         return int(b2a_hex(data[::self.step]), 16), file
202     
203     def write(self, file, item, a2b_hex=binascii.a2b_hex):
204         if not 0 <= item < self.max:
205             raise ValueError('invalid int value - %r' % (item,))
206         return file, a2b_hex(self.format_str % (item,))[::self.step]
207
208 class IPV6AddressType(Type):
209     def read(self, file):
210         data, file = read(file, 16)
211         if data[:12] != '00000000000000000000ffff'.decode('hex'):
212             raise ValueError('ipv6 addresses not supported yet')
213         return '.'.join(str(ord(x)) for x in data[12:]), file
214     
215     def write(self, file, item):
216         bits = map(int, item.split('.'))
217         if len(bits) != 4:
218             raise ValueError('invalid address: %r' % (bits,))
219         data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
220         assert len(data) == 16, len(data)
221         return file, data
222
223 _record_types = {}
224
225 def get_record(fields):
226     fields = tuple(sorted(fields))
227     if 'keys' in fields:
228         raise ValueError()
229     if fields not in _record_types:
230         class _Record(object):
231             __slots__ = fields
232             def __repr__(self):
233                 return repr(dict(self))
234             def __getitem__(self, key):
235                 return getattr(self, key)
236             def __setitem__(self, key, value):
237                 setattr(self, key, value)
238             #def __iter__(self):
239             #    for field in self.__slots__:
240             #        yield field, getattr(self, field)
241             def keys(self):
242                 return self.__slots__
243             def __eq__(self, other):
244                 if isinstance(other, dict):
245                     return dict(self) == other
246                 elif isinstance(other, _Record):
247                     return all(self[k] == other[k] for k in self.keys())
248                 raise TypeError()
249             def __ne__(self, other):
250                 return not (self == other)
251         _record_types[fields] = _Record
252     return _record_types[fields]()
253
254 class ComposedType(Type):
255     def __init__(self, fields):
256         self.fields = tuple(fields)
257     
258     def read(self, file):
259         item = get_record(k for k, v in self.fields)
260         for key, type_ in self.fields:
261             item[key], file = type_.read(file)
262         return item, file
263     
264     def write(self, file, item):
265         for key, type_ in self.fields:
266             file = type_.write(file, item[key])
267         return file
268
269 class PossiblyNoneType(Type):
270     def __init__(self, none_value, inner):
271         self.none_value = none_value
272         self.inner = inner
273     
274     def read(self, file):
275         value, file = self.inner.read(file)
276         return None if value == self.none_value else value, file
277     
278     def write(self, file, item):
279         if item == self.none_value:
280             raise ValueError('none_value used')
281         return self.inner.write(file, self.none_value if item is None else item)