do not do SSL handshake on connect
[electrum-server.git] / transports / stratum_tcp.py
1 import json
2 import Queue as queue
3 import socket
4 import threading
5 import time
6 import traceback, sys
7
8 from processor import Session, Dispatcher
9 from utils import print_log
10
11
12 class TcpSession(Session):
13
14     def __init__(self, connection, address, use_ssl, ssl_certfile, ssl_keyfile):
15         Session.__init__(self)
16         self.use_ssl = use_ssl
17         if use_ssl:
18             import ssl
19             self._connection = ssl.wrap_socket(
20                 connection,
21                 server_side=True,
22                 certfile=ssl_certfile,
23                 keyfile=ssl_keyfile,
24                 ssl_version=ssl.PROTOCOL_SSLv23,
25                 do_handshake_on_connect=False)
26         else:
27             self._connection = connection
28
29         self.address = address[0]
30         self.name = "TCP " if not use_ssl else "SSL "
31
32     def do_handshake(self):
33         if self.use_ssl:
34             self._connection.do_handshake()
35
36     def connection(self):
37         if self.stopped():
38             raise Exception("Session was stopped")
39         else:
40             return self._connection
41
42     def stop(self):
43         if self.stopped():
44             return
45
46         try:
47             self._connection.shutdown(socket.SHUT_RDWR)
48         except:
49             # print_log("problem shutting down", self.address)
50             # traceback.print_exc(file=sys.stdout)
51             pass
52
53         self._connection.close()
54         with self.lock:
55             self._stopped = True
56
57     def send_response(self, response):
58         data = json.dumps(response) + "\n"
59         # Possible race condition here by having session
60         # close connection?
61         # I assume Python connections are thread safe interfaces
62         try:
63             connection = self.connection()
64             while data:
65                 l = connection.send(data)
66                 data = data[l:]
67         except:
68             self.stop()
69
70
71 class TcpClientRequestor(threading.Thread):
72
73     def __init__(self, dispatcher, session):
74         self.shared = dispatcher.shared
75         self.dispatcher = dispatcher
76         self.message = ""
77         self.session = session
78         threading.Thread.__init__(self)
79
80     def run(self):
81         try:
82             self.session.do_handshake()
83         except:
84             return
85
86         while not self.shared.stopped():
87             if not self.update():
88                 break
89
90             self.session.time = time.time()
91
92             while self.parse():
93                 pass
94
95     def update(self):
96         data = self.receive()
97         if not data:
98             # close_session
99             self.session.stop()
100             return False
101
102         self.message += data
103         return True
104
105     def receive(self):
106         try:
107             return self.session.connection().recv(2048)
108         except:
109             return ''
110
111     def parse(self):
112         raw_buffer = self.message.find('\n')
113         if raw_buffer == -1:
114             return False
115
116         raw_command = self.message[0:raw_buffer].strip()
117         self.message = self.message[raw_buffer + 1:]
118         if raw_command == 'quit':
119             self.session.stop()
120             return False
121
122         try:
123             command = json.loads(raw_command)
124         except:
125             self.dispatcher.push_response({"error": "bad JSON", "request": raw_command})
126             return True
127
128         try:
129             # Try to load vital fields, and return an error if
130             # unsuccessful.
131             message_id = command['id']
132             method = command['method']
133         except KeyError:
134             # Return an error JSON in response.
135             self.dispatcher.push_response({"error": "syntax error", "request": raw_command})
136         else:
137             self.dispatcher.push_request(self.session, command)
138
139         return True
140
141
142 class TcpServer(threading.Thread):
143
144     def __init__(self, dispatcher, host, port, use_ssl, ssl_certfile, ssl_keyfile):
145         self.shared = dispatcher.shared
146         self.dispatcher = dispatcher.request_dispatcher
147         threading.Thread.__init__(self)
148         self.daemon = True
149         self.host = host
150         self.port = port
151         self.lock = threading.Lock()
152         self.use_ssl = use_ssl
153         self.ssl_keyfile = ssl_keyfile
154         self.ssl_certfile = ssl_certfile
155
156     def run(self):
157         print_log( ("SSL" if self.use_ssl else "TCP") + " server started on port %d"%self.port)
158         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
159         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
160         sock.bind((self.host, self.port))
161         sock.listen(5)
162
163         while not self.shared.stopped():
164
165             try:
166                 connection, address = sock.accept()
167             except:
168                 traceback.print_exc(file=sys.stdout)
169                 time.sleep(0.1)
170                 continue
171
172             try:
173                 session = TcpSession(connection, address, use_ssl=self.use_ssl, ssl_certfile=self.ssl_certfile, ssl_keyfile=self.ssl_keyfile)
174             except BaseException, e:
175                 error = str(e)
176                 print_log("cannot start TCP session", error, address)
177                 connection.close()
178                 time.sleep(0.1)
179                 continue
180
181             self.dispatcher.add_session(session)
182             self.dispatcher.collect_garbage()
183             client_req = TcpClientRequestor(self.dispatcher, session)
184             client_req.start()