use set instead of list in memorypool_update
[electrum-server.git] / transports / stratum_tcp.py
index 5fb26a5..aad4c17 100644 (file)
@@ -11,8 +11,8 @@ from utils import print_log
 
 class TcpSession(Session):
 
-    def __init__(self, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
-        Session.__init__(self)
+    def __init__(self, dispatcher, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
+        Session.__init__(self, dispatcher)
         self.use_ssl = use_ssl
         if use_ssl:
             import ssl
@@ -26,9 +26,11 @@ class TcpSession(Session):
         else:
             self._connection = connection
 
-        self.address = address[0]
+        self.address = address[0] + ":%d"%address[1]
         self.name = "TCP " if not use_ssl else "SSL "
+        self.timeout = 1000
         self.response_queue = queue.Queue()
+        self.dispatcher.add_session(self)
 
     def do_handshake(self):
         if self.use_ssl:
@@ -40,10 +42,7 @@ class TcpSession(Session):
         else:
             return self._connection
 
-    def stop(self):
-        if self.stopped():
-            return
-
+    def shutdown(self):
         try:
             self._connection.shutdown(socket.SHUT_RDWR)
         except:
@@ -52,8 +51,6 @@ class TcpSession(Session):
             pass
 
         self._connection.close()
-        with self.lock:
-            self._stopped = True
 
     def send_response(self, response):
         self.response_queue.put(response)
@@ -67,7 +64,10 @@ class TcpClientResponder(threading.Thread):
 
     def run(self):
         while not self.session.stopped():
-            response = self.session.response_queue.get()
+            try:
+                response = self.session.response_queue.get(timeout=10)
+            except queue.Empty:
+                continue
             data = json.dumps(response) + "\n"
             try:
                 while data:
@@ -91,26 +91,22 @@ class TcpClientRequestor(threading.Thread):
         try:
             self.session.do_handshake()
         except:
+            self.session.stop()
             return
 
         while not self.shared.stopped():
-            if not self.update():
+
+            data = self.receive()
+            if not data:
+                self.session.stop()
                 break
 
+            self.message += data
             self.session.time = time.time()
 
             while self.parse():
                 pass
 
-    def update(self):
-        data = self.receive()
-        if not data:
-            # close_session
-            self.session.stop()
-            return False
-
-        self.message += data
-        return True
 
     def receive(self):
         try:
@@ -145,6 +141,8 @@ class TcpClientRequestor(threading.Thread):
             self.dispatcher.push_response({"error": "syntax error", "request": raw_command})
         else:
             self.dispatcher.push_request(self.session, command)
+            # sleep a bit to prevent a single session from DOSing the queue
+            time.sleep(0.01)
 
         return True
 
@@ -182,7 +180,7 @@ class TcpServer(threading.Thread):
 
             #if self.use_ssl: print_log("SSL: new session", address)
             try:
-                session = TcpSession(connection, address, use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
+                session = TcpSession(self.dispatcher, connection, address, use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
             except BaseException, e:
                 error = str(e)
                 print_log("cannot start TCP session", error, address)
@@ -190,8 +188,6 @@ class TcpServer(threading.Thread):
                 time.sleep(0.1)
                 continue
 
-            self.dispatcher.add_session(session)
-            self.dispatcher.collect_garbage()
             client_req = TcpClientRequestor(self.dispatcher, session)
             client_req.start()
             responder = TcpClientResponder(session)