add delay if tcp session cannot be started
[electrum-server.git] / transports / stratum_http.py
1 #!/usr/bin/env python
2 # Copyright(C) 2012 thomasv@gitorious
3
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful, but
10 # WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 # Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public
15 # License along with this program.  If not, see
16 # <http://www.gnu.org/licenses/agpl.html>.
17 """
18 sessions are identified with cookies
19  - each session has a buffer of responses to requests
20
21
22 from the processor point of view:
23  - the user only defines process() ; the rest is session management.  thus sessions should not belong to processor
24
25 """
26 import json
27 import logging
28 import os
29 import Queue
30 import SimpleXMLRPCServer
31 import socket
32 import SocketServer
33 import sys
34 import time
35 import threading
36 import traceback
37 import types
38
39 import jsonrpclib
40 from jsonrpclib import Fault
41 from jsonrpclib.jsonrpc import USE_UNIX_SOCKETS
42 from OpenSSL import SSL
43
44 try:
45     import fcntl
46 except ImportError:
47     # For Windows
48     fcntl = None
49
50
51 from processor import Session
52 from utils import random_string, print_log
53
54
55 def get_version(request):
56     # must be a dict
57     if 'jsonrpc' in request.keys():
58         return 2.0
59     if 'id' in request.keys():
60         return 1.0
61     return None
62
63
64 def validate_request(request):
65     if not isinstance(request, types.DictType):
66         return Fault(-32600, 'Request must be {}, not %s.' % type(request))
67     rpcid = request.get('id', None)
68     version = get_version(request)
69     if not version:
70         return Fault(-32600, 'Request %s invalid.' % request, rpcid=rpcid)
71     request.setdefault('params', [])
72     method = request.get('method', None)
73     params = request.get('params')
74     param_types = (types.ListType, types.DictType, types.TupleType)
75     if not method or type(method) not in types.StringTypes or type(params) not in param_types:
76         return Fault(-32600, 'Invalid request parameters or method.', rpcid=rpcid)
77     return True
78
79
80 class StratumJSONRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
81
82     def __init__(self, encoding=None):
83         # todo: use super
84         SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self, allow_none=True, encoding=encoding)
85
86     def _marshaled_dispatch(self, session_id, data, dispatch_method=None):
87         response = None
88         try:
89             request = jsonrpclib.loads(data)
90         except Exception, e:
91             fault = Fault(-32700, 'Request %s invalid. (%s)' % (data, e))
92             response = fault.response()
93             return response
94
95         session = self.dispatcher.get_session_by_address(session_id)
96         if not session:
97             return 'Error: session not found'
98         session.time = time.time()
99
100         responses = []
101         if not isinstance(request, types.ListType):
102             request = [request]
103
104         for req_entry in request:
105             result = validate_request(req_entry)
106             if type(result) is Fault:
107                 responses.append(result.response())
108                 continue
109
110             self.dispatcher.do_dispatch(session, req_entry)
111
112             if req_entry['method'] == 'server.stop':
113                 return json.dumps({'result': 'ok'})
114
115         r = self.poll_session(session)
116         for item in r:
117             responses.append(json.dumps(item))
118
119         if len(responses) > 1:
120             response = '[%s]' % ','.join(responses)
121         elif len(responses) == 1:
122             response = responses[0]
123         else:
124             response = ''
125
126         return response
127
128     def create_session(self):
129         session_id = random_string(10)
130         session = HttpSession(session_id)
131         self.dispatcher.add_session(session)
132         return session_id
133
134     def poll_session(self, session):
135         q = session.pending_responses
136         responses = []
137         while not q.empty():
138             r = q.get()
139             responses.append(r)
140         #print "poll: %d responses"%len(responses)
141         return responses
142
143
144 class StratumJSONRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
145
146     def do_OPTIONS(self):
147         self.send_response(200)
148         self.send_header('Allow', 'GET, POST, OPTIONS')
149         self.send_header('Access-Control-Allow-Origin', '*')
150         self.send_header('Access-Control-Allow-Headers', 'Cache-Control, Content-Language, Content-Type, Expires, Last-Modified, Pragma, Accept-Language, Accept, Origin')
151         self.send_header('Content-Length', '0')
152         self.end_headers()
153
154     def do_GET(self):
155         if not self.is_rpc_path_valid():
156             self.report_404()
157             return
158         try:
159             session_id = None
160             c = self.headers.get('cookie')
161             if c:
162                 if c[0:8] == 'SESSION=':
163                     #print "found cookie", c[8:]
164                     session_id = c[8:]
165
166             if session_id is None:
167                 session_id = self.server.create_session()
168                 #print "setting cookie", session_id
169
170             data = json.dumps([])
171             response = self.server._marshaled_dispatch(session_id, data)
172             self.send_response(200)
173         except Exception, e:
174             self.send_response(500)
175             err_lines = traceback.format_exc().splitlines()
176             trace_string = '%s | %s' % (err_lines[-3], err_lines[-1])
177             fault = jsonrpclib.Fault(-32603, 'Server error: %s' % trace_string)
178             response = fault.response()
179             print "500", trace_string
180         if response is None:
181             response = ''
182
183         if session_id:
184             self.send_header("Set-Cookie", "SESSION=%s" % session_id)
185
186         self.send_header("Content-type", "application/json-rpc")
187         self.send_header("Access-Control-Allow-Origin", "*")
188         self.send_header("Content-length", str(len(response)))
189         self.end_headers()
190         self.wfile.write(response)
191         self.wfile.flush()
192         self.shutdown_connection()
193
194     def do_POST(self):
195         if not self.is_rpc_path_valid():
196             self.report_404()
197             return
198         try:
199             max_chunk_size = 10*1024*1024
200             size_remaining = int(self.headers["content-length"])
201             L = []
202             while size_remaining:
203                 chunk_size = min(size_remaining, max_chunk_size)
204                 L.append(self.rfile.read(chunk_size))
205                 size_remaining -= len(L[-1])
206             data = ''.join(L)
207
208             session_id = None
209             c = self.headers.get('cookie')
210             if c:
211                 if c[0:8] == 'SESSION=':
212                     #print "found cookie", c[8:]
213                     session_id = c[8:]
214
215             if session_id is None:
216                 session_id = self.server.create_session()
217                 #print "setting cookie", session_id
218
219             response = self.server._marshaled_dispatch(session_id, data)
220             self.send_response(200)
221         except Exception, e:
222             self.send_response(500)
223             err_lines = traceback.format_exc().splitlines()
224             trace_string = '%s | %s' % (err_lines[-3], err_lines[-1])
225             fault = jsonrpclib.Fault(-32603, 'Server error: %s' % trace_string)
226             response = fault.response()
227             print "500", trace_string
228         if response is None:
229             response = ''
230
231         if session_id:
232             self.send_header("Set-Cookie", "SESSION=%s" % session_id)
233
234         self.send_header("Content-type", "application/json-rpc")
235         self.send_header("Access-Control-Allow-Origin", "*")
236         self.send_header("Content-length", str(len(response)))
237         self.end_headers()
238         self.wfile.write(response)
239         self.wfile.flush()
240         self.shutdown_connection()
241
242     def shutdown_connection(self):
243         self.connection.shutdown(1)
244
245
246 class SSLRequestHandler(StratumJSONRPCRequestHandler):
247     def setup(self):
248         self.connection = self.request
249         self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
250         self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
251
252     def shutdown_connection(self):
253         self.connection.shutdown()
254
255
256 class SSLTCPServer(SocketServer.TCPServer):
257     def __init__(self, server_address, certfile, keyfile, RequestHandlerClass, bind_and_activate=True):
258         SocketServer.BaseServer.__init__(self, server_address, RequestHandlerClass)
259         ctx = SSL.Context(SSL.SSLv23_METHOD)
260         ctx.use_privatekey_file(keyfile)
261         ctx.use_certificate_file(certfile)
262         self.socket = SSL.Connection(ctx, socket.socket(self.address_family, self.socket_type))
263         if bind_and_activate:
264             self.server_bind()
265             self.server_activate()
266
267     def shutdown_request(self, request):
268         #request.shutdown()
269         pass
270
271
272 class StratumHTTPServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
273
274     allow_reuse_address = True
275
276     def __init__(self, addr, requestHandler=StratumJSONRPCRequestHandler,
277                  logRequests=False, encoding=None, bind_and_activate=True,
278                  address_family=socket.AF_INET):
279         self.logRequests = logRequests
280         StratumJSONRPCDispatcher.__init__(self, encoding)
281         # TCPServer.__init__ has an extra parameter on 2.6+, so
282         # check Python version and decide on how to call it
283         vi = sys.version_info
284         self.address_family = address_family
285         if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
286             # Unix sockets can't be bound if they already exist in the
287             # filesystem. The convention of e.g. X11 is to unlink
288             # before binding again.
289             if os.path.exists(addr):
290                 try:
291                     os.unlink(addr)
292                 except OSError:
293                     logging.warning("Could not unlink socket %s", addr)
294
295         SocketServer.TCPServer.__init__(self, addr, requestHandler, bind_and_activate)
296
297         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
298             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
299             flags |= fcntl.FD_CLOEXEC
300             fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
301
302
303 class StratumHTTPSSLServer(SSLTCPServer, StratumJSONRPCDispatcher):
304
305     allow_reuse_address = True
306
307     def __init__(self, addr, certfile, keyfile,
308                  requestHandler=SSLRequestHandler,
309                  logRequests=False, encoding=None, bind_and_activate=True,
310                  address_family=socket.AF_INET):
311
312         self.logRequests = logRequests
313         StratumJSONRPCDispatcher.__init__(self, encoding)
314         # TCPServer.__init__ has an extra parameter on 2.6+, so
315         # check Python version and decide on how to call it
316         vi = sys.version_info
317         self.address_family = address_family
318         if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
319             # Unix sockets can't be bound if they already exist in the
320             # filesystem. The convention of e.g. X11 is to unlink
321             # before binding again.
322             if os.path.exists(addr):
323                 try:
324                     os.unlink(addr)
325                 except OSError:
326                     logging.warning("Could not unlink socket %s", addr)
327
328         SSLTCPServer.__init__(self, addr, certfile, keyfile, requestHandler, bind_and_activate)
329
330         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
331             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
332             flags |= fcntl.FD_CLOEXEC
333             fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
334
335
336 class HttpSession(Session):
337
338     def __init__(self, session_id):
339         Session.__init__(self)
340         self.pending_responses = Queue.Queue()
341         self.address = session_id
342         self.name = "HTTP"
343
344     def send_response(self, response):
345         raw_response = json.dumps(response)
346         self.pending_responses.put(response)
347
348     def stopped(self):
349         with self.lock:
350             if time.time() - self.time > 60:
351                 self._stopped = True
352             return self._stopped
353
354
355 class HttpServer(threading.Thread):
356     def __init__(self, dispatcher, host, port, use_ssl, certfile, keyfile):
357         self.shared = dispatcher.shared
358         self.dispatcher = dispatcher.request_dispatcher
359         threading.Thread.__init__(self)
360         self.daemon = True
361         self.host = host
362         self.port = port
363         self.use_ssl = use_ssl
364         self.certfile = certfile
365         self.keyfile = keyfile
366         self.lock = threading.Lock()
367
368     def run(self):
369         # see http://code.google.com/p/jsonrpclib/
370         from SocketServer import ThreadingMixIn
371         if self.use_ssl:
372             class StratumThreadedServer(ThreadingMixIn, StratumHTTPSSLServer):
373                 pass
374             self.server = StratumThreadedServer((self.host, self.port), self.certfile, self.keyfile)
375             print_log("HTTPS server started.")
376         else:
377             class StratumThreadedServer(ThreadingMixIn, StratumHTTPServer):
378                 pass
379             self.server = StratumThreadedServer((self.host, self.port))
380             print_log("HTTP server started.")
381
382         self.server.dispatcher = self.dispatcher
383         self.server.register_function(None, 'server.stop')
384         self.server.register_function(None, 'server.info')
385
386         self.server.serve_forever()