server-side support for ssl
authorThomasV <thomasv@gitorious>
Wed, 17 Oct 2012 09:40:06 +0000 (13:40 +0400)
committerThomasV <thomasv@gitorious>
Wed, 17 Oct 2012 09:40:06 +0000 (13:40 +0400)
backends/irc/__init__.py
server.py
transports/stratum_http.py
transports/stratum_tcp.py

index 0372002..5e0bc8c 100644 (file)
@@ -15,6 +15,8 @@ class IrcThread(threading.Thread):
         self.daemon = True
         self.stratum_tcp_port = config.get('server','stratum_tcp_port')
         self.stratum_http_port = config.get('server','stratum_http_port')
+        self.stratum_tcp_ssl_port = config.get('server','stratum_tcp_ssl_port')
+        self.stratum_http_ssl_port = config.get('server','stratum_http_ssl_port')
         self.peers = {}
         self.host = config.get('server','host')
         self.nick = config.get('server', 'irc_nick')
@@ -34,6 +36,10 @@ class IrcThread(threading.Thread):
             s += 't' + self.stratum_tcp_port + ' ' 
         if self.stratum_http_port:
             s += 'h' + self.stratum_http_port + ' '
+        if self.stratum_tcp_port:
+            s += 's' + self.stratum_tcp_ssl_port + ' ' 
+        if self.stratum_http_port:
+            s += 'g' + self.stratum_http_ssl_port + ' '
         return s
 
 
index bfbeecd..1e70984 100755 (executable)
--- a/server.py
+++ b/server.py
@@ -40,6 +40,8 @@ def create_config():
     config.set('server', 'host', 'localhost')
     config.set('server', 'stratum_tcp_port', '50001')
     config.set('server', 'stratum_http_port', '8081')
+    config.set('server', 'stratum_tcp_ssl_port', '50002')
+    config.set('server', 'stratum_http_ssl_port', '8082')
     config.set('server', 'password', '')
     config.set('server', 'irc', 'yes')
     config.set('server', 'irc_nick', '')
@@ -96,6 +98,10 @@ if __name__ == '__main__':
     host = config.get('server', 'host')
     stratum_tcp_port = config.get('server', 'stratum_tcp_port')
     stratum_http_port = config.get('server', 'stratum_http_port')
+    stratum_tcp_ssl_port = config.get('server', 'stratum_tcp_ssl_port')
+    stratum_http_ssl_port = config.get('server', 'stratum_http_ssl_port')
+    ssl_certfile = config.get('server', 'ssl_certfile')
+    ssl_keyfile = config.get('server', 'ssl_keyfile')
 
     if len(sys.argv) > 1:
         run_rpc_command(sys.argv[1], stratum_tcp_port)
@@ -129,12 +135,22 @@ if __name__ == '__main__':
     # Create various transports we need
     if stratum_tcp_port:
         from transports.stratum_tcp import TcpServer
-        tcp_server = TcpServer(dispatcher, host, int(stratum_tcp_port))
+        tcp_server = TcpServer(dispatcher, host, int(stratum_tcp_port), False, None, None)
+        transports.append(tcp_server)
+
+    if stratum_tcp_ssl_port:
+        from transports.stratum_tcp import TcpServer
+        tcp_server = TcpServer(dispatcher, host, int(stratum_tcp_ssl_port), True, ssl_certfile, ssl_keyfile)
         transports.append(tcp_server)
 
     if stratum_http_port:
         from transports.stratum_http import HttpServer
-        http_server = HttpServer(dispatcher, host, int(stratum_http_port))
+        http_server = HttpServer(dispatcher, host, int(stratum_http_port), False, None, None)
+        transports.append(http_server)
+
+    if stratum_http_ssl_port:
+        from transports.stratum_http import HttpServer
+        http_server = HttpServer(dispatcher, host, int(stratum_http_ssl_port), True, ssl_certfile, ssl_keyfile)
         transports.append(http_server)
 
     for server in transports:
index 49f3fd9..c4fe06a 100644 (file)
@@ -27,6 +27,8 @@ import types
 import traceback
 import sys, threading
 
+from OpenSSL import SSL
+
 try:
     import fcntl
 except ImportError:
@@ -131,6 +133,23 @@ class StratumJSONRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
         return response
 
 
