7dc4bb81434fd9798ac6b5af2fe6d8cee3b286ef
[p2pool.git] / bitcoin_p2p.py
1 '''
2 Implementation of Bitcoin's p2p protocol
3 '''
4
5 from __future__ import division
6
7 import hashlib
8 import random
9 import StringIO
10 import socket
11 import struct
12 import time
13 import traceback
14
15 from twisted.internet import defer, protocol, reactor
16
17 import expiring_dict
18 import util
19
20 class EarlyEnd(Exception):
21     pass
22
23 class LateEnd(Exception):
24     pass
25
26 class Type(object):
27     # the same data can have only one unpacked representation, but multiple packed binary representations
28     
29     def _unpack(self, data):
30         f = StringIO.StringIO(data)
31         
32         obj = self.read(f)
33         
34         if f.tell() != len(data):
35             raise LateEnd('underread ' + repr((self, data)))
36         
37         return obj
38     
39     def unpack(self, data):
40         obj = self._unpack(data)
41         assert self._unpack(self._pack(obj)) == obj
42         return obj
43     
44     def _pack(self, obj):
45         f = StringIO.StringIO()
46         
47         self.write(f, obj)
48         
49         data = f.getvalue()
50         
51         return data
52     
53     def pack(self, obj):
54         data = self._pack(obj)
55         assert self._unpack(data) == obj
56         return data
57
58 class VarIntType(Type):
59     def read(self, file):
60         data = file.read(1)
61         if len(data) != 1:
62             raise EarlyEnd()
63         first, = struct.unpack('<B', data)
64         if first == 0xff: desc = '<Q'
65         elif first == 0xfe: desc = '<I'
66         elif first == 0xfd: desc = '<H'
67         else: return first
68         length = struct.calcsize(desc)
69         data = file.read(length)
70         if len(data) != length:
71             raise EarlyEnd()
72         return struct.unpack(desc, data)[0]
73     
74     def write(self, file, item):
75         if item < 0xfd:
76             file.write(struct.pack('<B', item))
77         elif item <= 0xffff:
78             file.write(struct.pack('<BH', 0xfd, item))
79         elif item <= 0xffffffff:
80             file.write(struct.pack('<BI', 0xfe, item))
81         elif item <= 0xffffffffffffffff:
82             file.write(struct.pack('<BQ', 0xff, item))
83         else:
84             raise ValueError('int too large for varint')
85
86 class VarStrType(Type):
87     def read(self, file):
88         length = VarIntType().read(file)
89         res = file.read(length)
90         if len(res) != length:
91             raise EarlyEnd('var str not long enough %r' % ((length, len(res), res),))
92         return res
93     
94     def write(self, file, item):
95         VarIntType().write(file, len(item))
96         file.write(item)
97
98 class FixedStrType(Type):
99     def __init__(self, length):
100         self.length = length
101     
102     def read(self, file):
103         res = file.read(self.length)
104         if len(res) != self.length:
105             raise EarlyEnd('early EOF!')
106         return res
107     
108     def write(self, file, item):
109         if len(item) != self.length:
110             raise ValueError('incorrect length!')
111         file.write(item)
112
113 class EnumType(Type):
114     def __init__(self, inner, values):
115         self.inner = inner
116         self.values = values
117         
118         self.keys = {}
119         for k, v in values.iteritems():
120             if v in self.keys:
121                 raise ValueError('duplicate value in values')
122             self.keys[v] = k
123     
124     def read(self, file):
125         return self.keys[self.inner.read(file)]
126     
127     def write(self, file, item):
128         self.inner.write(file, self.values[item])
129
130 class HashType(Type):
131     def read(self, file):
132         data = file.read(256//8)
133         if len(data) != 256//8:
134             raise EarlyEnd('incorrect length!')
135         return int(data[::-1].encode('hex'), 16)
136     
137     def write(self, file, item):
138         file.write(('%064x' % (item,)).decode('hex')[::-1])
139
140 class ShortHashType(Type):
141     def read(self, file):
142         data = file.read(160//8)
143         if len(data) != 160//8:
144             raise EarlyEnd('incorrect length!')
145         return int(data[::-1].encode('hex'), 16)
146     
147     def write(self, file, item):
148         file.write(('%020x' % (item,)).decode('hex')[::-1])
149
150 class ListType(Type):
151     def __init__(self, type):
152         self.type = type
153     
154     def read(self, file):
155         length = VarIntType().read(file)
156         return [self.type.read(file) for i in xrange(length)]
157     
158     def write(self, file, item):
159         VarIntType().write(file, len(item))
160         for subitem in item:
161             self.type.write(file, subitem)
162
163 class StructType(Type):
164     def __init__(self, desc):
165         self.desc = desc
166         self.length = struct.calcsize(self.desc)
167     
168     def read(self, file):
169         data = file.read(self.length)
170         if len(data) != self.length:
171             raise EarlyEnd()
172         res, = struct.unpack(self.desc, data)
173         return res
174     
175     def write(self, file, item):
176         data = struct.pack(self.desc, item)
177         if struct.unpack(self.desc, data)[0] != item:
178             # special test because struct doesn't error on some overflows
179             raise ValueError("item didn't survive pack cycle (%r)" % (item,))
180         file.write(data)
181
182 class IPV6AddressType(Type):
183     def read(self, file):
184         data = file.read(16)
185         if len(data) != 16:
186             raise EarlyEnd()
187         if data[:12] != '00000000000000000000ffff'.decode('hex'):
188             raise ValueError("ipv6 addresses not supported yet")
189         return '::ffff:' + '.'.join(str(ord(x)) for x in data[12:])
190     
191     def write(self, file, item):
192         prefix = '::ffff:'
193         if not item.startswith(prefix):
194             raise ValueError("ipv6 addresses not supported yet")
195         item = item[len(prefix):]
196         bits = map(int, item.split('.'))
197         if len(bits) != 4:
198             raise ValueError("invalid address: %r" % (bits,))
199         data = '00000000000000000000ffff'.decode('hex') + ''.join(chr(x) for x in bits)
200         assert len(data) == 16, len(data)
201         file.write(data)
202
203 class ComposedType(Type):
204     def __init__(self, fields):
205         self.fields = fields
206     
207     def read(self, file):
208         item = {}
209         for key, type_ in self.fields:
210             item[key] = type_.read(file)
211         return item
212     
213     def write(self, file, item):
214         for key, type_ in self.fields:
215             type_.write(file, item[key])
216
217 address_type = ComposedType([
218     ('services', StructType('<Q')),
219     ('address', IPV6AddressType()),
220     ('port', StructType('>H')),
221 ])
222
223 tx_type = ComposedType([
224     ('version', StructType('<I')),
225     ('tx_ins', ListType(ComposedType([
226         ('previous_output', ComposedType([
227             ('hash', HashType()),
228             ('index', StructType('<I')),
229         ])),
230         ('script', VarStrType()),
231         ('sequence', StructType('<I')),
232     ]))),
233     ('tx_outs', ListType(ComposedType([
234         ('value', StructType('<Q')),
235         ('script', VarStrType()),
236     ]))),
237     ('lock_time', StructType('<I')),
238 ])
239
240 block_header_type = ComposedType([
241     ('version', StructType('<I')),
242     ('previous_block', HashType()),
243     ('merkle_root', HashType()),
244     ('timestamp', StructType('<I')),
245     ('bits', StructType('<I')),
246     ('nonce', StructType('<I')),
247 ])
248
249 block_type = ComposedType([
250     ('header', block_header_type),
251     ('txs', ListType(tx_type)),
252 ])
253
254 def doublesha(data):
255     return HashType().unpack(hashlib.sha256(hashlib.sha256(data).digest()).digest())
256
257 def ripemdsha(data):
258     return ShortHashType().unpack(hashlib.new('ripemd160', hashlib.sha256(data).digest()).digest())
259
260 merkle_record_type = ComposedType([
261     ('left', HashType()),
262     ('right', HashType()),
263 ])
264
265 def merkle_hash(tx_list):
266     hash_list = [doublesha(tx_type.pack(tx)) for tx in tx_list]
267     while len(hash_list) > 1:
268         hash_list = [doublesha(merkle_record_type.pack(dict(left=left, right=left if right is None else right)))
269             for left, right in zip(hash_list[::2], hash_list[1::2] + [None])]
270     return hash_list[0]
271
272 def tx_hash(tx):
273     return doublesha(tx_type.pack(tx))
274
275 def block_hash(header):
276     return doublesha(block_header.pack(header))
277
278 class BaseProtocol(protocol.Protocol):
279     def connectionMade(self):
280         self.dataReceived = util.DataChunker(self.dataReceiver())
281     
282     def dataReceiver(self):
283         while True:
284             start = ''
285             while start != self._prefix:
286                 start = (start + (yield 1))[-len(self._prefix):]
287             
288             command = (yield 12).rstrip('\0')
289             length, = struct.unpack('<I', (yield 4))
290             
291             if self.use_checksum:
292                 checksum = yield 4
293             else:
294                 checksum = None
295             
296             payload = yield length
297             
298             if checksum is not None:
299                 if hashlib.sha256(hashlib.sha256(payload).digest()).digest()[:4] != checksum:
300                     print 'RECV', command, checksum.encode('hex') if checksum is not None else None, repr(payload.encode('hex')), len(payload)
301                     print 'INVALID HASH'
302                     continue
303             
304             type_ = self.message_types.get(command, None)
305             if type_ is None:
306                 print 'RECV', command, checksum.encode('hex') if checksum is not None else None, repr(payload.encode('hex')), len(payload)
307                 print 'NO TYPE FOR', repr(command)
308                 continue
309             
310             try:
311                 payload2 = type_.unpack(payload)
312             except:
313                 print 'RECV', command, checksum.encode('hex') if checksum is not None else None, repr(payload.encode('hex')), len(payload)
314                 traceback.print_exc()
315                 continue
316             
317             handler = getattr(self, 'handle_' + command, None)
318             if handler is None:
319                 print 'RECV', command, checksum.encode('hex') if checksum is not None else None, repr(payload.encode('hex')), len(payload)
320                 print 'NO HANDLER FOR', command
321                 continue
322             
323             #print 'RECV', command, payload2
324             
325             try:
326                 handler(**payload2)
327             except:
328                 print 'RECV', command, checksum.encode('hex') if checksum is not None else None, repr(payload.encode('hex')), len(payload)
329                 traceback.print_exc()
330                 continue
331     
332     def sendPacket(self, command, payload2={}):
333         payload = self.message_types[command].pack(payload2)
334         if len(command) >= 12:
335             raise ValueError('command too long')
336         if self.use_checksum:
337             checksum = hashlib.sha256(hashlib.sha256(payload).digest()).digest()[:4]
338         else:
339             checksum = ''
340         data = self._prefix + struct.pack('<12sI', command, len(payload)) + checksum + payload
341         self.transport.write(data)
342         #print 'SEND', command, payload2
343     
344     def __getattr__(self, attr):
345         prefix = 'send_'
346         if attr.startswith(prefix):
347             command = attr[len(prefix):]
348             return lambda **payload2: self.sendPacket(command, payload2)
349         #return protocol.Protocol.__getattr__(self, attr)
350         raise AttributeError(attr)
351
352 class Protocol(BaseProtocol):
353     def __init__(self, testnet=False):
354         if testnet:
355             self._prefix = 'fabfb5da'.decode('hex')
356         else:
357             self._prefix = 'f9beb4d9'.decode('hex')
358     
359     version = 0
360     
361     @property
362     def use_checksum(self):
363         return self.version >= 209
364     
365     message_types = {
366         'version': ComposedType([
367             ('version', StructType('<I')),
368             ('services', StructType('<Q')),
369             ('time', StructType('<Q')),
370             ('addr_to', address_type),
371             ('addr_from', address_type),
372             ('nonce', StructType('<Q')),
373             ('sub_version_num', VarStrType()),
374             ('start_height', StructType('<I')),
375         ]),
376         'verack': ComposedType([]),
377         'addr': ComposedType([
378             ('addrs', ListType(ComposedType([
379                 ('timestamp', StructType('<I')),
380                 ('address', address_type),
381             ]))),
382         ]),
383         'inv': ComposedType([
384             ('invs', ListType(ComposedType([
385                 ('type', EnumType(StructType('<I'), {'tx': 1, 'block': 2})),
386                 ('hash', HashType()),
387             ]))),
388         ]),
389         'getdata': ComposedType([
390             ('requests', ListType(ComposedType([
391                 ('type', EnumType(StructType('<I'), {'tx': 1, 'block': 2})),
392                 ('hash', HashType()),
393             ]))),
394         ]),
395         'getblocks': ComposedType([
396             ('version', StructType('<I')),
397             ('have', ListType(HashType())),
398             ('last', HashType()),
399         ]),
400         'getheaders': ComposedType([
401             ('version', StructType('<I')),
402             ('have', ListType(HashType())),
403             ('last', HashType()),
404         ]),
405         'tx': ComposedType([
406             ('tx', tx_type),
407         ]),
408         'block': ComposedType([
409             ('block', block_type),
410         ]),
411         'headers': ComposedType([
412             ('headers', ListType(block_header_type)),
413         ]),
414         'getaddr': ComposedType([]),
415         'checkorder': ComposedType([
416             ('id', HashType()),
417             ('order', FixedStrType(60)), # XXX
418         ]),
419         'submitorder': ComposedType([
420             ('id', HashType()),
421             ('order', FixedStrType(60)), # XXX
422         ]),
423         'reply': ComposedType([
424             ('hash', HashType()),
425             ('reply',  EnumType(StructType('<I'), {'success': 0, 'failure': 1, 'denied': 2})),
426             ('script', VarStrType()),
427         ]),
428         'ping': ComposedType([]),
429         'alert': ComposedType([
430             ('message', VarStrType()),
431             ('signature', VarStrType()),
432         ]),
433     }
434     
435     null_order = '\0'*60
436     
437     def connectionMade(self):
438         BaseProtocol.connectionMade(self)
439         
440         self.send_version(
441             version=32200,
442             services=1,
443             time=int(time.time()),
444             addr_to=dict(
445                 services=1,
446                 address='::ffff:' + self.transport.getPeer().host,
447                 port=self.transport.getPeer().port,
448             ),
449             addr_from=dict(
450                 services=1,
451                 address='::ffff:' + self.transport.getHost().host,
452                 port=self.transport.getHost().port,
453             ),
454             nonce=random.randrange(2**64),
455             sub_version_num='',
456             start_height=0,
457         )
458     
459     def handle_version(self, version, services, time, addr_to, addr_from, nonce, sub_version_num, start_height):
460         #print 'VERSION', locals()
461         self.version_after = version
462         self.send_verack()
463     
464     def handle_verack(self):
465         self.version = self.version_after
466         
467         # connection ready
468         self.check_order = util.GenericDeferrer(2**256, lambda id, order: self.send_checkorder(id=id, order=order))
469         self.submit_order = util.GenericDeferrer(2**256, lambda id, order: self.send_submitorder(id=id, order=order))
470         self.get_block = util.ReplyMatcher(lambda hash: self.send_getdata(requests=[dict(type='block', hash=hash)]))
471         self.get_block_header = util.ReplyMatcher(lambda hash: self.send_getdata(requests=[dict(type='block', hash=hash)]))
472         
473         if hasattr(self.factory, 'resetDelay'):
474             self.factory.resetDelay()
475         if hasattr(self.factory, 'gotConnection'):
476             self.factory.gotConnection(self)
477     
478     def handle_inv(self, invs):
479         for inv in invs:
480             #print 'INV', item['type'], hex(item['hash'])
481             self.send_getdata(requests=[inv])
482     
483     def handle_addr(self, addrs):
484         for addr in addrs:
485             pass#print 'ADDR', addr
486     
487     def handle_reply(self, hash, reply, script):
488         self.check_order.got_response(hash, dict(reply=reply, script=script))
489         self.submit_order.got_response(hash, dict(reply=reply, script=script))
490     
491     def handle_tx(self, tx):
492         #print 'TX', hex(merkle_hash([tx])), tx
493         self.factory.new_tx.happened(tx)
494     
495     def handle_block(self, block):
496         self.get_block.got_response(block_hash(block['header']), block)
497         self.factory.new_block.happened(block)
498     
499     def handle_ping(self):
500         pass
501     
502     def connectionLost(self, reason):
503         if hasattr(self.factory, 'gotConnection'):
504             self.factory.gotConnection(None)
505
506 class ClientFactory(protocol.ReconnectingClientFactory):
507     protocol = Protocol
508     
509     maxDelay = 15
510     
511     def __init__(self, testnet=False):
512         self.testnet = testnet
513         self.conn = util.Variable(None)
514         
515         self.new_block = util.Event()
516         self.new_tx = util.Event()
517     
518     def buildProtocol(self, addr):
519         p = self.protocol(self.testnet)
520         p.factory = self
521         return p
522     
523     def gotConnection(self, conn):
524         self.conn.set(conn)
525         self.conn = conn
526     
527     def getProtocol(self):
528         return self.conn.get_not_none()
529
530 if __name__ == '__main__':
531     factory = ClientFactory()
532     reactor.connectTCP('127.0.0.1', 8333, factory)
533     
534     reactor.run()