setup.py
[p2pool.git] / p2p.py
diff --git a/p2p.py b/p2p.py
index c0ecd4c..c0fec4e 100644 (file)
--- a/p2p.py
+++ b/p2p.py
@@ -4,43 +4,28 @@ import random
 import time
 import traceback
 
-from twisted.internet import defer, reactor, protocol, task
+from twisted.internet import defer, protocol, reactor
 
 import bitcoin_p2p
 import conv
-import util
 import p2pool
+import util
 
 # mode
 #     0: send hash first (high latency, low bandwidth)
 #     1: send entire share (low latency, high bandwidth)
 
-if 0:
-    import pygame
-    d = pygame.display.set_mode((512, 512))
-    task.LoopingCall(pygame.display.update).start(.1)
-    def draw_circle(id, color=(255,0,0)):
-        id = repr(id)
-        pygame.draw.circle(d, (255, 0, 0), (hash(id)%512, hash(id)//512%512), 4)
-    def draw_line(id, id2, color):
-        id = repr(id)
-        pygame.draw.line(d, color, (hash(id)%512, hash(id)//512%512), (hash(id2)%512, hash(id2)//512%512))
-else:
-    draw_circle = draw_line = lambda *args, **kwargs: None
-
 class Protocol(bitcoin_p2p.BaseProtocol):
-    version = 0
-    sub_version = ""
+    version = 1
+    sub_version = ''
     
     def __init__(self, node):
         self.node = node
-    
-    @property
-    def _prefix(self):
+        
         if self.node.testnet:
-            return 'f77cea5d16a2183f'.decode('hex')
+            self._prefix = 'f77cea5d16a2183f'.decode('hex')
         else:
-            return '95ec1eda53c5e716'.decode('hex')
+            self._prefix = '95ec1eda53c5e716'.decode('hex')
     
     use_checksum = True
     
@@ -81,11 +66,12 @@ class Protocol(bitcoin_p2p.BaseProtocol):
             ('count', bitcoin_p2p.StructType('<I')),
         ]),
         
-        'getsharesbychain': bitcoin_p2p.ComposedType([
+        'gettobest': bitcoin_p2p.ComposedType([
             ('chain_id', p2pool.chain_id_type),
             ('have', bitcoin_p2p.ListType(bitcoin_p2p.HashType())),
         ]),
         'getshares': bitcoin_p2p.ComposedType([
+            ('chain_id', p2pool.chain_id_type),
             ('hashes', bitcoin_p2p.ListType(bitcoin_p2p.HashType())),
         ]),
         
@@ -105,6 +91,7 @@ class Protocol(bitcoin_p2p.BaseProtocol):
     
     other_version = None
     node_var_watch = None
+    connected2 = False
     
     @property
     def mode(self):
@@ -112,10 +99,9 @@ class Protocol(bitcoin_p2p.BaseProtocol):
     
     def connectionMade(self):
         bitcoin_p2p.BaseProtocol.connectionMade(self)
-        if isinstance(self.factory, ClientFactory):
-            draw_line(self.node.port, self.transport.getPeer().port, (128, 128, 128))
         
-        chain = self.node.current_work['current_chain']
+        chain = self.node.current_work.value['current_chain']
+        highest_share2 = chain.get_highest_share2()
         self.send_version(
             version=self.version,
             services=0,
@@ -133,29 +119,21 @@ class Protocol(bitcoin_p2p.BaseProtocol):
             sub_version=self.sub_version,
             mode=self.node.mode_var.value,
             state=dict(
-                chain_id=dict(
-                    last_p2pool_block_hash=0,
-                    bits=0,
-                ),
+                chain_id=p2pool.chain_id_type.unpack(chain.chain_id_data),
                 highest=dict(
-                    hash=0,
-                    height=0,
+                    hash=highest_share2.share.hash if highest_share2 is not None else 2**256-1,
+                    height=highest_share2.height if highest_share2 is not None else 0,
                 ),
             ),
         )
         
         self.node_var_watch = self.node.mode_var.changed.watch(lambda new_mode: self.send_set_mode(mode=new_mode))
         
-        self.connected2 = False
-        
-        self._think()
-        self._think2()
-        
         reactor.callLater(10, self._connect_timeout)
     
     def _connect_timeout(self):
         if not self.connected2 and self.transport.connected:
-            print "Handshake timed out, disconnecting"
+            print 'Handshake timed out, disconnecting from %s:%i' % (self.transport.getPeer().host, self.transport.getPeer().port)
             self.transport.loseConnection()
     
     @defer.inlineCallbacks
@@ -168,6 +146,7 @@ class Protocol(bitcoin_p2p.BaseProtocol):
     def _think2(self):
         while self.connected2:
             self.send_addrme(port=self.node.port)
+            #print 'sending addrme'
             yield util.sleep(random.expovariate(1/100))
     
     def handle_version(self, version, services, addr_to, addr_from, nonce, sub_version, mode, state):
@@ -176,16 +155,26 @@ class Protocol(bitcoin_p2p.BaseProtocol):
         self.other_mode_var = util.Variable(mode)
         
         if nonce == self.node.nonce:
-            #print "Detected connection to self, disconnecting"
+            #print 'Detected connection to self, disconnecting from %s:%i' % (self.transport.getPeer().host, self.transport.getPeer().port)
+            self.transport.loseConnection()
+            return
+        if nonce in self.node.peers:
+            print 'Detected duplicate connection, disconnecting from %s:%i' % (self.transport.getPeer().host, self.transport.getPeer().port)
             self.transport.loseConnection()
             return
         
-        # XXX use state
-        
+        self.nonce = nonce
         self.connected2 = True
-        self.node.got_conn(self, services)
-        if isinstance(self.factory, ClientFactory):
-            draw_line(self.node.port, self.transport.getPeer().port, (0, 255, 0))
+        self.node.got_conn(self)
+        
+        self._think()
+        self._think2()
+        
+        if state['highest']['hash'] != 2**256 - 1:
+            self.handle_share0s(chains=[dict(
+                chain_id=state['chain_id'],
+                hashes=[state['highest']['hash']],
+            )])
     
     def handle_set_mode(self, mode):
         self.other_mode_var.set(mode)
@@ -194,13 +183,28 @@ class Protocol(bitcoin_p2p.BaseProtocol):
         pass
     
     def handle_addrme(self, port):
-        self.node.got_addr(('::ffff:' + self.transport.getPeer().host, port), self.other_services, int(time.time()))
-        if random.random() < .7 and self.node.peers:
-            random.choice(self.node.peers.values()).send_addrs(addrs=[dict(address=dict(services=self.other_services, address='::ffff:' + self.transport.getPeer().host, port=port), timestamp=int(time.time()))])
+        host = self.transport.getPeer().host
+        #print 'addrme from', host, port
+        if host == '127.0.0.1':
+            if random.random() < .8 and self.node.peers:
+                random.choice(self.node.peers.values()).send_addrme(port=port) # services...
+        else:
+            self.node.got_addr(('::ffff:' + self.transport.getPeer().host, port), self.other_services, int(time.time()))
+            if random.random() < .8 and self.node.peers:
+                random.choice(self.node.peers.values()).send_addrs(addrs=[
+                    dict(
+                        address=dict(
+                            services=self.other_services,
+                            address='::ffff:' + host,
+                            port=port,
+                        ),
+                        timestamp=int(time.time()),
+                    ),
+                ])
     def handle_addrs(self, addrs):
         for addr_record in addrs:
             self.node.got_addr((addr_record['address']['address'], addr_record['address']['port']), addr_record['address']['services'], min(int(time.time()), addr_record['timestamp']))
-            if random.random() < .7 and self.node.peers:
+            if random.random() < .8 and self.node.peers:
                 random.choice(self.node.peers.values()).send_addrs(addrs=[addr_record])
     def handle_getaddrs(self, count):
         self.send_addrs(addrs=[
@@ -215,39 +219,45 @@ class Protocol(bitcoin_p2p.BaseProtocol):
             random.sample(self.node.addr_store.keys(), min(count, len(self.node.addr_store)))
         ])
     
+    def handle_gettobest(self, chain_id, have):
+        self.node.handle_get_to_best(p2pool.chain_id_type.pack(chain_id), have, self)
+    
+    def handle_getshares(self, chain_id, hashes):
+        self.node.handle_get_shares(p2pool.chain_id_type.pack(chain_id), hashes, self)
+    
     def handle_share0s(self, chains):
         for chain in chains:
             for hash_ in chain['hashes']:
-                self.node.handle_share_hash((chain['chain_id']['previous_p2pool_block'], chain['chain_id']['bits']), hash_)
+                self.node.handle_share_hash(p2pool.chain_id_type.pack(chain['chain_id']), hash_, self)
     def handle_share1s(self, share1s):
         for share1 in share1s:
             hash_ = bitcoin_p2p.block_hash(share1['header'])
             if hash_ <= conv.bits_to_target(share1['header']['bits']):
-                print "Dropping peer %s:%i due to invalid share" % (self.transport.getPeer().host, self.transport.getPeer().port)
+                print 'Dropping peer %s:%i due to invalid share' % (self.transport.getPeer().host, self.transport.getPeer().port)
                 self.transport.loseConnection()
                 return
-            share = Share(share1['header'], gentx=share1['gentx'])
-            self.node.handle_share(share)
+            share = p2pool.Share(share1['header'], gentx=share1['gentx'])
+            self.node.handle_share(share, self)
     def handle_share2s(self, share2s):
         for share2 in share2s:
             hash_ = bitcoin_p2p.block_hash(share2['header'])
-            if not hash_ <= conv.bits_to_target(share1['header']['bits']):
-                print "Dropping peer %s:%i due to invalid share" % (self.transport.getPeer().host, self.transport.getPeer().port)
+            if not hash_ <= conv.bits_to_target(share2['header']['bits']):
+                print 'Dropping peer %s:%i due to invalid share' % (self.transport.getPeer().host, self.transport.getPeer().port)
                 self.transport.loseConnection()
                 return
-            share = Share(share1['header'], txns=share1['txns'])
-            self.node.handle_share(share)
+            share = p2pool.Share(share2['header'], txns=share2['txns'])
+            self.node.handle_share(share, self)
     
-    def send_share(self, share):
+    def send_share(self, share, full=False):
         if share.hash <= conv.bits_to_target(share.header['bits']):
             self.send_share2s(share2s=[share.as_block()])
         else:
-            if self.mode == 0:
+            if self.mode == 0 and not full:
                 self.send_share0s(chains=[dict(
                     chain_id=p2pool.chain_id_type.unpack(share.chain_id_data),
                     hashes=[share.hash],
                 )])
-            elif self.mode == 1:
+            elif self.mode == 1 or full:
                 self.send_share1s(share1s=[dict(
                     header=share.header,
                     gentx=share.gentx,
@@ -261,9 +271,6 @@ class Protocol(bitcoin_p2p.BaseProtocol):
         
         if self.connected2:
             self.node.lost_conn(self)
-        
-        if isinstance(self.factory, ClientFactory):
-            draw_line(self.node.port, self.transport.getPeer().port, (255, 0, 0))
 
 class ServerFactory(protocol.ServerFactory):
     def __init__(self, node):
@@ -316,7 +323,7 @@ class AddrStore(util.DictWrapper):
         return v['services'], v['first_seen'], v['last_seen']
 
 class Node(object):
-    def __init__(self, port, testnet, addr_store=None, preferred_addrs=[], mode=0, desired_peers=10, max_attempts=100):
+    def __init__(self, current_work, port, testnet, addr_store=None, preferred_addrs=[], mode=0, desired_peers=10, max_attempts=100):
         if addr_store is None:
             addr_store = {}
         
@@ -327,19 +334,16 @@ class Node(object):
         self.mode_var = util.Variable(mode)
         self.desired_peers = desired_peers
         self.max_attempts = max_attempts
-        
-        self.current_work = dict(current_chain=None)
+        self.current_work = current_work
         
         self.nonce = random.randrange(2**64)
         self.attempts = {}
         self.peers = {}
         self.running = False
-        
-        draw_circle(self.port)
     
     def start(self):
         if self.running:
-            raise ValueError("already running")
+            raise ValueError('already running')
         
         self.running = True
         
@@ -359,10 +363,11 @@ class Node(object):
                         host2, port = random.choice(self.addr_store.keys())
                         prefix = '::ffff:'
                         if not host2.startswith(prefix):
-                            raise ValueError("invalid address")
+                            raise ValueError('invalid address')
                         host = host2[len(prefix):]
                     
-                    if (host, port) not in self.attempts and (host, port) not in self.peers:
+                    if (host, port) not in self.attempts:
+                        #print 'Trying to connect to', host, port
                         reactor.connectTCP(host, port, ClientFactory(self), timeout=10)
             except:
                 traceback.print_exc()
@@ -382,7 +387,7 @@ class Node(object):
     
     def stop(self):
         if not self.running:
-            raise ValueError("already stopped")
+            raise ValueError('already stopped')
         
         self.running = False
         
@@ -392,7 +397,7 @@ class Node(object):
     def attempt_started(self, connector):
         host, port = connector.getDestination().host, connector.getDestination().port
         if (host, port) in self.attempts:
-            raise ValueError("already have attempt")
+            raise ValueError('already have attempt')
         self.attempts[host, port] = connector
     
     def attempt_failed(self, connector):
@@ -403,27 +408,25 @@ class Node(object):
         if (host, port) not in self.attempts:
             raise ValueError("don't have attempt")
         if connector is not self.attempts[host, port]:
-            raise ValueError("wrong connector")
+            raise ValueError('wrong connector')
         del self.attempts[host, port]
     
     
-    def got_conn(self, conn, services):
-        host, port = conn.transport.getPeer().host, conn.transport.getPeer().port
-        if (host, port) in self.peers:
-            raise ValueError("already have peer")
-        self.peers[host, port] = conn
+    def got_conn(self, conn):
+        if conn.nonce in self.peers:
+            raise ValueError('already have peer')
+        self.peers[conn.nonce] = conn
         
-        print "Connected to peer %s:%i" % (host, port)
+        print 'Connected to peer %s:%i' % (conn.transport.getPeer().host, conn.transport.getPeer().port)
     
     def lost_conn(self, conn):
-        host, port = conn.transport.getPeer().host, conn.transport.getPeer().port
-        if (host, port) not in self.peers:
+        if conn.nonce not in self.peers:
             raise ValueError("don't have peer")
-        if conn is not self.peers[host, port]:
-            raise ValueError("wrong conn")
-        del self.peers[host, port]
+        if conn is not self.peers[conn.nonce]:
+            raise ValueError('wrong conn')
+        del self.peers[conn.nonce]
         
-        print "Lost peer %s:%i" % (host, port)
+        print 'Lost peer %s:%i' % (conn.transport.getPeer().host, conn.transport.getPeer().port)
     
     
     def got_addr(self, (host, port), services, timestamp):
@@ -432,6 +435,18 @@ class Node(object):
             self.addr_store[host, port] = services, old_first_seen, max(old_last_seen, timestamp)
         else:
             self.addr_store[host, port] = services, timestamp, timestamp
+    
+    def handle_share(self, share, peer):
+        print 'handle_share', (share, peer)
+    
+    def handle_share_hash(self, chain_id_data, hash, peer):
+        print 'handle_share_hash', (chain_id_data, hash, peer)
+    
+    def handle_get_to_best(self, chain_id_data, have, peer):
+        print 'handle_get_to_best', (chain_id_data, have, peer)
+    
+    def handle_get_shares(self, chain_id_data, hashes, peer):
+        print 'handle_get_shares', (chain_id_data, hashes, peer)
 
 if __name__ == '__main__':
     p = random.randrange(2**15, 2**16)