broadcast shares in serial
authorForrest Voight <forrest@forre.st>
Wed, 4 Jul 2012 02:26:18 +0000 (22:26 -0400)
committerForrest Voight <forrest@forre.st>
Sat, 7 Jul 2012 21:51:13 +0000 (17:51 -0400)
p2pool/main.py
p2pool/p2p.py
p2pool/util/p2protocol.py
p2pool/util/variable.py

index 29522dc..4073f7f 100644 (file)
@@ -391,6 +391,7 @@ def main(args, net, datadir_path, merged_urls, worker_endpoint):
             for peer in p2p_node.peers.itervalues():
                 peer.send_bestblock(header=header)
         
+        @defer.inlineCallbacks
         def broadcast_share(share_hash):
             shares = []
             for share in tracker.get_chain(share_hash, min(5, tracker.get_height(share_hash))):
@@ -399,8 +400,8 @@ def main(args, net, datadir_path, merged_urls, worker_endpoint):
                 shared_share_hashes.add(share.hash)
                 shares.append(share)
             
-            for peer in p2p_node.peers.itervalues():
-                peer.sendShares([share for share in shares if share.peer is not peer])
+            for peer in list(p2p_node.peers.itervalues()):
+                yield peer.sendShares([share for share in shares if share.peer is not peer])
         
         # send share when the chain changes to their chain
         best_share_var.changed.watch(broadcast_share)
index 9899f5f..871a2d9 100644 (file)
@@ -25,6 +25,8 @@ class Protocol(p2protocol.Protocol):
         self.connected2 = False
     
     def connectionMade(self):
+        p2protocol.Protocol.connectionMade(self)
+        
         self.factory.proto_made_connection(self)
         
         self.addr = self.transport.getPeer().host, self.transport.getPeer().port
@@ -205,12 +207,14 @@ class Protocol(p2protocol.Protocol):
     def sendShares(self, shares):
         def att(f, **kwargs):
             try:
-                f(**kwargs)
+                return f(**kwargs)
             except p2protocol.TooLong:
                 att(f, **dict((k, v[:len(v)//2]) for k, v in kwargs.iteritems()))
-                att(f, **dict((k, v[len(v)//2:]) for k, v in kwargs.iteritems()))
+                return att(f, **dict((k, v[len(v)//2:]) for k, v in kwargs.iteritems()))
         if shares:
-            att(self.send_shares, shares=[share.as_share() for share in shares])
+            return att(self.send_shares, shares=[share.as_share() for share in shares])
+        else:
+            return defer.succeed(None)
     
     
     message_sharereq = pack.ComposedType([
index 1886713..ec3a6f9 100644 (file)
@@ -9,7 +9,7 @@ from twisted.internet import protocol
 from twisted.python import log
 
 import p2pool
-from p2pool.util import datachunker
+from p2pool.util import datachunker, variable
 
 class TooLong(Exception):
     pass
@@ -19,6 +19,19 @@ class Protocol(protocol.Protocol):
         self._message_prefix = message_prefix
         self._max_payload_length = max_payload_length
         self.dataReceived = datachunker.DataChunker(self.dataReceiver())
+        self.paused_var = variable.Variable(False)
+    
+    def connectionMade(self):
+        self.transport.registerProducer(self, True)
+    
+    def pauseProducing(self):
+        self.paused_var.set(True)
+    
+    def resumeProducing(self):
+        self.paused_var.set(False)
+    
+    def stopProducing(self):
+        pass
     
     def dataReceiver(self):
         while True:
@@ -74,6 +87,7 @@ class Protocol(protocol.Protocol):
         if len(payload) > self._max_payload_length:
             raise TooLong('payload too long')
         self.transport.write(self._message_prefix + struct.pack('<12sI', command, len(payload)) + hashlib.sha256(hashlib.sha256(payload).digest()).digest()[:4] + payload)
+        return self.paused_var.get_when_satisfies(lambda paused: not paused)
     
     def __getattr__(self, attr):
         prefix = 'send_'
index 8155e2c..f11b94f 100644 (file)
@@ -66,10 +66,12 @@ class Variable(object):
         self.changed.happened(value)
         self.transitioned.happened(oldvalue, value)
     
+    @defer.inlineCallbacks
+    def get_when_satisfies(self, func):
+        while True:
+            if func(self.value):
+                defer.returnValue(self.value)
+            yield self.changed.once.get_deferred()
+    
     def get_not_none(self):
-        if self.value is not None:
-            return defer.succeed(self.value)
-        else:
-            df = defer.Deferred()
-            self.changed.once.watch(df.callback)
-            return df
+        return self.get_when_satisfies(lambda val: val is not None)