Merge pull request #5 from zootreeves/master
[electrum-server.git] / transports / stratum_http.py
index 959beff..397e926 100644 (file)
@@ -27,6 +27,8 @@ import types
 import traceback
 import sys, threading
 
+from OpenSSL import SSL
+
 try:
     import fcntl
 except ImportError:
@@ -131,6 +133,23 @@ class StratumJSONRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
         return response
 
 
+    def create_session(self):
+        session_id = random_string(10)
+        session = HttpSession(session_id)
+        self.dispatcher.add_session(session)
+        return session_id
+
+    def poll_session(self, session):
+        q = session.pending_responses
+        responses = []
+        while not q.empty():
+            r = q.get()
+            responses.append(r)
+        #print "poll: %d responses"%len(responses)
+        return responses
+
+
+
 
 class StratumJSONRPCRequestHandler(
         SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
@@ -181,7 +200,7 @@ class StratumJSONRPCRequestHandler(
         self.end_headers()
         self.wfile.write(response)
         self.wfile.flush()
-        self.connection.shutdown(1)
+        self.shutdown_connection()
 
 
     def do_POST(self):
@@ -230,10 +249,39 @@ class StratumJSONRPCRequestHandler(
         self.end_headers()
         self.wfile.write(response)
         self.wfile.flush()
+        self.shutdown_connection()
+
+    def shutdown_connection(self):
         self.connection.shutdown(1)
 
 
-class StratumJSONRPCServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
+class SSLRequestHandler(StratumJSONRPCRequestHandler):
+    def setup(self):
+        self.connection = self.request
+        self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
+        self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
+
+    def shutdown_connection(self):
+        self.connection.shutdown()
+
+
+class SSLTCPServer(SocketServer.TCPServer):
+    def __init__(self, server_address, certfile, keyfile, RequestHandlerClass, bind_and_activate=True):
+        SocketServer.BaseServer.__init__(self, server_address, RequestHandlerClass)
+        ctx = SSL.Context(SSL.SSLv3_METHOD)
+        ctx.use_privatekey_file(keyfile)
+        ctx.use_certificate_file(certfile)
+        self.socket = SSL.Connection(ctx, socket.socket(self.address_family, self.socket_type))
+        if bind_and_activate:
+            self.server_bind()
+            self.server_activate()
+
+    def shutdown_request(self,request):
+        #request.shutdown()
+        pass
+
+
+class StratumHTTPServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
 
     allow_reuse_address = True
 
@@ -255,12 +303,42 @@ class StratumJSONRPCServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
                     os.unlink(addr)
                 except OSError:
                     logging.warning("Could not unlink socket %s", addr)
-        # if python 2.5 and lower
-        if vi[0] < 3 and vi[1] < 6:
-            SocketServer.TCPServer.__init__(self, addr, requestHandler)
-        else:
-            SocketServer.TCPServer.__init__(self, addr, requestHandler,
-                bind_and_activate)
+
+        SocketServer.TCPServer.__init__(self, addr, requestHandler, bind_and_activate)
+
+        if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
+            flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
+            flags |= fcntl.FD_CLOEXEC
+            fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
+
+
+class StratumHTTPSSLServer(SSLTCPServer, StratumJSONRPCDispatcher):
+
+    allow_reuse_address = True
+
+    def __init__(self, addr, certfile, keyfile,
+                 requestHandler=SSLRequestHandler,
+                 logRequests=False, encoding=None, bind_and_activate=True,
+                 address_family=socket.AF_INET):
+
+        self.logRequests = logRequests
+        StratumJSONRPCDispatcher.__init__(self, encoding)
+        # TCPServer.__init__ has an extra parameter on 2.6+, so
+        # check Python version and decide on how to call it
+        vi = sys.version_info
+        self.address_family = address_family
+        if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
+            # Unix sockets can't be bound if they already exist in the
+            # filesystem. The convention of e.g. X11 is to unlink
+            # before binding again.
+            if os.path.exists(addr): 
+                try:
+                    os.unlink(addr)
+                except OSError:
+                    logging.warning("Could not unlink socket %s", addr)
+
+        SSLTCPServer.__init__(self, addr, certfile, keyfile, requestHandler, bind_and_activate)
+
         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
             flags |= fcntl.FD_CLOEXEC
@@ -268,20 +346,7 @@ class StratumJSONRPCServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
 
 
 
-    def create_session(self):
-        session_id = random_string(10)
-        session = HttpSession(session_id)
-        self.dispatcher.add_session(session)
-        return session_id
 
-    def poll_session(self, session):
-        q = session.pending_responses
-        responses = []
-        while not q.empty():
-            r = q.get()
-            responses.append(r)
-        #print "poll: %d responses"%len(responses)
-        return responses
 
 
 from processor import Session
@@ -306,25 +371,34 @@ class HttpSession(Session):
             return self._stopped
 
 class HttpServer(threading.Thread):
-    def __init__(self, dispatcher, host, port):
+    def __init__(self, dispatcher, host, port, use_ssl, certfile, keyfile):
         self.shared = dispatcher.shared
         self.dispatcher = dispatcher.request_dispatcher
         threading.Thread.__init__(self)
         self.daemon = True
         self.host = host
         self.port = port
+        self.use_ssl = use_ssl
+        self.certfile = certfile
+        self.keyfile = keyfile
         self.lock = threading.Lock()
 
+
     def run(self):
         # see http://code.google.com/p/jsonrpclib/
         from SocketServer import ThreadingMixIn
-        class StratumThreadedJSONRPCServer(ThreadingMixIn, StratumJSONRPCServer): pass
+        if self.use_ssl:
+            class StratumThreadedServer(ThreadingMixIn, StratumHTTPSSLServer): pass
+            self.server = StratumThreadedServer(( self.host, self.port), self.certfile, self.keyfile)
+            print "HTTPS server started."
+        else:
+            class StratumThreadedServer(ThreadingMixIn, StratumHTTPServer): pass
+            self.server = StratumThreadedServer(( self.host, self.port))
+            print "HTTP server started."
 
-        self.server = StratumThreadedJSONRPCServer(( self.host, self.port))
         self.server.dispatcher = self.dispatcher
         self.server.register_function(None, 'server.stop')
         self.server.register_function(None, 'server.info')
 
-        print "HTTP server started."
         self.server.serve_forever()