use a dict for sessions, instead of a list. Deprecates #48
[electrum-server.git] / transports / stratum_http.py
index b9fe30c..2b86869 100644 (file)
 # You should have received a copy of the GNU Affero General Public
 # License along with this program.  If not, see
 # <http://www.gnu.org/licenses/agpl.html>.
-
-import jsonrpclib
-from jsonrpclib import Fault
-from jsonrpclib.jsonrpc import USE_UNIX_SOCKETS
-import SimpleXMLRPCServer
-import SocketServer
-import socket
-import logging
-import os
-import types
-import traceback
-import sys, threading
-
-try:
-    import fcntl
-except ImportError:
-    # For Windows
-    fcntl = None
-
-import json
-
-
 """
 sessions are identified with cookies
  - each session has a buffer of responses to requests
 
 
-from the processor point of view: 
+from the processor point of view:
  - the user only defines process() ; the rest is session management.  thus sessions should not belong to processor
 
 """
+import json
+import logging
+import os
+import Queue
+import SimpleXMLRPCServer
+import socket
+import SocketServer
+import sys
+import time
+import threading
+import traceback
+import types
 
+import jsonrpclib
+from jsonrpclib import Fault
+from jsonrpclib.jsonrpc import USE_UNIX_SOCKETS
+from OpenSSL import SSL
 
-def random_string(N):
-    import random, string
-    return ''.join(random.choice(string.ascii_uppercase + string.digits) for x in range(N))
+try:
+    import fcntl
+except ImportError:
+    # For Windows
+    fcntl = None
 
 
+from processor import Session
+from utils import random_string, print_log
 
 
 def get_version(request):
@@ -61,38 +59,31 @@ def get_version(request):
     if 'id' in request.keys():
         return 1.0
     return None
-    
+
+
 def validate_request(request):
-    if type(request) is not types.DictType:
-        fault = Fault(
-            -32600, 'Request must be {}, not %s.' % type(request)
-        )
-        return fault
+    if not isinstance(request, types.DictType):
+        return Fault(-32600, 'Request must be {}, not %s.' % type(request))
     rpcid = request.get('id', None)
     version = get_version(request)
     if not version:
-        fault = Fault(-32600, 'Request %s invalid.' % request, rpcid=rpcid)
-        return fault        
+        return Fault(-32600, 'Request %s invalid.' % request, rpcid=rpcid)
     request.setdefault('params', [])
     method = request.get('method', None)
     params = request.get('params')
     param_types = (types.ListType, types.DictType, types.TupleType)
-    if not method or type(method) not in types.StringTypes or \
-        type(params) not in param_types:
-        fault = Fault(
-            -32600, 'Invalid request parameters or method.', rpcid=rpcid
-        )
-        return fault
+    if not method or type(method) not in types.StringTypes or type(params) not in param_types:
+        return Fault(-32600, 'Invalid request parameters or method.', rpcid=rpcid)
     return True
 
+
 class StratumJSONRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
 
     def __init__(self, encoding=None):
-        SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self,
-                                        allow_none=True,
-                                        encoding=encoding)
+        # todo: use super
+        SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self, allow_none=True, encoding=encoding)
 
-    def _marshaled_dispatch(self, session_id, data, dispatch_method = None):
+    def _marshaled_dispatch(self, session_id, data, dispatch_method=None):
         response = None
         try:
             request = jsonrpclib.loads(data)
@@ -101,23 +92,30 @@ class StratumJSONRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
             response = fault.response()
             return response
 
+        session = self.dispatcher.get_session_by_address(session_id)
+        if not session:
+            return 'Error: session not found'
+        session.time = time.time()
+
         responses = []
-        if type(request) is not types.ListType:
-            request = [ request ]
+        if not isinstance(request, types.ListType):
+            request = [request]
 
         for req_entry in request:
             result = validate_request(req_entry)
             if type(result) is Fault:
                 responses.append(result.response())
                 continue
-            resp_entry = self._marshaled_single_dispatch(session_id, req_entry)
-            if resp_entry is not None:
-                responses.append(resp_entry)
 
