fix "wrong version number" error for HTTPS by using SSLv23
[electrum-server.git] / transports / stratum_http.py
1 #!/usr/bin/env python
2 # Copyright(C) 2012 thomasv@gitorious
3
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful, but
10 # WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 # Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public
15 # License along with this program.  If not, see
16 # <http://www.gnu.org/licenses/agpl.html>.
17
18 import jsonrpclib
19 from jsonrpclib import Fault
20 from jsonrpclib.jsonrpc import USE_UNIX_SOCKETS
21 import SimpleXMLRPCServer
22 import SocketServer
23 import socket
24 import logging
25 import os, time
26 import types
27 import traceback
28 import sys, threading
29
30 from OpenSSL import SSL
31
32 try:
33     import fcntl
34 except ImportError:
35     # For Windows
36     fcntl = None
37
38 import json
39
40
41 """
42 sessions are identified with cookies
43  - each session has a buffer of responses to requests
44
45
46 from the processor point of view: 
47  - the user only defines process() ; the rest is session management.  thus sessions should not belong to processor
48
49 """
50
51
52 from processor import random_string, print_log
53
54
55 def get_version(request):
56     # must be a dict
57     if 'jsonrpc' in request.keys():
58         return 2.0
59     if 'id' in request.keys():
60         return 1.0
61     return None
62     
63 def validate_request(request):
64     if type(request) is not types.DictType:
65         fault = Fault(
66             -32600, 'Request must be {}, not %s.' % type(request)
67         )
68         return fault
69     rpcid = request.get('id', None)
70     version = get_version(request)
71     if not version:
72         fault = Fault(-32600, 'Request %s invalid.' % request, rpcid=rpcid)
73         return fault        
74     request.setdefault('params', [])
75     method = request.get('method', None)
76     params = request.get('params')
77     param_types = (types.ListType, types.DictType, types.TupleType)
78     if not method or type(method) not in types.StringTypes or \
79         type(params) not in param_types:
80         fault = Fault(
81             -32600, 'Invalid request parameters or method.', rpcid=rpcid
82         )
83         return fault
84     return True
85
86 class StratumJSONRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
87
88     def __init__(self, encoding=None):
89         SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self,
90                                         allow_none=True,
91                                         encoding=encoding)
92
93     def _marshaled_dispatch(self, session_id, data, dispatch_method = None):
94         response = None
95         try:
96             request = jsonrpclib.loads(data)
97         except Exception, e:
98             fault = Fault(-32700, 'Request %s invalid. (%s)' % (data, e))
99             response = fault.response()
100             return response
101
102         session = self.dispatcher.get_session_by_address(session_id)
103         if not session:
104             return 'Error: session not found'
105         session.time = time.time()
106
107         responses = []
108         if type(request) is not types.ListType:
109             request = [ request ]
110
111         for req_entry in request:
112             result = validate_request(req_entry)
113             if type(result) is Fault:
114                 responses.append(result.response())
115                 continue
116
117             self.dispatcher.do_dispatch(session, req_entry)
118                 
119             if req_entry['method'] == 'server.stop':
120                 return json.dumps({'result':'ok'})
121
122         r = self.poll_session(session)
123         for item in r:
124             responses.append(json.dumps(item))
125             
126         if len(responses) > 1:
127             response = '[%s]' % ','.join(responses)
128         elif len(responses) == 1:
129             response = responses[0]
130         else:
131             response = ''
132
133         return response
134
135
136     def create_session(self):
137         session_id = random_string(10)
138         session = HttpSession(session_id)
139         self.dispatcher.add_session(session)
140         return session_id
141
142     def poll_session(self, session):
143         q = session.pending_responses
144         responses = []
145         while not q.empty():
146             r = q.get()
147             responses.append(r)
148         #print "poll: %d responses"%len(responses)
149         return responses
150
151
152
153
154 class StratumJSONRPCRequestHandler(
155         SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
156             
157     def do_OPTIONS(self):
158         self.send_response(200)
159         self.send_header('Allow', 'GET, POST, OPTIONS')
160         self.send_header('Access-Control-Allow-Origin', '*')
161         self.send_header('Access-Control-Allow-Headers', '*')
162         self.send_header('Content-Length', '0')
163         self.end_headers()
164             
165     def do_GET(self):
166         if not self.is_rpc_path_valid():
167             self.report_404()
168             return
169         try:
170             session_id = None
171             c = self.headers.get('cookie')
172             if c:
173                 if c[0:8]=='SESSION=':
174                     #print "found cookie", c[8:]
175                     session_id = c[8:]
176
177             if session_id is None:
178                 session_id = self.server.create_session()
179                 #print "setting cookie", session_id
180
181             data = json.dumps([])
182             response = self.server._marshaled_dispatch(session_id, data)
183             self.send_response(200)
184         except Exception, e:
185             self.send_response(500)
186             err_lines = traceback.format_exc().splitlines()
187             trace_string = '%s | %s' % (err_lines[-3], err_lines[-1])
188             fault = jsonrpclib.Fault(-32603, 'Server error: %s' % trace_string)
189             response = fault.response()
190             print "500", trace_string
191         if response == None:
192             response = ''
193
194         if session_id:
195             self.send_header("Set-Cookie", "SESSION=%s"%session_id)
196
197         self.send_header("Content-type", "application/json-rpc")
198         self.send_header("Access-Control-Allow-Origin", "*")
199         self.send_header("Content-length", str(len(response)))
200         self.end_headers()
201         self.wfile.write(response)
202         self.wfile.flush()
203         self.shutdown_connection()
204
205
206     def do_POST(self):
207         if not self.is_rpc_path_valid():
208             self.report_404()
209             return
210         try:
211             max_chunk_size = 10*1024*1024
212             size_remaining = int(self.headers["content-length"])
213             L = []
214             while size_remaining:
215                 chunk_size = min(size_remaining, max_chunk_size)
216                 L.append(self.rfile.read(chunk_size))
217                 size_remaining -= len(L[-1])
218             data = ''.join(L)
219
220             session_id = None
221             c = self.headers.get('cookie')
222             if c:
223                 if c[0:8]=='SESSION=':
224                     #print "found cookie", c[8:]
225                     session_id = c[8:]
226
227             if session_id is None:
228                 session_id = self.server.create_session()
229                 #print "setting cookie", session_id
230
231             response = self.server._marshaled_dispatch(session_id, data)
232             self.send_response(200)
233         except Exception, e:
234             self.send_response(500)
235             err_lines = traceback.format_exc().splitlines()
236             trace_string = '%s | %s' % (err_lines[-3], err_lines[-1])
237             fault = jsonrpclib.Fault(-32603, 'Server error: %s' % trace_string)
238             response = fault.response()
239             print "500", trace_string
240         if response == None:
241             response = ''
242
243         if session_id:
244             self.send_header("Set-Cookie", "SESSION=%s"%session_id)
245
246         self.send_header("Content-type", "application/json-rpc")
247         self.send_header("Access-Control-Allow-Origin", "*")
248         self.send_header("Content-length", str(len(response)))
249         self.end_headers()
250         self.wfile.write(response)
251         self.wfile.flush()
252         self.shutdown_connection()
253
254     def shutdown_connection(self):
255         self.connection.shutdown(1)
256
257
258 class SSLRequestHandler(StratumJSONRPCRequestHandler):
259     def setup(self):
260         self.connection = self.request
261         self.rfile = socket._fileobject(self.request, "rb", self.rbufsize)
262         self.wfile = socket._fileobject(self.request, "wb", self.wbufsize)
263
264     def shutdown_connection(self):
265         self.connection.shutdown()
266
267
268 class SSLTCPServer(SocketServer.TCPServer):
269     def __init__(self, server_address, certfile, keyfile, RequestHandlerClass, bind_and_activate=True):
270         SocketServer.BaseServer.__init__(self, server_address, RequestHandlerClass)
271         ctx = SSL.Context(SSL.SSLv23_METHOD)
272         ctx.use_privatekey_file(keyfile)
273         ctx.use_certificate_file(certfile)
274         self.socket = SSL.Connection(ctx, socket.socket(self.address_family, self.socket_type))
275         if bind_and_activate:
276             self.server_bind()
277             self.server_activate()
278
279     def shutdown_request(self,request):
280         #request.shutdown()
281         pass
282
283
284 class StratumHTTPServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
285
286     allow_reuse_address = True
287
288     def __init__(self, addr, requestHandler=StratumJSONRPCRequestHandler,
289                  logRequests=False, encoding=None, bind_and_activate=True,
290                  address_family=socket.AF_INET):
291         self.logRequests = logRequests
292         StratumJSONRPCDispatcher.__init__(self, encoding)
293         # TCPServer.__init__ has an extra parameter on 2.6+, so
294         # check Python version and decide on how to call it
295         vi = sys.version_info
296         self.address_family = address_family
297         if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
298             # Unix sockets can't be bound if they already exist in the
299             # filesystem. The convention of e.g. X11 is to unlink
300             # before binding again.
301             if os.path.exists(addr): 
302                 try:
303                     os.unlink(addr)
304                 except OSError:
305                     logging.warning("Could not unlink socket %s", addr)
306
307         SocketServer.TCPServer.__init__(self, addr, requestHandler, bind_and_activate)
308
309         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
310             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
311             flags |= fcntl.FD_CLOEXEC
312             fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
313
314
315 class StratumHTTPSSLServer(SSLTCPServer, StratumJSONRPCDispatcher):
316
317     allow_reuse_address = True
318
319     def __init__(self, addr, certfile, keyfile,
320                  requestHandler=SSLRequestHandler,
321                  logRequests=False, encoding=None, bind_and_activate=True,
322                  address_family=socket.AF_INET):
323
324         self.logRequests = logRequests
325         StratumJSONRPCDispatcher.__init__(self, encoding)
326         # TCPServer.__init__ has an extra parameter on 2.6+, so
327         # check Python version and decide on how to call it
328         vi = sys.version_info
329         self.address_family = address_family
330         if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
331             # Unix sockets can't be bound if they already exist in the
332             # filesystem. The convention of e.g. X11 is to unlink
333             # before binding again.
334             if os.path.exists(addr): 
335                 try:
336                     os.unlink(addr)
337                 except OSError:
338                     logging.warning("Could not unlink socket %s", addr)
339
340         SSLTCPServer.__init__(self, addr, certfile, keyfile, requestHandler, bind_and_activate)
341
342         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
343             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
344             flags |= fcntl.FD_CLOEXEC
345             fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
346
347
348
349
350
351
352 from processor import Session
353 import Queue
354
355 class HttpSession(Session):
356
357     def __init__(self, session_id):
358         Session.__init__(self)
359         self.pending_responses = Queue.Queue()
360         self.address = session_id
361         self.name = "HTTP"
362
363     def send_response(self, response):
364         raw_response = json.dumps(response)
365         self.pending_responses.put(response)
366
367     def stopped(self):
368         with self.lock:
369             if time.time() - self.time > 60:
370                 self._stopped = True
371             return self._stopped
372
373 class HttpServer(threading.Thread):
374     def __init__(self, dispatcher, host, port, use_ssl, certfile, keyfile):
375         self.shared = dispatcher.shared
376         self.dispatcher = dispatcher.request_dispatcher
377         threading.Thread.__init__(self)
378         self.daemon = True
379         self.host = host
380         self.port = port
381         self.use_ssl = use_ssl
382         self.certfile = certfile
383         self.keyfile = keyfile
384         self.lock = threading.Lock()
385
386
387     def run(self):
388         # see http://code.google.com/p/jsonrpclib/
389         from SocketServer import ThreadingMixIn
390         if self.use_ssl:
391             class StratumThreadedServer(ThreadingMixIn, StratumHTTPSSLServer): pass
392             self.server = StratumThreadedServer(( self.host, self.port), self.certfile, self.keyfile)
393             print_log( "HTTPS server started.")
394         else:
395             class StratumThreadedServer(ThreadingMixIn, StratumHTTPServer): pass
396             self.server = StratumThreadedServer(( self.host, self.port))
397             print_log( "HTTP server started.")
398
399         self.server.dispatcher = self.dispatcher
400         self.server.register_function(None, 'server.stop')
401         self.server.register_function(None, 'server.info')
402
403         self.server.serve_forever()
404