replace task.LoopingCall's with deferral.RobustLoopingCall that catches errors and...
[p2pool.git] / p2pool / bitcoin / p2p.py
1 '''
2 Implementation of Bitcoin's p2p protocol
3 '''
4
5 import random
6 import sys
7 import time
8
9 from twisted.internet import protocol
10
11 import p2pool
12 from . import data as bitcoin_data
13 from p2pool.util import deferral, p2protocol, pack, variable
14
15 class Protocol(p2protocol.Protocol):
16     def __init__(self, net):
17         p2protocol.Protocol.__init__(self, net.P2P_PREFIX, 1000000, ignore_trailing_payload=True)
18     
19     def connectionMade(self):
20         self.send_version(
21             version=32200,
22             services=1,
23             time=int(time.time()),
24             addr_to=dict(
25                 services=1,
26                 address=self.transport.getPeer().host,
27                 port=self.transport.getPeer().port,
28             ),
29             addr_from=dict(
30                 services=1,
31                 address=self.transport.getHost().host,
32                 port=self.transport.getHost().port,
33             ),
34             nonce=random.randrange(2**64),
35             sub_version_num='/P2Pool:%s/' % (p2pool.__version__,),
36             start_height=0,
37         )
38     
39     message_version = pack.ComposedType([
40         ('version', pack.IntType(32)),
41         ('services', pack.IntType(64)),
42         ('time', pack.IntType(64)),
43         ('addr_to', bitcoin_data.address_type),
44         ('addr_from', bitcoin_data.address_type),
45         ('nonce', pack.IntType(64)),
46         ('sub_version_num', pack.VarStrType()),
47         ('start_height', pack.IntType(32)),
48     ])
49     def handle_version(self, version, services, time, addr_to, addr_from, nonce, sub_version_num, start_height):
50         self.send_verack()
51     
52     message_verack = pack.ComposedType([])
53     def handle_verack(self):
54         self.get_block = deferral.ReplyMatcher(lambda hash: self.send_getdata(requests=[dict(type='block', hash=hash)]))
55         self.get_block_header = deferral.ReplyMatcher(lambda hash: self.send_getheaders(version=1, have=[], last=hash))
56         
57         if hasattr(self.factory, 'resetDelay'):
58             self.factory.resetDelay()
59         if hasattr(self.factory, 'gotConnection'):
60             self.factory.gotConnection(self)
61         
62         self.pinger = deferral.RobustLoopingCall(self.send_ping)
63         self.pinger.start(30)
64     
65     message_inv = pack.ComposedType([
66         ('invs', pack.ListType(pack.ComposedType([
67             ('type', pack.EnumType(pack.IntType(32), {1: 'tx', 2: 'block'})),
68             ('hash', pack.IntType(256)),
69         ]))),
70     ])
71     def handle_inv(self, invs):
72         for inv in invs:
73             if inv['type'] == 'tx':
74                 self.send_getdata(requests=[inv])
75             elif inv['type'] == 'block':
76                 self.factory.new_block.happened(inv['hash'])
77             else:
78                 print 'Unknown inv type', inv
79     
80     message_getdata = pack.ComposedType([
81         ('requests', pack.ListType(pack.ComposedType([
82             ('type', pack.EnumType(pack.IntType(32), {1: 'tx', 2: 'block'})),
83             ('hash', pack.IntType(256)),
84         ]))),
85     ])
86     message_getblocks = pack.ComposedType([
87         ('version', pack.IntType(32)),
88         ('have', pack.ListType(pack.IntType(256))),
89         ('last', pack.PossiblyNoneType(0, pack.IntType(256))),
90     ])
91     message_getheaders = pack.ComposedType([
92         ('version', pack.IntType(32)),
93         ('have', pack.ListType(pack.IntType(256))),
94         ('last', pack.PossiblyNoneType(0, pack.IntType(256))),
95     ])
96     message_getaddr = pack.ComposedType([])
97     
98     message_addr = pack.ComposedType([
99         ('addrs', pack.ListType(pack.ComposedType([
100             ('timestamp', pack.IntType(32)),
101             ('address', bitcoin_data.address_type),
102         ]))),
103     ])
104     def handle_addr(self, addrs):
105         for addr in addrs:
106             pass
107     
108     message_tx = pack.ComposedType([
109         ('tx', bitcoin_data.tx_type),
110     ])
111     def handle_tx(self, tx):
112         self.factory.new_tx.happened(tx)
113     
114     message_block = pack.ComposedType([
115         ('block', bitcoin_data.block_type),
116     ])
117     def handle_block(self, block):
118         block_hash = bitcoin_data.hash256(bitcoin_data.block_header_type.pack(block['header']))
119         self.get_block.got_response(block_hash, block)
120         self.get_block_header.got_response(block_hash, block['header'])
121     
122     message_headers = pack.ComposedType([
123         ('headers', pack.ListType(bitcoin_data.block_type)),
124     ])
125     def handle_headers(self, headers):
126         for header in headers:
127             header = header['header']
128             self.get_block_header.got_response(bitcoin_data.hash256(bitcoin_data.block_header_type.pack(header)), header)
129         self.factory.new_headers.happened([header['header'] for header in headers])
130     
131     message_ping = pack.ComposedType([])
132     def handle_ping(self):
133         pass
134     
135     message_alert = pack.ComposedType([
136         ('message', pack.VarStrType()),
137         ('signature', pack.VarStrType()),
138     ])
139     def handle_alert(self, message, signature):
140         pass # print 'ALERT:', (message, signature)
141     
142     def connectionLost(self, reason):
143         if hasattr(self.factory, 'gotConnection'):
144             self.factory.gotConnection(None)
145         if hasattr(self, 'pinger'):
146             self.pinger.stop()
147         if p2pool.DEBUG:
148             print >>sys.stderr, 'Bitcoin connection lost. Reason:', reason.getErrorMessage()
149
150 class ClientFactory(protocol.ReconnectingClientFactory):
151     protocol = Protocol
152     
153     maxDelay = 1
154     
155     def __init__(self, net):
156         self.net = net
157         self.conn = variable.Variable(None)
158         
159         self.new_block = variable.Event()
160         self.new_tx = variable.Event()
161         self.new_headers = variable.Event()
162     
163     def buildProtocol(self, addr):
164         p = self.protocol(self.net)
165         p.factory = self
166         return p
167     
168     def gotConnection(self, conn):
169         self.conn.set(conn)
170     
171     def getProtocol(self):
172         return self.conn.get_not_none()