poller (rebased)
authorThomasV <thomasv1@gmx.de>
Thu, 17 Apr 2014 14:43:29 +0000 (16:43 +0200)
committerThomasV <thomasv1@gmx.de>
Thu, 17 Apr 2014 14:43:29 +0000 (16:43 +0200)
transports/stratum_tcp.py

index cae116c..ec697ca 100644 (file)
@@ -1,6 +1,7 @@
 import json
 import Queue as queue
 import socket
+import select
 import threading
 import time
 import traceback, sys
@@ -9,11 +10,19 @@ from processor import Session, Dispatcher
 from utils import print_log
 
 
+READ_ONLY = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR
+READ_WRITE = READ_ONLY | select.POLLOUT
+TIMEOUT = 100
+
+import ssl
+
 class TcpSession(Session):
 
-    def __init__(self, dispatcher, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
+    def __init__(self, dispatcher, poller, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
         Session.__init__(self, dispatcher)
         self.use_ssl = use_ssl
+        self.poller = poller
+        self.raw_connection = connection
         if use_ssl:
             import ssl
             self._connection = ssl.wrap_socket(
@@ -29,12 +38,31 @@ class TcpSession(Session):
         self.address = address[0] + ":%d"%address[1]
         self.name = "TCP " if not use_ssl else "SSL "
         self.timeout = 1000
-        self.response_queue = queue.Queue()
         self.dispatcher.add_session(self)
+        self.response_queue = queue.Queue()
+        self.message = ''
+        self.retry_msg = ''
+        self.handshake = not self.use_ssl
+
 
-    def do_handshake(self):
-        if self.use_ssl:
+    def check_do_handshake(self):
+        if self.handshake:
+            return
+        try:
             self._connection.do_handshake()
+        except ssl.SSLError as err:
+            if err.args[0] == ssl.SSL_ERROR_WANT_READ:
+                return
+            elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
+                self.poller.modify(self.raw_connection, READ_WRITE)
+                return
+            else:
+                raise
+
+        self.poller.modify(self.raw_connection, READ_ONLY)
+        self.handshake = True
+
+
 
     def connection(self):
         if self.stopped():
@@ -53,82 +81,65 @@ class TcpSession(Session):
         self._connection.close()
 
     def send_response(self, response):
-        self.response_queue.put(response)
+        try:
+            msg = json.dumps(response) + '\n'
+        except:
+            traceback.print_exc(file=sys.stdout)
+            return
 
+        self.response_queue.put(msg)
 
-class TcpClientResponder(threading.Thread):
+        try:
+            self.poller.modify(self.raw_connection, READ_WRITE)
+        except:
+            traceback.print_exc(file=sys.stdout)
+            return
 
-    def __init__(self, session):
-        self.session = session
-        threading.Thread.__init__(self)
 
-    def run(self):
-        while not self.session.stopped():
-            try:
-                response = self.session.response_queue.get(timeout=10)
-            except queue.Empty:
-                continue
-            data = json.dumps(response) + "\n"
-            try:
-                while data:
-                    l = self.session.connection().send(data)
-                    data = data[l:]
-            except:
-                self.session.stop()
 
+    def parse_message(self):
 
+        message = self.message
+        self.time = time.time()
 
-class TcpClientRequestor(threading.Thread):
+        raw_buffer = message.find('\n')
+        if raw_buffer == -1:
+            return False
 
-    def __init__(self, dispatcher, session):
-        self.shared = dispatcher.shared
-        self.dispatcher = dispatcher
-        self.message = ""
-        self.session = session
-        threading.Thread.__init__(self)
+        raw_command = message[0:raw_buffer].strip()
+        self.message = message[raw_buffer + 1:]
+        return raw_command
 
-    def run(self):
-        try:
-            self.session.do_handshake()
-        except:
-            self.session.stop()
-            return
 
-        while not self.shared.stopped():
 
-            data = self.receive()
-            if not data:
-                self.session.stop()
-                break
 
-            self.message += data
-            self.session.time = time.time()
 
-            while self.parse():
-                pass
+class TcpServer(threading.Thread):
+
+    def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile):
+        self.shared = dispatcher.shared
+        self.dispatcher = dispatcher.request_dispatcher
+        threading.Thread.__init__(self)
+        self.daemon = True
+        self.host = host
+        self.port = port
+        self.lock = threading.Lock()
+        self.use_ssl = use_ssl
+        self.ssl_keyfile = ssl_keyfile
+        self.ssl_certfile = ssl_certfile
+
+        self.fd_to_session = {}
+        self.buffer_size = 4096
 
 
-    def receive(self):
-        try:
-            return self.session.connection().recv(2048)
-        except:
-            return ''
 
-    def parse(self):
-        raw_buffer = self.message.find('\n')
-        if raw_buffer == -1:
-            return False
 
-        raw_command = self.message[0:raw_buffer].strip()
-        self.message = self.message[raw_buffer + 1:]
-        if raw_command == 'quit':
-            self.session.stop()
-            return False
 
+    def handle_command(self, raw_command, session):
         try:
             command = json.loads(raw_command)
         except:
-            self.dispatcher.push_response(self.session, {"error": "bad JSON", "request": raw_command})
+            session.send_response({"error": "bad JSON"})
             return True
 
         try:
@@ -138,57 +149,150 @@ class TcpClientRequestor(threading.Thread):
             method = command['method']
         except KeyError:
             # Return an error JSON in response.
-            self.dispatcher.push_response(self.session, {"error": "syntax error", "request": raw_command})
+            session.send_response({"error": "syntax error", "request": raw_command})
         else:
-            self.dispatcher.push_request(self.session, command)
-            # sleep a bit to prevent a single session from DOSing the queue
-            time.sleep(0.01)
+            self.dispatcher.push_request(session, command)
+            ## sleep a bit to prevent a single session from DOSing the queue
+            #time.sleep(0.01)
 
-        return True
 
 
-class TcpServer(threading.Thread):
 
-    def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile):
-        self.shared = dispatcher.shared
-        self.dispatcher = dispatcher.request_dispatcher
-        threading.Thread.__init__(self)
-        self.daemon = True
-        self.host = host
-        self.port = port
-        self.lock = threading.Lock()
-        self.use_ssl = use_ssl
-        self.ssl_keyfile = ssl_keyfile
-        self.ssl_certfile = ssl_certfile
 
     def run(self):
+
         print_log( ("SSL" if self.use_ssl else "TCP") + " server started on port %d"%self.port)
-        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        sock.bind((self.host, self.port))
-        sock.listen(5)
+        server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        server.setblocking(0)
+        server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        server.bind((self.host, self.port))
+        server.listen(5)
+        server_fd = server.fileno()
 
-        while not self.shared.stopped():
+        poller = select.poll()
+        poller.register(server)
 
-            #if self.use_ssl: print_log("SSL: socket listening")
+        def stop_session(fd):
             try:
-                connection, address = sock.accept()
+                # unregister before we close s 
+                poller.unregister(fd)
             except:
+                print_log("unregister error", fd)
                 traceback.print_exc(file=sys.stdout)
-                time.sleep(0.1)
-                continue
 
-            #if self.use_ssl: print_log("SSL: new session", address)
-            try:
-                session = TcpSession(self.dispatcher, connection, address, use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
-            except BaseException, e:
-                error = str(e)
-                print_log("cannot start TCP session", error, address)
-                connection.close()
-                time.sleep(0.1)
-                continue
-
-            client_req = TcpClientRequestor(self.dispatcher, session)
-            client_req.start()
-            responder = TcpClientResponder(session)
-            responder.start()
+            session = self.fd_to_session.pop(fd)
+            # this will close the socket
+            session.stop()
+
+
+        redo = []
+
+        while not self.shared.stopped():
+            if redo:
+                events = redo
+                redo = []
+            else:
+                events = poller.poll(TIMEOUT)
+
+            for fd, flag in events:
+
+                if fd != server_fd:
+                    session = self.fd_to_session[fd]
+                    s = session._connection
+                    try:
+                        session.check_do_handshake()
+                    except:
+                        stop_session(fd)
+                        continue
+
+                # handle inputs
+                if flag & (select.POLLIN | select.POLLPRI):
+
+                    if fd == server_fd:
+                        connection, address = server.accept()
+                        try:
+                            session = TcpSession(self.dispatcher, poller, connection, address, 
+                                                 use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
+                        except BaseException as e:
+                            print_log("cannot start TCP session", str(e), address)
+                            connection.close()
+                            continue
+
+                        connection = session._connection
+                        connection.setblocking(0)
+                        self.fd_to_session[ connection.fileno() ] = session
+                        poller.register(connection, READ_ONLY)
+                        try:
+                            session.check_do_handshake()
+                        except:
+                            print_log( "handshake failure", address )
+                            stop_session(connection.fileno())
+
+                        continue
+
+                    try:
+                        data = s.recv(self.buffer_size)
+                    except ssl.SSLError as x:
+                        if x.args[0] == ssl.SSL_ERROR_WANT_READ: 
+                            # print_log("error want read", x, fd)
+                            continue 
+                        else: 
+                            raise x
+                    except socket.error as x:
+                        # print_log("recv err", x)
+                        stop_session(fd)
+                        continue
+
+                    if data:
+                        if len(data) == self.buffer_size:
+                            redo.append( (fd, flag) )
+
+                        session.message += data
+                        while True:
+                            cmd = session.parse_message()
+                            if not cmd: 
+                                break
+                            if cmd == 'quit':
+                                data = False
+                                break
+                            self.handle_command(cmd, session)
+
+                    if not data:
+                        stop_session(fd)
+                        continue
+
+                elif flag & select.POLLHUP:
+                    print_log('client hung up', address)
+                    stop_session(fd)
+
+
+                elif flag & select.POLLOUT:
+                    # Socket is ready to send data, if there is any to send.
+                    if session.retry_msg:
+                        next_msg = session.retry_msg
+                    else:
+                        try:
+                            next_msg = session.response_queue.get_nowait()
+                        except queue.Empty:
+                            # No messages waiting so stop checking for writability.
+                            poller.modify(s, READ_ONLY)
+                            continue
+
+                    try:
+                        sent = s.send(next_msg)
+                    except socket.error as x:
+                        # print_log("recv err", x)
+                        stop_session(fd)
+                        continue
+                        
+                    session.retry_msg = next_msg[sent:]
+
+
+                elif flag & select.POLLERR:
+                    print_log('handling exceptional condition for', session.address)
+                    stop_session(fd)
+
+
+                elif flag & select.POLLNVAL:
+                    print_log('invalid request', session.address)
+                    stop_session(fd)