-        r = self.poll_session(session_id)
+            self.dispatcher.do_dispatch(session, req_entry)
+
+            if req_entry['method'] == 'server.stop':
+                return json.dumps({'result': 'ok'})
+
+        r = self.poll_session(session)
         for item in r:
             responses.append(json.dumps(item))
-            
+
         if len(responses) > 1:
             response = '[%s]' % ','.join(responses)
         elif len(responses) == 1:
@@ -127,69 +125,31 @@ class StratumJSONRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
 
         return response
 
-    def _marshaled_single_dispatch(self, session_id, request):
-        # TODO - Use the multiprocessing and skip the response if
-        # it is a notification
-        # Put in support for custom dispatcher here
-        # (See SimpleXMLRPCServer._marshaled_dispatch)
-        method = request.get('method')
-        params = request.get('params')
-        try:
-            response = self._dispatch(method, session_id, request)
-        except:
-            exc_type, exc_value, exc_tb = sys.exc_info()
-            fault = Fault(-32603, '%s:%s' % (exc_type, exc_value))
-            return fault.response()
-        if 'id' not in request.keys() or request['id'] == None:
-            # It's a notification
-            return None
+    def create_session(self):
+        session_id = random_string(20)
+        session = HttpSession(self.dispatcher, session_id)
+        return session_id
 
-        try:
-            response = jsonrpclib.dumps(response,
-                                        methodresponse=True,
-                                        rpcid=request['id']
-                                        )
-            return response
-        except:
-            exc_type, exc_value, exc_tb = sys.exc_info()
-            fault = Fault(-32603, '%s:%s' % (exc_type, exc_value))
-            return fault.response()
+    def poll_session(self, session):
+        q = session.pending_responses
+        responses = []
+        while not q.empty():
+            r = q.get()
+            responses.append(r)
+        #print "poll: %d responses"%len(responses)
+        return responses
 
-    def _dispatch(self, method, session_id, request):
-        func = None
-        try:
-            func = self.funcs[method]
-        except KeyError:
-            if self.instance is not None:
-                if hasattr(self.instance, '_dispatch'):
-                    return self.instance._dispatch(method, params)
-                else:
-                    try:
-                        func = SimpleXMLRPCServer.resolve_dotted_attribute(
-                            self.instance,
-                            method,
-                            True
-                            )
-                    except AttributeError:
-                        pass
-        if func is not None:
-            try:
-                response = func(session_id, request)
-                return response
-            except TypeError:
-                return Fault(-32602, 'Invalid parameters.')
-            except:
-                err_lines = traceback.format_exc().splitlines()
-                trace_string = '%s | %s' % (err_lines[-3], err_lines[-1])
-                fault = jsonrpclib.Fault(-32603, 'Server error: %s' % 
-                                         trace_string)
-                return fault
-        else:
-            return Fault(-32601, 'Method %s not supported.' % method)
 
