server-side support for ssl
[electrum-server.git] / transports / stratum_http.py
1 #!/usr/bin/env python
2 # Copyright(C) 2012 thomasv@gitorious
3
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful, but
10 # WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 # Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public
15 # License along with this program.  If not, see
16 # <http://www.gnu.org/licenses/agpl.html>.
17
18 import jsonrpclib
19 from jsonrpclib import Fault
20 from jsonrpclib.jsonrpc import USE_UNIX_SOCKETS
21 import SimpleXMLRPCServer
22 import SocketServer
23 import socket
24 import logging
25 import os, time
26 import types
27 import traceback
28 import sys, threading
29
30 from OpenSSL import SSL
31
32 try:
33     import fcntl
34 except ImportError:
35     # For Windows
36     fcntl = None
37
38 import json
39
40
41 """
42 sessions are identified with cookies
43  - each session has a buffer of responses to requests
44
45
46 from the processor point of view: 
47  - the user only defines process() ; the rest is session management.  thus sessions should not belong to processor
48
49 """
50
51
52 from processor import random_string
53
54
55 def get_version(request):
56     # must be a dict
57     if 'jsonrpc' in request.keys():
58         return 2.0
59     if 'id' in request.keys():
60         return 1.0
61     return None
62     
63 def validate_request(request):
64     if type(request) is not types.DictType:
65         fault = Fault(
66             -32600, 'Request must be {}, not %s.' % type(request)
67         )
68         return fault
69     rpcid = request.get('id', None)
70     version = get_version(request)
71     if not version:
72         fault = Fault(-32600, 'Request %s invalid.' % request, rpcid=rpcid)
73         return fault        
74     request.setdefault('params', [])
75     method = request.get('method', None)
76     params = request.get('params')
77     param_types = (types.ListType, types.DictType, types.TupleType)
78     if not method or type(method) not in types.StringTypes or \
79         type(params) not in param_types:
80         fault = Fault(
81             -32600, 'Invalid request parameters or method.', rpcid=rpcid
82         )
83         return fault
84     return True
85
86 class StratumJSONRPCDispatcher(SimpleXMLRPCServer.SimpleXMLRPCDispatcher):
87
88     def __init__(self, encoding=None):
89         SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self,
90                                         allow_none=True,
91                                         encoding=encoding)
92
93     def _marshaled_dispatch(self, session_id, data, dispatch_method = None):
94         response = None
95         try:
96             request = jsonrpclib.loads(data)
97         except Exception, e:
98             fault = Fault(-32700, 'Request %s invalid. (%s)' % (data, e))
99             response = fault.response()
100             return response
101
102         session = self.dispatcher.get_session_by_address(session_id)
103         if not session:
104             return 'Error: session not found'
105         session.time = time.time()
106
107         responses = []
108         if type(request) is not types.ListType:
109             request = [ request ]
110
111         for req_entry in request:
112             result = validate_request(req_entry)
113             if type(result) is Fault:
114                 responses.append(result.response())
115                 continue
116
117             self.dispatcher.do_dispatch(session, req_entry)
118                 
119             if req_entry['method'] == 'server.stop':
120                 return json.dumps({'result':'ok'})
121
122         r = self.poll_session(session)
123         for item in r:
124             responses.append(json.dumps(item))
125             
126         if len(responses) > 1:
127             response = '[%s]' % ','.join(responses)
128         elif len(responses) == 1:
129             response = responses[0]
130         else:
131             response = ''
132
133         return response
134
135
136     def create_session(self):
137         session_id = random_string(10)
138         session = HttpSession(session_id)
139         self.dispatcher.add_session(session)
140         return session_id
141
142     def poll_session(self, session):
143         q = session.pending_responses
144         responses = []
145         while not q.empty():
146             r = q.get()
147             responses.append(r)
148         #print "poll: %d responses"%len(responses)
149         return responses
150
151
152
153
154 class StratumJSONRPCRequestHandler(
155         SimpleXMLRPCServer.SimpleXMLRPCRequestHandler):
156     
157     def do_GET(self):
158         if not self.is_rpc_path_valid():
159             self.report_404()
160             return
161         try:
162             session_id = None
163             c = self.headers.get('cookie')
164             if c:
165                 if c[0:8]=='SESSION=':
166                     #print "found cookie", c[8:]
167                     session_id = c[8:]
168
169             if session_id is None:
170                 session_id = self.server.create_session()
171                 #print "setting cookie", session_id
172
173             data = json.dumps([])
174             response = self.server._marshaled_dispatch(session_id, data)
175             self.send_response(200)
176         except Exception, e:
177             self.send_response(500)
178             err_lines = traceback.format_exc().splitlines()
179             trace_string = '%s | %s' % (err_lines[-3], err_lines[-1])
180             fault = jsonrpclib.Fault(-32603, 'Server error: %s' % trace_string)
181             response = fault.response()
182             print "500", trace_string
183         if response == None:
184             response = ''
185
186         if session_id:
187             self.send_header("Set-Cookie", "SESSION=%s"%session_id)
188
189         self.send_header("Content-type", "application/json-rpc")
190         self.send_header("Access-Control-Allow-Origin", "*")
191         self.send_header("Content-length", str(len(response)))
192         self.end_headers()
193         self.wfile.write(response)
194         self.wfile.flush()
195         self.connection.shutdown(1)
196
197
198     def do_POST(self):
199         if not self.is_rpc_path_valid():
200             self.report_404()
201             return
202         try:
203             max_chunk_size = 10*1024*1024
204             size_remaining = int(self.headers["content-length"])
205             L = []
206             while size_remaining:
207                 chunk_size = min(size_remaining, max_chunk_size)
208                 L.append(self.rfile.read(chunk_size))
209                 size_remaining -= len(L[-1])
210             data = ''.join(L)
211
212             session_id = None
213             c = self.headers.get('cookie')
214             if c:
215                 if c[0:8]=='SESSION=':
216                     #print "found cookie", c[8:]
217                     session_id = c[8:]
218
219             if session_id is None:
220                 session_id = self.server.create_session()
221                 #print "setting cookie", session_id
222
223             response = self.server._marshaled_dispatch(session_id, data)
224             self.send_response(200)
225         except Exception, e:
226             self.send_response(500)
227             err_lines = traceback.format_exc().splitlines()
228             trace_string = '%s | %s' % (err_lines[-3], err_lines[-1])
229             fault = jsonrpclib.Fault(-32603, 'Server error: %s' % trace_string)
230             response = fault.response()
231             print "500", trace_string
232         if response == None:
233             response = ''
234
235         if session_id:
236             self.send_header("Set-Cookie", "SESSION=%s"%session_id)
237
238         self.send_header("Content-type", "application/json-rpc")
239         self.send_header("Access-Control-Allow-Origin", "*")
240         self.send_header("Content-length", str(len(response)))
241         self.end_headers()
242         self.wfile.write(response)
243         self.wfile.flush()
244         self.connection.shutdown(1)
245
246
247
248
249 class SSLTCPServer(SocketServer.TCPServer):
250
251     def __init__(self, server_address, certfile, keyfile, RequestHandlerClass, bind_and_activate=True):
252         SocketServer.BaseServer.__init__(self, server_address, RequestHandlerClass)
253         ctx = SSL.Context(SSL.SSLv3_METHOD)
254         self.certfile = certfile
255         self.keyfile = keyfile
256         #certfile = '/etc/ssl/certs/cacert.pem'
257         #keyfile = '/etc/ssl/private/cakey.pem'
258         ctx.use_privatekey_file(keyfile)
259         ctx.use_certificate_file(certfile)
260         self.socket = SSL.Connection(ctx, socket.socket(self.address_family, self.socket_type))
261         if bind_and_activate:
262             self.server_bind()
263             self.server_activate()
264
265     def shutdown_request(self,request):
266         request.shutdown()
267
268
269 class StratumHTTPServer(SocketServer.TCPServer, StratumJSONRPCDispatcher):
270
271     allow_reuse_address = True
272
273     def __init__(self, addr, requestHandler=StratumJSONRPCRequestHandler,
274                  logRequests=False, encoding=None, bind_and_activate=True,
275                  address_family=socket.AF_INET):
276         self.logRequests = logRequests
277         StratumJSONRPCDispatcher.__init__(self, encoding)
278         # TCPServer.__init__ has an extra parameter on 2.6+, so
279         # check Python version and decide on how to call it
280         vi = sys.version_info
281         self.address_family = address_family
282         if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
283             # Unix sockets can't be bound if they already exist in the
284             # filesystem. The convention of e.g. X11 is to unlink
285             # before binding again.
286             if os.path.exists(addr): 
287                 try:
288                     os.unlink(addr)
289                 except OSError:
290                     logging.warning("Could not unlink socket %s", addr)
291
292         SocketServer.TCPServer.__init__(self, addr, requestHandler, bind_and_activate)
293
294         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
295             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
296             flags |= fcntl.FD_CLOEXEC
297             fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
298
299
300 class StratumHTTPSSLServer(SSLTCPServer, StratumJSONRPCDispatcher):
301
302     allow_reuse_address = True
303
304     def __init__(self, addr, certfile, keyfile,
305                  requestHandler=StratumJSONRPCRequestHandler,
306                  logRequests=False, encoding=None, bind_and_activate=True,
307                  address_family=socket.AF_INET):
308
309         self.logRequests = logRequests
310         StratumJSONRPCDispatcher.__init__(self, encoding)
311         # TCPServer.__init__ has an extra parameter on 2.6+, so
312         # check Python version and decide on how to call it
313         vi = sys.version_info
314         self.address_family = address_family
315         if USE_UNIX_SOCKETS and address_family == socket.AF_UNIX:
316             # Unix sockets can't be bound if they already exist in the
317             # filesystem. The convention of e.g. X11 is to unlink
318             # before binding again.
319             if os.path.exists(addr): 
320                 try:
321                     os.unlink(addr)
322                 except OSError:
323                     logging.warning("Could not unlink socket %s", addr)
324
325         SSLTCPServer.__init__(self, addr, certfile, keyfile, requestHandler, bind_and_activate)
326
327         if fcntl is not None and hasattr(fcntl, 'FD_CLOEXEC'):
328             flags = fcntl.fcntl(self.fileno(), fcntl.F_GETFD)
329             flags |= fcntl.FD_CLOEXEC
330             fcntl.fcntl(self.fileno(), fcntl.F_SETFD, flags)
331
332
333
334
335
336
337 from processor import Session
338 import Queue
339
340 class HttpSession(Session):
341
342     def __init__(self, session_id):
343         Session.__init__(self)
344         self.pending_responses = Queue.Queue()
345         self.address = session_id
346         self.name = "HTTP"
347
348     def send_response(self, response):
349         raw_response = json.dumps(response)
350         self.pending_responses.put(response)
351
352     def stopped(self):
353         with self.lock:
354             if time.time() - self.time > 60:
355                 self._stopped = True
356             return self._stopped
357
358 class HttpServer(threading.Thread):
359     def __init__(self, dispatcher, host, port, use_ssl, certfile, keyfile):
360         self.shared = dispatcher.shared
361         self.dispatcher = dispatcher.request_dispatcher
362         threading.Thread.__init__(self)
363         self.daemon = True
364         self.host = host
365         self.port = port
366         self.use_ssl = use_ssl
367         self.certfile = certfile
368         self.keyfile = keyfile
369         self.lock = threading.Lock()
370
371
372     def run(self):
373         # see http://code.google.com/p/jsonrpclib/
374         from SocketServer import ThreadingMixIn
375         if self.use_ssl:
376             class StratumThreadedServer(ThreadingMixIn, StratumHTTPSSLServer): pass
377             self.server = StratumThreadedServer(( self.host, self.port), self.certfile, self.keyfile)
378             print "HTTPS server started."
379         else:
380             class StratumThreadedServer(ThreadingMixIn, StratumHTTPServer): pass
381             self.server = StratumThreadedServer(( self.host, self.port))
382             print "HTTP server started."
383
384         self.server.dispatcher = self.dispatcher
385         self.server.register_function(None, 'server.stop')
386         self.server.register_function(None, 'server.info')
387
388         self.server.serve_forever()
389