+    def create_session(self):
+        session_id = random_string(10)
+        session = HttpSession(session_id)
+        self.dispatcher.add_session(session)
+        return session_id
+
+    def poll_session(self, session):
+        q = session.pending_responses
+        responses = []
+        while not q.empty():
+            r = q.get()
+            responses.append(r)
+        #print "poll: %d responses"%len(responses)
+        return responses
+
+
+
 
 class StratumJSONRPCRequestHandler(
         SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
@@ -225,7 +244,29 @@ class StratumJSONRPCRequestHandler(
         self.connection.shutdown(1)
 
 
-class StratumJSONRPCServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
+
+
+class SSLTCPServer(SocketServer.TCPServer):
+
+    def __init__(self, server_address, certfile, keyfile, RequestHandlerClass, bind_and_activate=True):
+        SocketServer.BaseServer.__init__(self, server_address, RequestHandlerClass)
+        ctx = SSL.Context(SSL.SSLv3_METHOD)
+        self.certfile = certfile
+        self.keyfile = keyfile
+        #certfile = '/etc/ssl/certs/cacert.pem'
+        #keyfile = '/etc/ssl/private/cakey.pem'
+        ctx.use_privatekey_file(keyfile)
+        ctx.use_certificate_file(certfile)
+        self.socket = SSL.Connection(ctx, socket.socket(self.address_family, self.socket_type))
+        if bind_and_activate:
+            self.server_bind()
+            self.server_activate()
+
+    def shutdown_request(self,request):
+        request.shutdown()
+
+
+class StratumHTTPServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
 
     allow_reuse_address = True
 
@@ -247,12 +288,42 @@ class StratumJSONRPCServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
                     os.unlink(addr)
                 except OSError:
                     logging.warning("Could not unlink socket %s", addr)
-        # if python 2.5 and lower
-        if vi[0] < 3 and vi[1] < 6:
-            SocketServer.TCPServer.__init__(self, addr, requestHandler)
-        else:
-            SocketServer.TCPServer.__init__(self, addr, requestHandler,
-                bind_and_activate)
+
+        SocketServer.TCPServer.__init__(self, addr, requestHandler, bind_and_activate)
+
+        if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
+            flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
+            flags |= fcntl.FD_CLOEXEC
+            fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
+
+
+class StratumHTTPSSLServer(SSLTCPServer, StratumJSONRPCDispatcher):
+
+    allow_reuse_address = True
+
+    def __init__(self, addr, certfile, keyfile,
+                 requestHandler=StratumJSONRPCRequestHandler,
+                 logRequests=False, encoding=None, bind_and_activate=True,
+                 address_family=socket.AF_INET):
+
+        self.logRequests = logRequests
+        StratumJSONRPCDispatcher.__init__(self, encoding)
+        # TCPServer.__init__ has an extra parameter on 2.6+, so
+        # check Python version and decide on how to call it
+        vi = sys.version_info
+        self.address_family = address_family
+        if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
+            # Unix sockets can't be bound if they already exist in the
+            # filesystem. The convention of e.g. X11 is to unlink
+            # before binding again.
+            if os.path.exists(addr): 
+                try:
+                    os.unlink(addr)
+                except OSError:
+                    logging.warning("Could not unlink socket %s", addr)
+
+        SSLTCPServer.__init__(self, addr, certfile, keyfile, requestHandler, bind_and_activate)
+
         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
             flags |= fcntl.FD_CLOEXEC
@@ -260,20 +331,7 @@ class StratumJSONRPCServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
 
 
 
-    def create_session(self):
-        session_id = random_string(10)
-        session = HttpSession(session_id)
-        self.dispatcher.add_session(session)
-        return session_id
 
-    def poll_session(self, session):
-        q = session.pending_responses
-        responses = []
-        while not q.empty():
-            r = q.get()
-            responses.append(r)
-        #print "poll: %d responses"%len(responses)
-        return responses
 
 
 from processor import Session
@@ -298,25 +356,34 @@ class HttpSession(Session):
             return self._stopped
 
 class HttpServer(threading.Thread):
-    def __init__(self, dispatcher, host, port):
+    def __init__(self, dispatcher, host, port, use_ssl, certfile, keyfile):
         self.shared = dispatcher.shared
         self.dispatcher = dispatcher.request_dispatcher
         threading.Thread.__init__(self)
         self.daemon = True
         self.host = host
         self.port = port
+        self.use_ssl = use_ssl
+        self.certfile = certfile
+        self.keyfile = keyfile
         self.lock = threading.Lock()
 
+
     def run(self):
         # see http://code.google.com/p/jsonrpclib/
         from SocketServer import ThreadingMixIn
-        class StratumThreadedJSONRPCServer(ThreadingMixIn, StratumJSONRPCServer): pass
+        if self.use_ssl:
+            class StratumThreadedServer(ThreadingMixIn, StratumHTTPSSLServer): pass
+            self.server = StratumThreadedServer(( self.host, self.port), self.certfile, self.keyfile)
+            print "HTTPS server started."
+        else:
+            class StratumThreadedServer(ThreadingMixIn, StratumHTTPServer): pass
+            self.server = StratumThreadedServer(( self.host, self.port))
+            print "HTTP server started."
 
-        self.server = StratumThreadedJSONRPCServer(( self.host, self.port))
         self.server.dispatcher = self.dispatcher
         self.server.register_function(None, 'server.stop')
         self.server.register_function(None, 'server.info')
 
-        print "HTTP server started."
         self.server.serve_forever()
 
index 4bcefab..9fb17f7 100644 (file)
@@ -8,9 +8,20 @@ from processor import Session, Dispatcher, timestr
 
 class TcpSession(Session):
 
-    def __init__(self, connection, address):
+    def __init__(self, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
         Session.__init__(self)
-        self._connection = connection
+        print connection, address, use_ssl
+        if use_ssl:
+            import ssl
+            self._connection = ssl.wrap_socket(
+                connection,
+                server_side=True,
+                certfile=ssl_certfile,
+                keyfile=ssl_keyfile,
+                ssl_version=ssl.PROTOCOL_TLSv1)
+        else:
+            self._connection = connection
+
         self.address = address[0]
         self.name = "TCP"
 
@@ -108,7 +119,7 @@ class TcpClientRequestor(threading.Thread):
 
 class TcpServer(threading.Thread):
 
-    def __init__(self, dispatcher, host, port):
+    def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile):
         self.shared = dispatcher.shared
         self.dispatcher = dispatcher.request_dispatcher
         threading.Thread.__init__(self)
@@ -116,15 +127,18 @@ class TcpServer(threading.Thread):
         self.host = host
         self.port = port
         self.lock = threading.Lock()
+        self.use_ssl = use_ssl
+        self.ssl_keyfile = ssl_keyfile
+        self.ssl_certfile = ssl_certfile
 
     def run(self):
-        print "TCP server started."
+        print "TCP server started.", self.use_ssl
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
         sock.bind((self.host, self.port))
         sock.listen(1)
         while not self.shared.stopped():
-            session = TcpSession(*sock.accept())
+            session = TcpSession(*sock.accept(), use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
             self.dispatcher.add_session(session)
             self.dispatcher.collect_garbage()
             client_req = TcpClientRequestor(self.dispatcher, session)