9 from processor import Session, Dispatcher
10 from utils import print_log
13 READ_ONLY = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR
14 READ_WRITE = READ_ONLY | select.POLLOUT
19 class TcpSession(Session):
21 def __init__(self, dispatcher, poller, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
22 Session.__init__(self, dispatcher)
23 self.use_ssl = use_ssl
25 self.raw_connection = connection
28 self._connection = ssl.wrap_socket(
31 certfile=ssl_certfile,
33 ssl_version=ssl.PROTOCOL_SSLv23,
34 do_handshake_on_connect=False)
36 self._connection = connection
38 self.address = address[0] + ":%d"%address[1]
39 self.name = "TCP " if not use_ssl else "SSL "
41 self.dispatcher.add_session(self)
42 self.response_queue = queue.Queue()
45 self.handshake = not self.use_ssl
48 def check_do_handshake(self):
52 self._connection.do_handshake()
53 except ssl.SSLError as err:
54 if err.args[0] == ssl.SSL_ERROR_WANT_READ:
56 elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
57 self.poller.modify(self.raw_connection, READ_WRITE)
62 self.poller.modify(self.raw_connection, READ_ONLY)
69 raise Exception("Session was stopped")
71 return self._connection
75 self._connection.shutdown(socket.SHUT_RDWR)
77 # print_log("problem shutting down", self.address)
78 # traceback.print_exc(file=sys.stdout)
81 self._connection.close()
83 def send_response(self, response):
85 msg = json.dumps(response) + '\n'
87 traceback.print_exc(file=sys.stdout)
90 self.response_queue.put(msg)
93 self.poller.modify(self.raw_connection, READ_WRITE)
95 traceback.print_exc(file=sys.stdout)
100 def parse_message(self):
102 message = self.message
103 self.time = time.time()
105 raw_buffer = message.find('\n')
109 raw_command = message[0:raw_buffer].strip()
110 self.message = message[raw_buffer + 1:]
117 class TcpServer(threading.Thread):
119 def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile):
120 self.shared = dispatcher.shared
121 self.dispatcher = dispatcher.request_dispatcher
122 threading.Thread.__init__(self)
126 self.lock = threading.Lock()
127 self.use_ssl = use_ssl
128 self.ssl_keyfile = ssl_keyfile
129 self.ssl_certfile = ssl_certfile
131 self.fd_to_session = {}
132 self.buffer_size = 4096
138 def handle_command(self, raw_command, session):
140 command = json.loads(raw_command)
142 session.send_response({"error": "bad JSON"})
146 # Try to load vital fields, and return an error if
148 message_id = command['id']
149 method = command['method']
151 # Return an error JSON in response.
152 session.send_response({"error": "syntax error", "request": raw_command})
154 self.dispatcher.push_request(session, command)
155 ## sleep a bit to prevent a single session from DOSing the queue
164 print_log( ("SSL" if self.use_ssl else "TCP") + " server started on port %d"%self.port)
165 server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
166 server.setblocking(0)
167 server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
168 server.bind((self.host, self.port))
170 server_fd = server.fileno()
172 poller = select.poll()
173 poller.register(server)
175 def stop_session(fd):
177 # unregister before we close s
178 poller.unregister(fd)
180 print_log("unregister error", fd)
181 traceback.print_exc(file=sys.stdout)
183 session = self.fd_to_session.pop(fd)
184 # this will close the socket
190 while not self.shared.stopped():
195 events = poller.poll(TIMEOUT)
197 for fd, flag in events:
200 session = self.fd_to_session[fd]
201 s = session._connection
203 session.check_do_handshake()
209 if flag & (select.POLLIN | select.POLLPRI):
212 connection, address = server.accept()
214 session = TcpSession(self.dispatcher, poller, connection, address,
215 use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
216 except BaseException as e:
217 print_log("cannot start TCP session", str(e), address)
221 connection = session._connection
222 connection.setblocking(0)
223 self.fd_to_session[ connection.fileno() ] = session
224 poller.register(connection, READ_ONLY)
226 session.check_do_handshake()
228 print_log( "handshake failure", address )
229 stop_session(connection.fileno())
234 data = s.recv(self.buffer_size)
235 except ssl.SSLError as x:
236 if x.args[0] == ssl.SSL_ERROR_WANT_READ:
237 # print_log("error want read", x, fd)
241 except socket.error as x:
242 # print_log("recv err", x)
247 if len(data) == self.buffer_size:
248 redo.append( (fd, flag) )
250 session.message += data
252 cmd = session.parse_message()
258 self.handle_command(cmd, session)
264 elif flag & select.POLLHUP:
265 print_log('client hung up', address)
269 elif flag & select.POLLOUT:
270 # Socket is ready to send data, if there is any to send.
271 if session.retry_msg:
272 next_msg = session.retry_msg
275 next_msg = session.response_queue.get_nowait()
277 # No messages waiting so stop checking for writability.
278 poller.modify(s, READ_ONLY)
282 sent = s.send(next_msg)
283 except socket.error as x:
284 # print_log("recv err", x)
288 session.retry_msg = next_msg[sent:]
291 elif flag & select.POLLERR:
292 print_log('handling exceptional condition for', session.address)
296 elif flag & select.POLLNVAL:
297 print_log('invalid request', session.address)