From 521034aa46e65bbabbaed2cd89f582fb67a0ddd6 Mon Sep 17 00:00:00 2001 From: ThomasV Date: Thu, 17 Apr 2014 16:43:29 +0200 Subject: [PATCH] poller (rebased) --- transports/stratum_tcp.py | 302 ++++++++++++++++++++++++++++++--------------- 1 files changed, 203 insertions(+), 99 deletions(-) diff --git a/transports/stratum_tcp.py b/transports/stratum_tcp.py index cae116c..ec697ca 100644 --- a/transports/stratum_tcp.py +++ b/transports/stratum_tcp.py @@ -1,6 +1,7 @@ import json import Queue as queue import socket +import select import threading import time import traceback, sys @@ -9,11 +10,19 @@ from processor import Session, Dispatcher from utils import print_log +READ_ONLY = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR +READ_WRITE = READ_ONLY | select.POLLOUT +TIMEOUT = 100 + +import ssl + class TcpSession(Session): - def __init__(self, dispatcher, connection, address, use_ssl, ssl_certfile, ssl_keyfile): + def __init__(self, dispatcher, poller, connection, address, use_ssl, ssl_certfile, ssl_keyfile): Session.__init__(self, dispatcher) self.use_ssl = use_ssl + self.poller = poller + self.raw_connection = connection if use_ssl: import ssl self._connection = ssl.wrap_socket( @@ -29,12 +38,31 @@ class TcpSession(Session): self.address = address[0] + ":%d"%address[1] self.name = "TCP " if not use_ssl else "SSL " self.timeout = 1000 - self.response_queue = queue.Queue() self.dispatcher.add_session(self) + self.response_queue = queue.Queue() + self.message = '' + self.retry_msg = '' + self.handshake = not self.use_ssl + - def do_handshake(self): - if self.use_ssl: + def check_do_handshake(self): + if self.handshake: + return + try: self._connection.do_handshake() + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + return + elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + self.poller.modify(self.raw_connection, READ_WRITE) + return + else: + raise + + self.poller.modify(self.raw_connection, READ_ONLY) + self.handshake = True + + def connection(self): if self.stopped(): @@ -53,82 +81,65 @@ class TcpSession(Session): self._connection.close() def send_response(self, response): - self.response_queue.put(response) + try: + msg = json.dumps(response) + '\n' + except: + traceback.print_exc(file=sys.stdout) + return + self.response_queue.put(msg) -class TcpClientResponder(threading.Thread): + try: + self.poller.modify(self.raw_connection, READ_WRITE) + except: + traceback.print_exc(file=sys.stdout) + return - def __init__(self, session): - self.session = session - threading.Thread.__init__(self) - def run(self): - while not self.session.stopped(): - try: - response = self.session.response_queue.get(timeout=10) - except queue.Empty: - continue - data = json.dumps(response) + "\n" - try: - while data: - l = self.session.connection().send(data) - data = data[l:] - except: - self.session.stop() + def parse_message(self): + message = self.message + self.time = time.time() -class TcpClientRequestor(threading.Thread): + raw_buffer = message.find('\n') + if raw_buffer == -1: + return False - def __init__(self, dispatcher, session): - self.shared = dispatcher.shared - self.dispatcher = dispatcher - self.message = "" - self.session = session - threading.Thread.__init__(self) + raw_command = message[0:raw_buffer].strip() + self.message = message[raw_buffer + 1:] + return raw_command - def run(self): - try: - self.session.do_handshake() - except: - self.session.stop() - return - while not self.shared.stopped(): - data = self.receive() - if not data: - self.session.stop() - break - self.message += data - self.session.time = time.time() - while self.parse(): - pass +class TcpServer(threading.Thread): + + def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile): + self.shared = dispatcher.shared + self.dispatcher = dispatcher.request_dispatcher + threading.Thread.__init__(self) + self.daemon = True + self.host = host + self.port = port + self.lock = threading.Lock() + self.use_ssl = use_ssl + self.ssl_keyfile = ssl_keyfile + self.ssl_certfile = ssl_certfile + + self.fd_to_session = {} + self.buffer_size = 4096 - def receive(self): - try: - return self.session.connection().recv(2048) - except: - return '' - def parse(self): - raw_buffer = self.message.find('\n') - if raw_buffer == -1: - return False - raw_command = self.message[0:raw_buffer].strip() - self.message = self.message[raw_buffer + 1:] - if raw_command == 'quit': - self.session.stop() - return False + def handle_command(self, raw_command, session): try: command = json.loads(raw_command) except: - self.dispatcher.push_response(self.session, {"error": "bad JSON", "request": raw_command}) + session.send_response({"error": "bad JSON"}) return True try: @@ -138,57 +149,150 @@ class TcpClientRequestor(threading.Thread): method = command['method'] except KeyError: # Return an error JSON in response. - self.dispatcher.push_response(self.session, {"error": "syntax error", "request": raw_command}) + session.send_response({"error": "syntax error", "request": raw_command}) else: - self.dispatcher.push_request(self.session, command) - # sleep a bit to prevent a single session from DOSing the queue - time.sleep(0.01) + self.dispatcher.push_request(session, command) + ## sleep a bit to prevent a single session from DOSing the queue + #time.sleep(0.01) - return True -class TcpServer(threading.Thread): - def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile): - self.shared = dispatcher.shared - self.dispatcher = dispatcher.request_dispatcher - threading.Thread.__init__(self) - self.daemon = True - self.host = host - self.port = port - self.lock = threading.Lock() - self.use_ssl = use_ssl - self.ssl_keyfile = ssl_keyfile - self.ssl_certfile = ssl_certfile def run(self): + print_log( ("SSL" if self.use_ssl else "TCP") + " server started on port %d"%self.port) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((self.host, self.port)) - sock.listen(5) + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setblocking(0) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind((self.host, self.port)) + server.listen(5) + server_fd = server.fileno() - while not self.shared.stopped(): + poller = select.poll() + poller.register(server) - #if self.use_ssl: print_log("SSL: socket listening") + def stop_session(fd): try: - connection, address = sock.accept() + # unregister before we close s + poller.unregister(fd) except: + print_log("unregister error", fd) traceback.print_exc(file=sys.stdout) - time.sleep(0.1) - continue - #if self.use_ssl: print_log("SSL: new session", address) - try: - session = TcpSession(self.dispatcher, connection, address, use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile) - except BaseException, e: - error = str(e) - print_log("cannot start TCP session", error, address) - connection.close() - time.sleep(0.1) - continue - - client_req = TcpClientRequestor(self.dispatcher, session) - client_req.start() - responder = TcpClientResponder(session) - responder.start() + session = self.fd_to_session.pop(fd) + # this will close the socket + session.stop() + + + redo = [] + + while not self.shared.stopped(): + if redo: + events = redo + redo = [] + else: + events = poller.poll(TIMEOUT) + + for fd, flag in events: + + if fd != server_fd: + session = self.fd_to_session[fd] + s = session._connection + try: + session.check_do_handshake() + except: + stop_session(fd) + continue + + # handle inputs + if flag & (select.POLLIN | select.POLLPRI): + + if fd == server_fd: + connection, address = server.accept() + try: + session = TcpSession(self.dispatcher, poller, connection, address, + use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile) + except BaseException as e: + print_log("cannot start TCP session", str(e), address) + connection.close() + continue + + connection = session._connection + connection.setblocking(0) + self.fd_to_session[ connection.fileno() ] = session + poller.register(connection, READ_ONLY) + try: + session.check_do_handshake() + except: + print_log( "handshake failure", address ) + stop_session(connection.fileno()) + + continue + + try: + data = s.recv(self.buffer_size) + except ssl.SSLError as x: + if x.args[0] == ssl.SSL_ERROR_WANT_READ: + # print_log("error want read", x, fd) + continue + else: + raise x + except socket.error as x: + # print_log("recv err", x) + stop_session(fd) + continue + + if data: + if len(data) == self.buffer_size: + redo.append( (fd, flag) ) + + session.message += data + while True: + cmd = session.parse_message() + if not cmd: + break + if cmd == 'quit': + data = False + break + self.handle_command(cmd, session) + + if not data: + stop_session(fd) + continue + + elif flag & select.POLLHUP: + print_log('client hung up', address) + stop_session(fd) + + + elif flag & select.POLLOUT: + # Socket is ready to send data, if there is any to send. + if session.retry_msg: + next_msg = session.retry_msg + else: + try: + next_msg = session.response_queue.get_nowait() + except queue.Empty: + # No messages waiting so stop checking for writability. + poller.modify(s, READ_ONLY) + continue + + try: + sent = s.send(next_msg) + except socket.error as x: + # print_log("recv err", x) + stop_session(fd) + continue + + session.retry_msg = next_msg[sent:] + + + elif flag & select.POLLERR: + print_log('handling exceptional condition for', session.address) + stop_session(fd) + + + elif flag & select.POLLNVAL: + print_log('invalid request', session.address) + stop_session(fd) -- 1.7.1