moved base58 to only p2pool.bitcoin.data
[p2pool.git] / p2pool / util / pack.py
1 import binascii
2 import hashlib
3 import struct
4
5 import p2pool
6
7 class EarlyEnd(Exception):
8     pass
9
10 class LateEnd(Exception):
11     pass
12
13 def read((data, pos), length):
14     data2 = data[pos:pos + length]
15     if len(data2) != length:
16         raise EarlyEnd()
17     return data2, (data, pos + length)
18
19 def size((data, pos)):
20     return len(data) - pos
21
22 class Type(object):
23     __slots__ = []
24     
25     # the same data can have only one unpacked representation, but multiple packed binary representations
26     
27     def __hash__(self):
28         rval = getattr(self, '_hash', None)
29         if rval is None:
30             try:
31                 rval = self._hash = hash((type(self), frozenset(self.__dict__.items())))
32             except:
33                 print self.__dict__
34                 raise
35         return rval
36     
37     def __eq__(self, other):
38         return type(other) is type(self) and other.__dict__ == self.__dict__
39     
40     def __ne__(self, other):
41         return not (self == other)
42     
43     def _unpack(self, data):
44         obj, (data2, pos) = self.read((data, 0))
45         
46         assert data2 is data
47         
48         if pos != len(data):
49             raise LateEnd()
50         
51         return obj
52     
53     def _pack(self, obj):
54         f = self.write(None, obj)
55         
56         res = []
57         while f is not None:
58             res.append(f[1])
59             f = f[0]
60         res.reverse()
61         return ''.join(res)
62     
63     
64     def unpack(self, data):
65         obj = self._unpack(data)
66         
67         if p2pool.DEBUG:
68             data2 = self._pack(obj)
69             if data2 != data:
70                 if self._unpack(data2) != obj:
71                     raise AssertionError()
72         
73         return obj
74     
75     def pack(self, obj):
76         data = self._pack(obj)
77         
78         if p2pool.DEBUG:
79             if self._unpack(data) != obj:
80                 raise AssertionError((self._unpack(data), obj))
81         
82         return data
83     
84     
85     def hash160(self, obj):
86         return IntType(160).unpack(hashlib.new('ripemd160', hashlib.sha256(self.pack(obj)).digest()).digest())
87     
88     def hash256(self, obj):
89         return IntType(256).unpack(hashlib.sha256(hashlib.sha256(self.pack(obj)).digest()).digest())
90     
91     def scrypt(self, obj):
92         import ltc_scrypt
93         return IntType(256).unpack(ltc_scrypt.getPoWHash(self.pack(obj)))
94
95 class VarIntType(Type):
96     # redundancy doesn't matter here because bitcoin and p2pool both reencode before hashing
97     def read(self, file):
98         data, file = read(file, 1)
99         first = ord(data)
100         if first < 0xfd:
101             return first, file
102         elif first == 0xfd:
103             desc, length = '<H', 2
104         elif first == 0xfe:
105             desc, length = '<I', 4
106         elif first == 0xff:
107             desc, length = '<Q', 8
108         else:
109             raise AssertionError()
110         data, file = read(file, length)
111         return struct.unpack(desc, data)[0], file
112     
113     def write(self, file, item):
114         if item < 0xfd:
115             file = file, struct.pack('<B', item)
116         elif item <= 0xffff:
117             file = file, struct.pack('<BH', 0xfd, item)
118         elif item <= 0xffffffff:
119             file = file, struct.pack('<BI', 0xfe, item)
120         elif item <= 0xffffffffffffffff:
121             file = file, struct.pack('<BQ', 0xff, item)
122         else:
123             raise ValueError('int too large for varint')
124         return file
125
126 class VarStrType(Type):
127     _inner_size = VarIntType()
128     
129     def read(self, file):
130         length, file = self._inner_size.read(file)
131         return read(file, length)
132     
133     def write(self, file, item):
134         return self._inner_size.write(file, len(item)), item
135
136 class PassthruType(Type):
137     def read(self, file):
138         return read(file, size(file))
139     
140     def write(self, file, item):
141         return file, item
142
143 class EnumType(Type):
144     def __init__(self, inner, values):
145         self.inner = inner
146         self.values = values
147         
148         keys = {}
149         for k, v in values.iteritems():
150             if v in keys:
151                 raise ValueError('duplicate value in values')
152             keys[v] = k
153         self.keys = keys
154     
155     def read(self, file):
156         data, file = self.inner.read(file)
157         if data not in self.keys:
158             raise ValueError('enum data (%r) not in values (%r)' % (data, self.values))
159         return self.keys[data], file
160     
161     def write(self, file, item):
162         if item not in self.values:
163             raise ValueError('enum item (%r) not in values (%r)' % (item, self.values))
164         return self.inner.write(file, self.values[item])
165
166 class ListType(Type):
167     _inner_size = VarIntType()
168     
169     def __init__(self, type):
170         self.type = type
171     
172     def read(self, file):
173         length, file = self._inner_size.read(file)
174         res = []
175         for i in xrange(length):
176             item, file = self.type.read(file)
177             res.append(item)
178         return res, file
179     
180     def write(self, file, item):
181         file = self._inner_size.write(file, len(item))
182         for subitem in item:
183             file = self.type.write(file, subitem)
184         return file
185
186 class StructType(Type):
187     __slots__ = 'desc length'.split(' ')
188     
189     def __init__(self, desc):
190         self.desc = desc
191         self.length = struct.calcsize(self.desc)
192     
193     def read(self, file):
194         data, file = read(file, self.length)
195         return struct.unpack(self.desc, data)[0], file
196     
197     def write(self, file, item):
198         return file, struct.pack(self.desc, item)
199
200 class IntType(Type):
201     __slots__ = 'bytes step format_str max'.split(' ')
202     
203     def __new__(cls, bits, endianness='little'):
204         assert bits % 8 == 0
205         assert endianness in ['little', 'big']
206         if bits in [8, 16, 32, 64]:
207             return StructType(('<' if endianness == 'little' else '>') + {8: 'B', 16: 'H', 32: 'I', 64: 'Q'}[bits])
208         else:
209             return Type.__new__(cls, bits, endianness)
210     
211     def __init__(self, bits, endianness='little'):
212         assert bits % 8 == 0
213         assert endianness in ['little', 'big']
214         self.bytes = bits//8
215         self.step = -1 if endianness == 'little' else 1
216         self.format_str = '%%0%ix' % (2*self.bytes)
217         self.max = 2**bits
218     
219     def read(self, file, b2a_hex=binascii.b2a_hex):
220         data, file = read(file, self.bytes)
221         return int(b2a_hex(data[::self.step]), 16), file
222     
223     def write(self, file, item, a2b_hex=binascii.a2b_hex):
224         if not 0 <= item < self.max:
225             raise ValueError('invalid int value - %r' % (item,))
226         return file, a2b_hex(self.format_str % (item,))[::self.step]
227
228 class IPV6AddressType(Type):
229     def read(self, file):
230         data, file = read(file, 16)
231         if data[:12] != '00000000000000000000ffff'.decode('hex'):
232             raise ValueError('ipv6 addresses not supported yet')
233         return '.'.join(str(ord(x)) for x in data[12:]), file
234     
235     def write(self, file, item):
236         bits = map(int, item.split('.'))
237         if len(bits) != 4:
238             raise ValueError('invalid address: %r' % (bits,))
239         data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
240         assert len(data) == 16, len(data)
241         return file, data
242
243 _record_types = {}
244
245 def get_record(fields):
246     fields = tuple(sorted(fields))
247     if 'keys' in fields:
248         raise ValueError()
249     if fields not in _record_types:
250         class _Record(object):
251             __slots__ = fields
252             def __repr__(self):
253                 return repr(dict(self))
254             def __getitem__(self, key):
255                 return getattr(self, key)
256             def __setitem__(self, key, value):
257                 setattr(self, key, value)
258             #def __iter__(self):
259             #    for field in self.__slots__:
260             #        yield field, getattr(self, field)
261             def keys(self):
262                 return self.__slots__
263             def __eq__(self, other):
264                 if isinstance(other, dict):
265                     return dict(self) == other
266                 elif isinstance(other, _Record):
267                     return all(self[k] == other[k] for k in self.keys())
268                 raise TypeError()
269             def __ne__(self, other):
270                 return not (self == other)
271         _record_types[fields] = _Record
272     return _record_types[fields]()
273
274 class ComposedType(Type):
275     def __init__(self, fields):
276         self.fields = tuple(fields)
277     
278     def read(self, file):
279         item = get_record(k for k, v in self.fields)
280         for key, type_ in self.fields:
281             item[key], file = type_.read(file)
282         return item, file
283     
284     def write(self, file, item):
285         for key, type_ in self.fields:
286             file = type_.write(file, item[key])
287         return file
288
289 class PossiblyNoneType(Type):
290     def __init__(self, none_value, inner):
291         self.none_value = none_value
292         self.inner = inner
293     
294     def read(self, file):
295         value, file = self.inner.read(file)
296         return None if value == self.none_value else value, file
297     
298     def write(self, file, item):
299         if item == self.none_value:
300             raise ValueError('none_value used')
301         return self.inner.write(file, self.none_value if item is None else item)