-class StratumJSONRPCRequestHandler(
-        SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
-    
+class StratumJSONRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
+
+    def do_OPTIONS(self):
+        self.send_response(200)
+        self.send_header('Allow', 'GET, POST, OPTIONS')
+        self.send_header('Access-Control-Allow-Origin', '*')
+        self.send_header('Access-Control-Allow-Headers', 'Cache-Control, Content-Language, Content-Type, Expires, Last-Modified, Pragma, Accept-Language, Accept, Origin')
+        self.send_header('Content-Length', '0')
+        self.end_headers()
+
     def do_GET(self):
         if not self.is_rpc_path_valid():
             self.report_404()
@@ -198,7 +158,7 @@ class StratumJSONRPCRequestHandler(
             session_id = None
             c = self.headers.get('cookie')
             if c:
-                if c[0:8]=='SESSION=':
+                if c[0:8] == 'SESSION=':
                     #print "found cookie", c[8:]
                     session_id = c[8:]
 
@@ -216,19 +176,19 @@ class StratumJSONRPCRequestHandler(
             fault = jsonrpclib.Fault(-32603, 'Server error: %s' % trace_string)
             response = fault.response()
             print "500", trace_string
-        if response == None:
+        if response is None:
             response = ''
 
         if session_id:
-            self.send_header("Set-Cookie", "SESSION=%s"%session_id)
+            self.send_header("Set-Cookie", "SESSION=%s" % session_id)
 
         self.send_header("Content-type", "application/json-rpc")
+        self.send_header("Access-Control-Allow-Origin", "*")
         self.send_header("Content-length", str(len(response)))
         self.end_headers()
         self.wfile.write(response)
         self.wfile.flush()
-        self.connection.shutdown(1)
-
+        self.shutdown_connection()
 
     def do_POST(self):
         if not self.is_rpc_path_valid():
@@ -247,13 +207,13 @@ class StratumJSONRPCRequestHandler(
             session_id = None
             c = self.headers.get('cookie')
             if c:
-                if c[0:8]=='SESSION=':
-                    print "found cookie", c[8:]
+                if c[0:8] == 'SESSION=':
+                    #print "found cookie", c[8:]
                     session_id = c[8:]
 
             if session_id is None:
                 session_id = self.server.create_session()
-                print "setting cookie", session_id
+                #print "setting cookie", session_id
 
             response = self.server._marshaled_dispatch(session_id, data)
             self.send_response(200)
@@ -264,21 +224,51 @@ class StratumJSONRPCRequestHandler(
             fault = jsonrpclib.Fault(-32603, 'Server error: %s' % trace_string)
             response = fault.response()
             print "500", trace_string
-        if response == None:
+        if response is None:
             response = ''
 
         if session_id:
-            self.send_header("Set-Cookie", "SESSION=%s"%session_id)
+            self.send_header("Set-Cookie", "SESSION=%s" % session_id)
 
         self.send_header("Content-type", "application/json-rpc")
+        self.send_header("Access-Control-Allow-Origin", "*")
         self.send_header("Content-length", str(len(response)))
         self.end_headers()
         self.wfile.write(response)
         self.wfile.flush()
+        self.shutdown_connection()
+
+    def shutdown_connection(self):
         self.connection.shutdown(1)
 
 
-class StratumJSONRPCServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
+class SSLRequestHandler(StratumJSONRPCRequestHandler):
+    def setup(self):
+        self.connection = self.request
+        self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
+        self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
+
+    def shutdown_connection(self):
+        self.connection.shutdown()
+
+
+class SSLTCPServer(SocketServer.TCPServer):
+    def __init__(self, server_address, certfile, keyfile, RequestHandlerClass, bind_and_activate=True):
+        SocketServer.BaseServer.__init__(self, server_address, RequestHandlerClass)
+        ctx = SSL.Context(SSL.SSLv23_METHOD)
+        ctx.use_privatekey_file(keyfile)
+        ctx.use_certificate_file(certfile)
+        self.socket = SSL.Connection(ctx, socket.socket(self.address_family, self.socket_type))
+        if bind_and_activate:
+            self.server_bind()
+            self.server_activate()
+
+    def shutdown_request(self, request):
+        #request.shutdown()
+        pass
+
+
+class StratumHTTPServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
 
     allow_reuse_address = True
 
@@ -295,92 +285,98 @@ class StratumJSONRPCServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
             # Unix sockets can't be bound if they already exist in the
             # filesystem. The convention of e.g. X11 is to unlink
             # before binding again.
-            if os.path.exists(addr): 
+            if os.path.exists(addr):
                 try:
                     os.unlink(addr)
                 except OSError:
                     logging.warning("Could not unlink socket %s", addr)
-        # if python 2.5 and lower
-        if vi[0] < 3 and vi[1] < 6:
-            SocketServer.TCPServer.__init__(self, addr, requestHandler)
-        else:
-            SocketServer.TCPServer.__init__(self, addr, requestHandler,
-                bind_and_activate)
+
+        SocketServer.TCPServer.__init__(self, addr, requestHandler, bind_and_activate)
+
         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
             flags |= fcntl.FD_CLOEXEC
             fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
 
-        self.sessions = {}
 
+class StratumHTTPSSLServer(SSLTCPServer, StratumJSONRPCDispatcher):
 
+    allow_reuse_address = True
 
-    def create_session(self):
-        session_id = random_string(10)
-        self.sessions[session_id] = HttpSession(session_id)
-        return session_id
+    def __init__(self, addr, certfile, keyfile,
+                 requestHandler=SSLRequestHandler,
+                 logRequests=False, encoding=None, bind_and_activate=True,
+                 address_family=socket.AF_INET):
 
-    def poll_session(self,session_id):
-        responses = self.sessions[session_id].pending_responses[:]
-        self.sessions[session_id].pending_responses = []
-        print "poll: %d responses"%len(responses)
-        return responses
+        self.logRequests = logRequests
+        StratumJSONRPCDispatcher.__init__(self, encoding)
+        # TCPServer.__init__ has an extra parameter on 2.6+, so
+        # check Python version and decide on how to call it
+        vi = sys.version_info
+        self.address_family = address_family
+        if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
+            # Unix sockets can't be bound if they already exist in the
+            # filesystem. The convention of e.g. X11 is to unlink
+            # before binding again.
+            if os.path.exists(addr):
+                try:
+                    os.unlink(addr)
+                except OSError:
+                    logging.warning("Could not unlink socket %s", addr)
 
+        SSLTCPServer.__init__(self, addr, certfile, keyfile, requestHandler, bind_and_activate)
+
+        if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
+            flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
+            flags |= fcntl.FD_CLOEXEC
+            fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
 
-from processor import Session
 
 class HttpSession(Session):
 
-    def __init__(self, session_id):
-        Session.__init__(self)
-        self.pending_responses = []
-        print "new http session", session_id
+    def __init__(self, dispatcher, session_id):
+        Session.__init__(self, dispatcher)
+        self.pending_responses = Queue.Queue()
+        self.address = session_id
+        self.name = "HTTP"
+        self.timeout = 60
+        self.dispatcher.add_session(self)
 
     def send_response(self, response):
         raw_response = json.dumps(response)
-        self.pending_responses.append(response)
+        self.pending_responses.put(response)
+
+
 
 class HttpServer(threading.Thread):
-    def __init__(self, dispatcher, host, port, password):
+    def __init__(self, dispatcher, host, port, use_ssl, certfile, keyfile):
         self.shared = dispatcher.shared
         self.dispatcher = dispatcher.request_dispatcher
         threading.Thread.__init__(self)
         self.daemon = True
         self.host = host
         self.port = port
-        self.password = password
+        self.use_ssl = use_ssl
+        self.certfile = certfile
+        self.keyfile = keyfile
         self.lock = threading.Lock()
 
     def run(self):
         # see http://code.google.com/p/jsonrpclib/
         from SocketServer import ThreadingMixIn
-        class StratumThreadedJSONRPCServer(ThreadingMixIn, StratumJSONRPCServer): pass
-        self.server = StratumThreadedJSONRPCServer(( self.host, self.port))
-        for s in ['server.peers.subscribe', 'server.banner', 'blockchain.transaction.broadcast', \
-                      'blockchain.address.get_history','blockchain.address.subscribe', \
-                      'blockchain.numblocks.subscribe', 'client.version' ]:
-            self.server.register_function(self.process, s)
+        if self.use_ssl:
+            class StratumThreadedServer(ThreadingMixIn, StratumHTTPSSLServer):
+                pass
+            self.server = StratumThreadedServer((self.host, self.port), self.certfile, self.keyfile)
+            print_log("HTTPS server started.")
+        else:
+            class StratumThreadedServer(ThreadingMixIn, StratumHTTPServer):
+                pass
+            self.server = StratumThreadedServer((self.host, self.port))
+            print_log("HTTP server started.")
 
-        self.server.register_function(self.do_stop, 'stop')
+        self.server.dispatcher = self.dispatcher
+        self.server.register_function(None, 'server.stop')
+        self.server.register_function(None, 'server.info')
 
-        print "HTTP server started."
         self.server.serve_forever()
-
-
-    def process(self, session_id, request):
-        #print session, request
-        session = self.server.sessions.get(session_id)
-        if session:
-            self.dispatcher.process(session, request)
-
-    def do_stop(self, session, request):
-        try:
-            password = request['params'][0]
-        except:
-            password = None
-        if password == self.password:
-            self.shared.stop()
-            return 'ok'
-
-
-