moved generic data types to util.pack
[p2pool.git] / p2pool / util / pack.py
1 import binascii
2 import hashlib
3 import struct
4
5 from p2pool.bitcoin import base58
6 import p2pool
7
8 class EarlyEnd(Exception):
9     pass
10
11 class LateEnd(Exception):
12     pass
13
14 def read((data, pos), length):
15     data2 = data[pos:pos + length]
16     if len(data2) != length:
17         raise EarlyEnd()
18     return data2, (data, pos + length)
19
20 def size((data, pos)):
21     return len(data) - pos
22
23 class Type(object):
24     __slots__ = []
25     
26     # the same data can have only one unpacked representation, but multiple packed binary representations
27     
28     def __hash__(self):
29         rval = getattr(self, '_hash', None)
30         if rval is None:
31             try:
32                 rval = self._hash = hash((type(self), frozenset(self.__dict__.items())))
33             except:
34                 print self.__dict__
35                 raise
36         return rval
37     
38     def __eq__(self, other):
39         return type(other) is type(self) and other.__dict__ == self.__dict__
40     
41     def __ne__(self, other):
42         return not (self == other)
43     
44     def _unpack(self, data):
45         obj, (data2, pos) = self.read((data, 0))
46         
47         assert data2 is data
48         
49         if pos != len(data):
50             raise LateEnd()
51         
52         return obj
53     
54     def _pack(self, obj):
55         f = self.write(None, obj)
56         
57         res = []
58         while f is not None:
59             res.append(f[1])
60             f = f[0]
61         res.reverse()
62         return ''.join(res)
63     
64     
65     def unpack(self, data):
66         obj = self._unpack(data)
67         
68         if p2pool.DEBUG:
69             data2 = self._pack(obj)
70             if data2 != data:
71                 if self._unpack(data2) != obj:
72                     raise AssertionError()
73         
74         return obj
75     
76     def pack(self, obj):
77         data = self._pack(obj)
78         
79         if p2pool.DEBUG:
80             if self._unpack(data) != obj:
81                 raise AssertionError((self._unpack(data), obj))
82         
83         return data
84     
85     
86     def pack_base58(self, obj):
87         return base58.encode(self.pack(obj))
88     
89     def unpack_base58(self, base58_data):
90         return self.unpack(base58.decode(base58_data))
91     
92     
93     def hash160(self, obj):
94         return IntType(160).unpack(hashlib.new('ripemd160', hashlib.sha256(self.pack(obj)).digest()).digest())
95     
96     def hash256(self, obj):
97         return IntType(256).unpack(hashlib.sha256(hashlib.sha256(self.pack(obj)).digest()).digest())
98     
99     def scrypt(self, obj):
100         import ltc_scrypt
101         return IntType(256).unpack(ltc_scrypt.getPoWHash(self.pack(obj)))
102
103 class VarIntType(Type):
104     # redundancy doesn't matter here because bitcoin and p2pool both reencode before hashing
105     def read(self, file):
106         data, file = read(file, 1)
107         first = ord(data)
108         if first < 0xfd:
109             return first, file
110         elif first == 0xfd:
111             desc, length = '<H', 2
112         elif first == 0xfe:
113             desc, length = '<I', 4
114         elif first == 0xff:
115             desc, length = '<Q', 8
116         else:
117             raise AssertionError()
118         data, file = read(file, length)
119         return struct.unpack(desc, data)[0], file
120     
121     def write(self, file, item):
122         if item < 0xfd:
123             file = file, struct.pack('<B', item)
124         elif item <= 0xffff:
125             file = file, struct.pack('<BH', 0xfd, item)
126         elif item <= 0xffffffff:
127             file = file, struct.pack('<BI', 0xfe, item)
128         elif item <= 0xffffffffffffffff:
129             file = file, struct.pack('<BQ', 0xff, item)
130         else:
131             raise ValueError('int too large for varint')
132         return file
133
134 class VarStrType(Type):
135     _inner_size = VarIntType()
136     
137     def read(self, file):
138         length, file = self._inner_size.read(file)
139         return read(file, length)
140     
141     def write(self, file, item):
142         return self._inner_size.write(file, len(item)), item
143
144 class PassthruType(Type):
145     def read(self, file):
146         return read(file, size(file))
147     
148     def write(self, file, item):
149         return file, item
150
151 class EnumType(Type):
152     def __init__(self, inner, values):
153         self.inner = inner
154         self.values = values
155         
156         keys = {}
157         for k, v in values.iteritems():
158             if v in keys:
159                 raise ValueError('duplicate value in values')
160             keys[v] = k
161         self.keys = keys
162     
163     def read(self, file):
164         data, file = self.inner.read(file)
165         if data not in self.keys:
166             raise ValueError('enum data (%r) not in values (%r)' % (data, self.values))
167         return self.keys[data], file
168     
169     def write(self, file, item):
170         if item not in self.values:
171             raise ValueError('enum item (%r) not in values (%r)' % (item, self.values))
172         return self.inner.write(file, self.values[item])
173
174 class ListType(Type):
175     _inner_size = VarIntType()
176     
177     def __init__(self, type):
178         self.type = type
179     
180     def read(self, file):
181         length, file = self._inner_size.read(file)
182         res = []
183         for i in xrange(length):
184             item, file = self.type.read(file)
185             res.append(item)
186         return res, file
187     
188     def write(self, file, item):
189         file = self._inner_size.write(file, len(item))
190         for subitem in item:
191             file = self.type.write(file, subitem)
192         return file
193
194 class StructType(Type):
195     __slots__ = 'desc length'.split(' ')
196     
197     def __init__(self, desc):
198         self.desc = desc
199         self.length = struct.calcsize(self.desc)
200     
201     def read(self, file):
202         data, file = read(file, self.length)
203         return struct.unpack(self.desc, data)[0], file
204     
205     def write(self, file, item):
206         return file, struct.pack(self.desc, item)
207
208 class IntType(Type):
209     __slots__ = 'bytes step format_str max'.split(' ')
210     
211     def __new__(cls, bits, endianness='little'):
212         assert bits % 8 == 0
213         assert endianness in ['little', 'big']
214         if bits in [8, 16, 32, 64]:
215             return StructType(('<' if endianness == 'little' else '>') + {8: 'B', 16: 'H', 32: 'I', 64: 'Q'}[bits])
216         else:
217             return Type.__new__(cls, bits, endianness)
218     
219     def __init__(self, bits, endianness='little'):
220         assert bits % 8 == 0
221         assert endianness in ['little', 'big']
222         self.bytes = bits//8
223         self.step = -1 if endianness == 'little' else 1
224         self.format_str = '%%0%ix' % (2*self.bytes)
225         self.max = 2**bits
226     
227     def read(self, file, b2a_hex=binascii.b2a_hex):
228         data, file = read(file, self.bytes)
229         return int(b2a_hex(data[::self.step]), 16), file
230     
231     def write(self, file, item, a2b_hex=binascii.a2b_hex):
232         if not 0 <= item < self.max:
233             raise ValueError('invalid int value - %r' % (item,))
234         return file, a2b_hex(self.format_str % (item,))[::self.step]
235
236 class IPV6AddressType(Type):
237     def read(self, file):
238         data, file = read(file, 16)
239         if data[:12] != '00000000000000000000ffff'.decode('hex'):
240             raise ValueError('ipv6 addresses not supported yet')
241         return '.'.join(str(ord(x)) for x in data[12:]), file
242     
243     def write(self, file, item):
244         bits = map(int, item.split('.'))
245         if len(bits) != 4:
246             raise ValueError('invalid address: %r' % (bits,))
247         data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
248         assert len(data) == 16, len(data)
249         return file, data
250
251 _record_types = {}
252
253 def get_record(fields):
254     fields = tuple(sorted(fields))
255     if 'keys' in fields:
256         raise ValueError()
257     if fields not in _record_types:
258         class _Record(object):
259             __slots__ = fields
260             def __repr__(self):
261                 return repr(dict(self))
262             def __getitem__(self, key):
263                 return getattr(self, key)
264             def __setitem__(self, key, value):
265                 setattr(self, key, value)
266             #def __iter__(self):
267             #    for field in self.__slots__:
268             #        yield field, getattr(self, field)
269             def keys(self):
270                 return self.__slots__
271             def __eq__(self, other):
272                 if isinstance(other, dict):
273                     return dict(self) == other
274                 elif isinstance(other, _Record):
275                     return all(self[k] == other[k] for k in self.keys())
276                 raise TypeError()
277             def __ne__(self, other):
278                 return not (self == other)
279         _record_types[fields] = _Record
280     return _record_types[fields]()
281
282 class ComposedType(Type):
283     def __init__(self, fields):
284         self.fields = tuple(fields)
285     
286     def read(self, file):
287         item = get_record(k for k, v in self.fields)
288         for key, type_ in self.fields:
289             item[key], file = type_.read(file)
290         return item, file
291     
292     def write(self, file, item):
293         for key, type_ in self.fields:
294             file = type_.write(file, item[key])
295         return file
296
297 class PossiblyNoneType(Type):
298     def __init__(self, none_value, inner):
299         self.none_value = none_value
300         self.inner = inner
301     
302     def read(self, file):
303         value, file = self.inner.read(file)
304         return None if value == self.none_value else value, file
305     
306     def write(self, file, item):
307         if item == self.none_value:
308             raise ValueError('none_value used')
309         return self.inner.write(file, self.none_value if item is None else item)