import json
import Queue as queue
import socket
+import select
import threading
import time
import traceback, sys
from utils import print_log
+READ_ONLY = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR
+READ_WRITE = READ_ONLY | select.POLLOUT
+TIMEOUT = 100
+
+import ssl
+
class TcpSession(Session):
- def __init__(self, dispatcher, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
+ def __init__(self, dispatcher, poller, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
Session.__init__(self, dispatcher)
self.use_ssl = use_ssl
+ self.poller = poller
+ self.raw_connection = connection
if use_ssl:
import ssl
self._connection = ssl.wrap_socket(
self.address = address[0] + ":%d"%address[1]
self.name = "TCP " if not use_ssl else "SSL "
self.timeout = 1000
- self.response_queue = queue.Queue()
self.dispatcher.add_session(self)
+ self.response_queue = queue.Queue()
+ self.message = ''
+ self.retry_msg = ''
+ self.handshake = not self.use_ssl
+
- def do_handshake(self):
- if self.use_ssl:
+ def check_do_handshake(self):
+ if self.handshake:
+ return
+ try:
self._connection.do_handshake()
+ except ssl.SSLError as err:
+ if err.args[0] == ssl.SSL_ERROR_WANT_READ:
+ return
+ elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
+ self.poller.modify(self.raw_connection, READ_WRITE)
+ return
+ else:
+ raise
+
+ self.poller.modify(self.raw_connection, READ_ONLY)
+ self.handshake = True
+
+
def connection(self):
if self.stopped():
self._connection.close()
def send_response(self, response):
- self.response_queue.put(response)
+ try:
+ msg = json.dumps(response) + '\n'
+ except:
+ traceback.print_exc(file=sys.stdout)
+ return
+ self.response_queue.put(msg)
-class TcpClientResponder(threading.Thread):
+ try:
+ self.poller.modify(self.raw_connection, READ_WRITE)
+ except:
+ traceback.print_exc(file=sys.stdout)
+ return
- def __init__(self, session):
- self.session = session
- threading.Thread.__init__(self)
- def run(self):
- while not self.session.stopped():
- try:
- response = self.session.response_queue.get(timeout=10)
- except queue.Empty:
- continue
- data = json.dumps(response) + "\n"
- try:
- while data:
- l = self.session.connection().send(data)
- data = data[l:]
- except:
- self.session.stop()
+ def parse_message(self):
+ message = self.message
+ self.time = time.time()
-class TcpClientRequestor(threading.Thread):
+ raw_buffer = message.find('\n')
+ if raw_buffer == -1:
+ return False
- def __init__(self, dispatcher, session):
- self.shared = dispatcher.shared
- self.dispatcher = dispatcher
- self.message = ""
- self.session = session
- threading.Thread.__init__(self)
+ raw_command = message[0:raw_buffer].strip()
+ self.message = message[raw_buffer + 1:]
+ return raw_command
- def run(self):
- try:
- self.session.do_handshake()
- except:
- self.session.stop()
- return
- while not self.shared.stopped():
- data = self.receive()
- if not data:
- self.session.stop()
- break
- self.message += data
- self.session.time = time.time()
- while self.parse():
- pass
+class TcpServer(threading.Thread):
+
+ def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile):
+ self.shared = dispatcher.shared
+ self.dispatcher = dispatcher.request_dispatcher
+ threading.Thread.__init__(self)
+ self.daemon = True
+ self.host = host
+ self.port = port
+ self.lock = threading.Lock()
+ self.use_ssl = use_ssl
+ self.ssl_keyfile = ssl_keyfile
+ self.ssl_certfile = ssl_certfile
+
+ self.fd_to_session = {}
+ self.buffer_size = 4096
- def receive(self):
- try:
- return self.session.connection().recv(2048)
- except:
- return ''
- def parse(self):
- raw_buffer = self.message.find('\n')
- if raw_buffer == -1:
- return False
- raw_command = self.message[0:raw_buffer].strip()
- self.message = self.message[raw_buffer + 1:]
- if raw_command == 'quit':
- self.session.stop()
- return False
+ def handle_command(self, raw_command, session):
try:
command = json.loads(raw_command)
except:
- self.dispatcher.push_response(self.session, {"error": "bad JSON", "request": raw_command})
+ session.send_response({"error": "bad JSON"})
return True
try:
method = command['method']
except KeyError:
# Return an error JSON in response.
- self.dispatcher.push_response(self.session, {"error": "syntax error", "request": raw_command})
+ session.send_response({"error": "syntax error", "request": raw_command})
else:
- self.dispatcher.push_request(self.session, command)
- # sleep a bit to prevent a single session from DOSing the queue
- time.sleep(0.01)
+ self.dispatcher.push_request(session, command)
+ ## sleep a bit to prevent a single session from DOSing the queue
+ #time.sleep(0.01)
- return True
-class TcpServer(threading.Thread):
- def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile):
- self.shared = dispatcher.shared
- self.dispatcher = dispatcher.request_dispatcher
- threading.Thread.__init__(self)
- self.daemon = True
- self.host = host
- self.port = port
- self.lock = threading.Lock()
- self.use_ssl = use_ssl
- self.ssl_keyfile = ssl_keyfile
- self.ssl_certfile = ssl_certfile
def run(self):
+
print_log( ("SSL" if self.use_ssl else "TCP") + " server started on port %d"%self.port)
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- sock.bind((self.host, self.port))
- sock.listen(5)
+ server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ server.setblocking(0)
+ server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ server.bind((self.host, self.port))
+ server.listen(5)
+ server_fd = server.fileno()
- while not self.shared.stopped():
+ poller = select.poll()
+ poller.register(server)
- #if self.use_ssl: print_log("SSL: socket listening")
+ def stop_session(fd):
try:
- connection, address = sock.accept()
+ # unregister before we close s
+ poller.unregister(fd)
except:
+ print_log("unregister error", fd)
traceback.print_exc(file=sys.stdout)
- time.sleep(0.1)
- continue
- #if self.use_ssl: print_log("SSL: new session", address)
- try:
- session = TcpSession(self.dispatcher, connection, address, use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
- except BaseException, e:
- error = str(e)
- print_log("cannot start TCP session", error, address)
- connection.close()
- time.sleep(0.1)
- continue
-
- client_req = TcpClientRequestor(self.dispatcher, session)
- client_req.start()
- responder = TcpClientResponder(session)
- responder.start()
+ session = self.fd_to_session.pop(fd)
+ # this will close the socket
+ session.stop()
+
+
+ redo = []
+
+ while not self.shared.stopped():
+ if redo:
+ events = redo
+ redo = []
+ else:
+ events = poller.poll(TIMEOUT)
+
+ for fd, flag in events:
+
+ if fd != server_fd:
+ session = self.fd_to_session[fd]
+ s = session._connection
+ try:
+ session.check_do_handshake()
+ except:
+ stop_session(fd)
+ continue
+
+ # handle inputs
+ if flag & (select.POLLIN | select.POLLPRI):
+
+ if fd == server_fd:
+ connection, address = server.accept()
+ try:
+ session = TcpSession(self.dispatcher, poller, connection, address,
+ use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
+ except BaseException as e:
+ print_log("cannot start TCP session", str(e), address)
+ connection.close()
+ continue
+
+ connection = session._connection
+ connection.setblocking(0)
+ self.fd_to_session[ connection.fileno() ] = session
+ poller.register(connection, READ_ONLY)
+ try:
+ session.check_do_handshake()
+ except:
+ print_log( "handshake failure", address )
+ stop_session(connection.fileno())
+
+ continue
+
+ try:
+ data = s.recv(self.buffer_size)
+ except ssl.SSLError as x:
+ if x.args[0] == ssl.SSL_ERROR_WANT_READ:
+ # print_log("error want read", x, fd)
+ continue
+ else:
+ raise x
+ except socket.error as x:
+ # print_log("recv err", x)
+ stop_session(fd)
+ continue
+
+ if data:
+ if len(data) == self.buffer_size:
+ redo.append( (fd, flag) )
+
+ session.message += data
+ while True:
+ cmd = session.parse_message()
+ if not cmd:
+ break
+ if cmd == 'quit':
+ data = False
+ break
+ self.handle_command(cmd, session)
+
+ if not data:
+ stop_session(fd)
+ continue
+
+ elif flag & select.POLLHUP:
+ print_log('client hung up', address)
+ stop_session(fd)
+
+
+ elif flag & select.POLLOUT:
+ # Socket is ready to send data, if there is any to send.
+ if session.retry_msg:
+ next_msg = session.retry_msg
+ else:
+ try:
+ next_msg = session.response_queue.get_nowait()
+ except queue.Empty:
+ # No messages waiting so stop checking for writability.
+ poller.modify(s, READ_ONLY)
+ continue
+
+ try:
+ sent = s.send(next_msg)
+ except socket.error as x:
+ # print_log("recv err", x)
+ stop_session(fd)
+ continue
+
+ session.retry_msg = next_msg[sent:]
+
+
+ elif flag & select.POLLERR:
+ print_log('handling exceptional condition for', session.address)
+ stop_session(fd)
+
+
+ elif flag & select.POLLNVAL:
+ print_log('invalid request', session.address)
+ stop_session(fd)