poller (rebased)
[electrum-server.git] / transports / stratum_tcp.py
1 import json
2 import Queue as queue
3 import socket
4 import select
5 import threading
6 import time
7 import traceback, sys
8
9 from processor import Session, Dispatcher
10 from utils import print_log
11
12
13 READ_ONLY = select.POLLIN | select.POLLPRI | select.POLLHUP | select.POLLERR
14 READ_WRITE = READ_ONLY | select.POLLOUT
15 TIMEOUT = 100
16
17 import ssl
18
19 class TcpSession(Session):
20
21     def __init__(self, dispatcher, poller, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
22         Session.__init__(self, dispatcher)
23         self.use_ssl = use_ssl
24         self.poller = poller
25         self.raw_connection = connection
26         if use_ssl:
27             import ssl
28             self._connection = ssl.wrap_socket(
29                 connection,
30                 server_side=True,
31                 certfile=ssl_certfile,
32                 keyfile=ssl_keyfile,
33                 ssl_version=ssl.PROTOCOL_SSLv23,
34                 do_handshake_on_connect=False)
35         else:
36             self._connection = connection
37
38         self.address = address[0] + ":%d"%address[1]
39         self.name = "TCP " if not use_ssl else "SSL "
40         self.timeout = 1000
41         self.dispatcher.add_session(self)
42         self.response_queue = queue.Queue()
43         self.message = ''
44         self.retry_msg = ''
45         self.handshake = not self.use_ssl
46
47
48     def check_do_handshake(self):
49         if self.handshake:
50             return
51         try:
52             self._connection.do_handshake()
53         except ssl.SSLError as err:
54             if err.args[0] == ssl.SSL_ERROR_WANT_READ:
55                 return
56             elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE:
57                 self.poller.modify(self.raw_connection, READ_WRITE)
58                 return
59             else:
60                 raise
61
62         self.poller.modify(self.raw_connection, READ_ONLY)
63         self.handshake = True
64
65
66
67     def connection(self):
68         if self.stopped():
69             raise Exception("Session was stopped")
70         else:
71             return self._connection
72
73     def shutdown(self):
74         try:
75             self._connection.shutdown(socket.SHUT_RDWR)
76         except:
77             # print_log("problem shutting down", self.address)
78             # traceback.print_exc(file=sys.stdout)
79             pass
80
81         self._connection.close()
82
83     def send_response(self, response):
84         try:
85             msg = json.dumps(response) + '\n'
86         except:
87             traceback.print_exc(file=sys.stdout)
88             return
89
90         self.response_queue.put(msg)
91
92         try:
93             self.poller.modify(self.raw_connection, READ_WRITE)
94         except:
95             traceback.print_exc(file=sys.stdout)
96             return
97
98
99
100     def parse_message(self):
101
102         message = self.message
103         self.time = time.time()
104
105         raw_buffer = message.find('\n')
106         if raw_buffer == -1:
107             return False
108
109         raw_command = message[0:raw_buffer].strip()
110         self.message = message[raw_buffer + 1:]
111         return raw_command
112
113
114
115
116
117 class TcpServer(threading.Thread):
118
119     def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile):
120         self.shared = dispatcher.shared
121         self.dispatcher = dispatcher.request_dispatcher
122         threading.Thread.__init__(self)
123         self.daemon = True
124         self.host = host
125         self.port = port
126         self.lock = threading.Lock()
127         self.use_ssl = use_ssl
128         self.ssl_keyfile = ssl_keyfile
129         self.ssl_certfile = ssl_certfile
130
131         self.fd_to_session = {}
132         self.buffer_size = 4096
133
134
135
136
137
138     def handle_command(self, raw_command, session):
139         try:
140             command = json.loads(raw_command)
141         except:
142             session.send_response({"error": "bad JSON"})
143             return True
144
145         try:
146             # Try to load vital fields, and return an error if
147             # unsuccessful.
148             message_id = command['id']
149             method = command['method']
150         except KeyError:
151             # Return an error JSON in response.
152             session.send_response({"error": "syntax error", "request": raw_command})
153         else:
154             self.dispatcher.push_request(session, command)
155             ## sleep a bit to prevent a single session from DOSing the queue
156             #time.sleep(0.01)
157
158
159
160
161
162     def run(self):
163
164         print_log( ("SSL" if self.use_ssl else "TCP") + " server started on port %d"%self.port)
165         server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
166         server.setblocking(0)
167         server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
168         server.bind((self.host, self.port))
169         server.listen(5)
170         server_fd = server.fileno()
171
172         poller = select.poll()
173         poller.register(server)
174
175         def stop_session(fd):
176             try:
177                 # unregister before we close s 
178                 poller.unregister(fd)
179             except:
180                 print_log("unregister error", fd)
181                 traceback.print_exc(file=sys.stdout)
182
183             session = self.fd_to_session.pop(fd)
184             # this will close the socket
185             session.stop()
186
187
188         redo = []
189
190         while not self.shared.stopped():
191             if redo:
192                 events = redo
193                 redo = []
194             else:
195                 events = poller.poll(TIMEOUT)
196
197             for fd, flag in events:
198
199                 if fd != server_fd:
200                     session = self.fd_to_session[fd]
201                     s = session._connection
202                     try:
203                         session.check_do_handshake()
204                     except:
205                         stop_session(fd)
206                         continue
207
208                 # handle inputs
209                 if flag & (select.POLLIN | select.POLLPRI):
210
211                     if fd == server_fd:
212                         connection, address = server.accept()
213                         try:
214                             session = TcpSession(self.dispatcher, poller, connection, address, 
215                                                  use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
216                         except BaseException as e:
217                             print_log("cannot start TCP session", str(e), address)
218                             connection.close()
219                             continue
220
221                         connection = session._connection
222                         connection.setblocking(0)
223                         self.fd_to_session[ connection.fileno() ] = session
224                         poller.register(connection, READ_ONLY)
225                         try:
226                             session.check_do_handshake()
227                         except:
228                             print_log( "handshake failure", address )
229                             stop_session(connection.fileno())
230
231                         continue
232
233                     try:
234                         data = s.recv(self.buffer_size)
235                     except ssl.SSLError as x:
236                         if x.args[0] == ssl.SSL_ERROR_WANT_READ: 
237                             # print_log("error want read", x, fd)
238                             continue 
239                         else: 
240                             raise x
241                     except socket.error as x:
242                         # print_log("recv err", x)
243                         stop_session(fd)
244                         continue
245
246                     if data:
247                         if len(data) == self.buffer_size:
248                             redo.append( (fd, flag) )
249
250                         session.message += data
251                         while True:
252                             cmd = session.parse_message()
253                             if not cmd: 
254                                 break
255                             if cmd == 'quit':
256                                 data = False
257                                 break
258                             self.handle_command(cmd, session)
259
260                     if not data:
261                         stop_session(fd)
262                         continue
263
264                 elif flag & select.POLLHUP:
265                     print_log('client hung up', address)
266                     stop_session(fd)
267
268
269                 elif flag & select.POLLOUT:
270                     # Socket is ready to send data, if there is any to send.
271                     if session.retry_msg:
272                         next_msg = session.retry_msg
273                     else:
274                         try:
275                             next_msg = session.response_queue.get_nowait()
276                         except queue.Empty:
277                             # No messages waiting so stop checking for writability.
278                             poller.modify(s, READ_ONLY)
279                             continue
280
281                     try:
282                         sent = s.send(next_msg)
283                     except socket.error as x:
284                         # print_log("recv err", x)
285                         stop_session(fd)
286                         continue
287                         
288                     session.retry_msg = next_msg[sent:]
289
290
291                 elif flag & select.POLLERR:
292                     print_log('handling exceptional condition for', session.address)
293                     stop_session(fd)
294
295
296                 elif flag & select.POLLNVAL:
297                     print_log('invalid request', session.address)
298                     stop_session(fd)