Merge branch 'master' of github.com:spesmilo/electrum-server
[electrum-server.git] / transports / stratum_http.py
1 #!/usr/bin/env python
2 # Copyright(C) 2012 thomasv@gitorious
3
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful, but
10 # WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 # Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public
15 # License along with this program.  If not, see
16 # <http://www.gnu.org/licenses/agpl.html>.
17 """
18 sessions are identified with cookies
19  - each session has a buffer of responses to requests
20
21
22 from the processor point of view:
23  - the user only defines process() ; the rest is session management.  thus sessions should not belong to processor
24
25 """
26 import json
27 import logging
28 import os
29 import Queue
30 import SimpleXMLRPCServer
31 import socket
32 import SocketServer
33 import sys
34 import time
35 import threading
36 import traceback
37 import types
38
39 import jsonrpclib
40 from jsonrpclib import Fault
41 from jsonrpclib.jsonrpc import USE_UNIX_SOCKETS
42 from OpenSSL import SSL
43
44 try:
45     import fcntl
46 except ImportError:
47     # For Windows
48     fcntl = None
49
50
51 from processor import Session
52 from utils import random_string, print_log
53
54
55 def get_version(request):
56     # must be a dict
57     if 'jsonrpc' in request.keys():
58         return 2.0
59     if 'id' in request.keys():
60         return 1.0
61     return None
62
63
64 def validate_request(request):
65     if not isinstance(request, types.DictType):
66         return Fault(-32600, 'Request must be {}, not %s.' % type(request))
67     rpcid = request.get('id', None)
68     version = get_version(request)
69     if not version:
70         return Fault(-32600, 'Request %s invalid.' % request, rpcid=rpcid)
71     request.setdefault('params', [])
72     method = request.get('method', None)
73     params = request.get('params')
74     param_types = (types.ListType, types.DictType, types.TupleType)
75     if not method or type(method) not in types.StringTypes or type(params) not in param_types:
76         return Fault(-32600, 'Invalid request parameters or method.', rpcid=rpcid)
77     return True
78
79
80 class StratumJSONRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
81
82     def __init__(self, encoding=None):
83         # todo: use super
84         SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self, allow_none=True, encoding=encoding)
85
86     def _marshaled_dispatch(self, session_id, data, dispatch_method=None):
87         response = None
88         try:
89             request = jsonrpclib.loads(data)
90         except Exception, e:
91             fault = Fault(-32700, 'Request %s invalid. (%s)' % (data, e))
92             response = fault.response()
93             return response
94
95         session = self.dispatcher.get_session_by_address(session_id)
96         if not session:
97             return 'Error: session not found'
98         session.time = time.time()
99
100         responses = []
101         if not isinstance(request, types.ListType):
102             request = [request]
103
104         for req_entry in request:
105             result = validate_request(req_entry)
106             if type(result) is Fault:
107                 responses.append(result.response())
108                 continue
109
110             self.dispatcher.do_dispatch(session, req_entry)
111
112             if req_entry['method'] == 'server.stop':
113                 return json.dumps({'result': 'ok'})
114
115         r = self.poll_session(session)
116         for item in r:
117             responses.append(json.dumps(item))
118
119         if len(responses) > 1:
120             response = '[%s]' % ','.join(responses)
121         elif len(responses) == 1:
122             response = responses[0]
123         else:
124             response = ''
125
126         return response
127
128     def create_session(self):
129         session_id = random_string(20)
130         session = HttpSession(self.dispatcher, session_id)
131         return session_id
132
133     def poll_session(self, session):
134         q = session.pending_responses
135         responses = []
136         while not q.empty():
137             r = q.get()
138             responses.append(r)
139         #print "poll: %d responses"%len(responses)
140         return responses
141
142
143 class StratumJSONRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
144
145     def do_OPTIONS(self):
146         self.send_response(200)
147         self.send_header('Allow', 'GET, POST, OPTIONS')
148         self.send_header('Access-Control-Allow-Origin', '*')
149         self.send_header('Access-Control-Allow-Headers', 'Cache-Control, Content-Language, Content-Type, Expires, Last-Modified, Pragma, Accept-Language, Accept, Origin')
150         self.send_header('Content-Length', '0')
151         self.end_headers()
152
153     def do_GET(self):
154         if not self.is_rpc_path_valid():
155             self.report_404()
156             return
157         try:
158             session_id = None
159             c = self.headers.get('cookie')
160             if c:
161                 if c[0:8] == 'SESSION=':
162                     #print "found cookie", c[8:]
163                     session_id = c[8:]
164
165             if session_id is None:
166                 session_id = self.server.create_session()
167                 #print "setting cookie", session_id
168
169             data = json.dumps([])
170             response = self.server._marshaled_dispatch(session_id, data)
171             self.send_response(200)
172         except Exception, e:
173             self.send_response(500)
174             err_lines = traceback.format_exc().splitlines()
175             trace_string = '%s | %s' % (err_lines[-3], err_lines[-1])
176             fault = jsonrpclib.Fault(-32603, 'Server error: %s' % trace_string)
177             response = fault.response()
178             print "500", trace_string
179         if response is None:
180             response = ''
181
182         if session_id:
183             self.send_header("Set-Cookie", "SESSION=%s" % session_id)
184
185         self.send_header("Content-type", "application/json-rpc")
186         self.send_header("Access-Control-Allow-Origin", "*")
187         self.send_header("Content-length", str(len(response)))
188         self.end_headers()
189         self.wfile.write(response)
190         self.wfile.flush()
191         self.shutdown_connection()
192
193     def do_POST(self):
194         if not self.is_rpc_path_valid():
195             self.report_404()
196             return
197         try:
198             max_chunk_size = 10*1024*1024
199             size_remaining = int(self.headers["content-length"])
200             L = []
201             while size_remaining:
202                 chunk_size = min(size_remaining, max_chunk_size)
203                 L.append(self.rfile.read(chunk_size))
204                 size_remaining -= len(L[-1])
205             data = ''.join(L)
206
207             session_id = None
208             c = self.headers.get('cookie')
209             if c:
210                 if c[0:8] == 'SESSION=':
211                     #print "found cookie", c[8:]
212                     session_id = c[8:]
213
214             if session_id is None:
215                 session_id = self.server.create_session()
216                 #print "setting cookie", session_id
217
218             response = self.server._marshaled_dispatch(session_id, data)
219             self.send_response(200)
220         except Exception, e:
221             self.send_response(500)
222             err_lines = traceback.format_exc().splitlines()
223             trace_string = '%s | %s' % (err_lines[-3], err_lines[-1])
224             fault = jsonrpclib.Fault(-32603, 'Server error: %s' % trace_string)
225             response = fault.response()
226             print "500", trace_string
227         if response is None:
228             response = ''
229
230         if session_id:
231             self.send_header("Set-Cookie", "SESSION=%s" % session_id)
232
233         self.send_header("Content-type", "application/json-rpc")
234         self.send_header("Access-Control-Allow-Origin", "*")
235         self.send_header("Content-length", str(len(response)))
236         self.end_headers()
237         self.wfile.write(response)
238         self.wfile.flush()
239         self.shutdown_connection()
240
241     def shutdown_connection(self):
242         self.connection.shutdown(1)
243
244
245 class SSLRequestHandler(StratumJSONRPCRequestHandler):
246     def setup(self):
247         self.connection = self.request
248         self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
249         self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
250
251     def shutdown_connection(self):
252         self.connection.shutdown()
253
254
255 class SSLTCPServer(SocketServer.TCPServer):
256     def __init__(self, server_address, certfile, keyfile, RequestHandlerClass, bind_and_activate=True):
257         SocketServer.BaseServer.__init__(self, server_address, RequestHandlerClass)
258         ctx = SSL.Context(SSL.SSLv23_METHOD)
259         ctx.use_privatekey_file(keyfile)
260         ctx.use_certificate_file(certfile)
261         self.socket = SSL.Connection(ctx, socket.socket(self.address_family, self.socket_type))
262         if bind_and_activate:
263             self.server_bind()
264             self.server_activate()
265
266     def shutdown_request(self, request):
267         #request.shutdown()
268         pass
269
270
271 class StratumHTTPServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
272
273     allow_reuse_address = True
274
275     def __init__(self, addr, requestHandler=StratumJSONRPCRequestHandler,
276                  logRequests=False, encoding=None, bind_and_activate=True,
277                  address_family=socket.AF_INET):
278         self.logRequests = logRequests
279         StratumJSONRPCDispatcher.__init__(self, encoding)
280         # TCPServer.__init__ has an extra parameter on 2.6+, so
281         # check Python version and decide on how to call it
282         vi = sys.version_info
283         self.address_family = address_family
284         if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
285             # Unix sockets can't be bound if they already exist in the
286             # filesystem. The convention of e.g. X11 is to unlink
287             # before binding again.
288             if os.path.exists(addr):
289                 try:
290                     os.unlink(addr)
291                 except OSError:
292                     logging.warning("Could not unlink socket %s", addr)
293
294         SocketServer.TCPServer.__init__(self, addr, requestHandler, bind_and_activate)
295
296         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
297             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
298             flags |= fcntl.FD_CLOEXEC
299             fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
300
301
302 class StratumHTTPSSLServer(SSLTCPServer, StratumJSONRPCDispatcher):
303
304     allow_reuse_address = True
305
306     def __init__(self, addr, certfile, keyfile,
307                  requestHandler=SSLRequestHandler,
308                  logRequests=False, encoding=None, bind_and_activate=True,
309                  address_family=socket.AF_INET):
310
311         self.logRequests = logRequests
312         StratumJSONRPCDispatcher.__init__(self, encoding)
313         # TCPServer.__init__ has an extra parameter on 2.6+, so
314         # check Python version and decide on how to call it
315         vi = sys.version_info
316         self.address_family = address_family
317         if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
318             # Unix sockets can't be bound if they already exist in the
319             # filesystem. The convention of e.g. X11 is to unlink
320             # before binding again.
321             if os.path.exists(addr):
322                 try:
323                     os.unlink(addr)
324                 except OSError:
325                     logging.warning("Could not unlink socket %s", addr)
326
327         SSLTCPServer.__init__(self, addr, certfile, keyfile, requestHandler, bind_and_activate)
328
329         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
330             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
331             flags |= fcntl.FD_CLOEXEC
332             fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
333
334
335 class HttpSession(Session):
336
337     def __init__(self, dispatcher, session_id):
338         Session.__init__(self, dispatcher)
339         self.pending_responses = Queue.Queue()
340         self.address = session_id
341         self.name = "HTTP"
342         self.timeout = 60
343         self.dispatcher.add_session(self)
344
345     def send_response(self, response):
346         raw_response = json.dumps(response)
347         self.pending_responses.put(response)
348
349
350
351 class HttpServer(threading.Thread):
352     def __init__(self, dispatcher, host, port, use_ssl, certfile, keyfile):
353         self.shared = dispatcher.shared
354         self.dispatcher = dispatcher.request_dispatcher
355         threading.Thread.__init__(self)
356         self.daemon = True
357         self.host = host
358         self.port = port
359         self.use_ssl = use_ssl
360         self.certfile = certfile
361         self.keyfile = keyfile
362         self.lock = threading.Lock()
363
364     def run(self):
365         # see http://code.google.com/p/jsonrpclib/
366         from SocketServer import ThreadingMixIn
367         if self.use_ssl:
368             class StratumThreadedServer(ThreadingMixIn, StratumHTTPSSLServer):
369                 pass
370             self.server = StratumThreadedServer((self.host, self.port), self.certfile, self.keyfile)
371             print_log("HTTPS server started.")
372         else:
373             class StratumThreadedServer(ThreadingMixIn, StratumHTTPServer):
374                 pass
375             self.server = StratumThreadedServer((self.host, self.port))
376             print_log("HTTP server started.")
377
378         self.server.dispatcher = self.dispatcher
379         self.server.register_function(None, 'server.stop')
380         self.server.register_function(None, 'server.info')
381
382         self.server.serve_forever()