fix download link to bitcoind
[electrum-server.git] / transports / stratum_tcp.py
index cb0f203..88ec0cf 100644 (file)
@@ -1,19 +1,31 @@
 import json
+import Queue as queue
 import socket
 import threading
 import time
-import Queue as queue
+import traceback, sys
+
+from processor import Session, Dispatcher
+from utils import print_log
 
-from processor import Session, Dispatcher, timestr
 
 class TcpSession(Session):
 
-    def __init__(self, connection, address):
+    def __init__(self, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
         Session.__init__(self)
-        self._connection = connection
+        if use_ssl:
+            import ssl
+            self._connection = ssl.wrap_socket(
+                connection,
+                server_side=True,
+                certfile=ssl_certfile,
+                keyfile=ssl_keyfile,
+                ssl_version=ssl.PROTOCOL_SSLv23)
+        else:
+            self._connection = connection
+
         self.address = address[0]
-        self.version = 'unknown'
-        self.name = "TCP session"
+        self.name = "TCP " if not use_ssl else "SSL "
 
     def connection(self):
         if self.stopped():
@@ -22,24 +34,34 @@ class TcpSession(Session):
             return self._connection
 
     def stop(self):
+        if self.stopped():
+            return
+
+        try:
+            self._connection.shutdown(socket.SHUT_RDWR)
+        except:
+            # print_log("problem shutting down", self.address)
+            # traceback.print_exc(file=sys.stdout)
+            pass
+
         self._connection.close()
-        #print "Terminating connection:", self.address
         with self.lock:
             self._stopped = True
 
     def send_response(self, response):
-        raw_response = json.dumps(response)
+        data = json.dumps(response) + "\n"
         # Possible race condition here by having session
         # close connection?
         # I assume Python connections are thread safe interfaces
         try:
             connection = self.connection()
-            connection.send(raw_response + "\n")
+            while data:
+                l = connection.send(data)
+                data = data[l:]
         except:
             self.stop()
 
 
-
 class TcpClientRequestor(threading.Thread):
 
     def __init__(self, dispatcher, session):
@@ -54,6 +76,8 @@ class TcpClientRequestor(threading.Thread):
             if not self.update():
                 break
 
+            self.session.time = time.time()
+
             while self.parse():
                 pass
 
@@ -69,7 +93,7 @@ class TcpClientRequestor(threading.Thread):
 
     def receive(self):
         try:
-            return self.session.connection().recv(1024)
+            return self.session.connection().recv(2048)
         except:
             return ''
 
@@ -80,7 +104,7 @@ class TcpClientRequestor(threading.Thread):
 
         raw_command = self.message[0:raw_buffer].strip()
         self.message = self.message[raw_buffer + 1:]
-        if raw_command == 'quit': 
+        if raw_command == 'quit':
             self.session.stop()
             return False
 
@@ -99,13 +123,14 @@ class TcpClientRequestor(threading.Thread):
             # Return an error JSON in response.
             self.dispatcher.push_response({"error": "syntax error", "request": raw_command})
         else:
-            self.dispatcher.push_request(self.session,command)
+            self.dispatcher.push_request(self.session, command)
 
         return True
 
+
 class TcpServer(threading.Thread):
 
-    def __init__(self, dispatcher, host, port):
+    def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile):
         self.shared = dispatcher.shared
         self.dispatcher = dispatcher.request_dispatcher
         threading.Thread.__init__(self)
@@ -113,20 +138,39 @@ class TcpServer(threading.Thread):
         self.host = host
         self.port = port
         self.lock = threading.Lock()
+        self.use_ssl = use_ssl
+        self.ssl_keyfile = ssl_keyfile
+        self.ssl_certfile = ssl_certfile
 
     def run(self):
-        print "TCP server started."
+        if self.use_ssl:
+            print_log("TCP/SSL server started.")
+        else:
+            print_log("TCP server started.")
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
         sock.bind((self.host, self.port))
-        sock.listen(1)
-        while not self.shared.stopped():
-            session = TcpSession(*sock.accept())
-            client_req = TcpClientRequestor(self.dispatcher, session)
-            client_req.start()
-            self.dispatcher.add_session(session)
-            self.dispatcher.collect_garbage()
-
+        sock.listen(5)
 
+        while not self.shared.stopped():
 
+            try:
+                connection, address = sock.accept()
+            except:
+                traceback.print_exc(file=sys.stdout)
+                time.sleep(0.1)
+                continue
+
+            try:
+                session = TcpSession(connection, address, use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
+            except BaseException, e:
+                error = str(e)
+                print_log("cannot start TCP session", error, address)
+                connection.close()
+                time.sleep(0.1)
+                continue
 
+            self.dispatcher.add_session(session)
+            self.dispatcher.collect_garbage()
+            client_req = TcpClientRequestor(self.dispatcher, session)
+            client_req.start()