don't close session twice; get connection inside try statement
[electrum-server.git] / server.py
1 #!/usr/bin/env python
2 # Copyright(C) 2012 thomasv@gitorious
3
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as
6 # published by the Free Software Foundation, either version 3 of the
7 # License, or (at your option) any later version.
8 #
9 # This program is distributed in the hope that it will be useful, but
10 # WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 # Affero General Public License for more details.
13 #
14 # You should have received a copy of the GNU Affero General Public
15 # License along with this program.  If not, see
16 # <http://www.gnu.org/licenses/agpl.html>.
17
18 """
19 Todo:
20    * server should check and return bitcoind status..
21    * improve txpoint sorting
22    * command to check cache
23
24  mempool transactions do not need to be added to the database; it slows it down
25 """
26
27 import abe_backend
28
29
30
31
32 import time, json, socket, operator, thread, ast, sys, re, traceback
33 import ConfigParser
34 from json import dumps, loads
35 import urllib
36
37
38 config = ConfigParser.ConfigParser()
39 # set some defaults, which will be overwritten by the config file
40 config.add_section('server')
41 config.set('server','banner', 'Welcome to Electrum!')
42 config.set('server', 'host', 'localhost')
43 config.set('server', 'port', '50000')
44 config.set('server', 'password', '')
45 config.set('server', 'irc', 'yes')
46 config.set('server', 'ircname', 'Electrum server')
47 config.add_section('database')
48 config.set('database', 'type', 'psycopg2')
49 config.set('database', 'database', 'abe')
50
51 try:
52     f = open('/etc/electrum.conf','r')
53     config.readfp(f)
54     f.close()
55 except:
56     print "Could not read electrum.conf. I will use the default values."
57
58 try:
59     f = open('/etc/electrum.banner','r')
60     config.set('server','banner', f.read())
61     f.close()
62 except:
63     pass
64
65
66 password = config.get('server','password')
67
68 stopping = False
69 block_number = -1
70 sessions = {}
71 sessions_sub_numblocks = {} # sessions that have subscribed to the service
72
73 m_sessions = [{}] # served by http
74
75 peer_list = {}
76
77 from Queue import Queue
78 input_queue = Queue()
79 output_queue = Queue()
80
81
82
83
84 def random_string(N):
85     import random, string
86     return ''.join(random.choice(string.ascii_uppercase + string.digits) for x in range(N))
87
88     
89
90 def cmd_stop(_,__,pw):
91     global stopping
92     if password == pw:
93         stopping = True
94         return 'ok'
95     else:
96         return 'wrong password'
97
98 def cmd_load(_,__,pw):
99     if password == pw:
100         return repr( len(sessions) )
101     else:
102         return 'wrong password'
103
104
105
106
107
108 def modified_addresses(a_session):
109     #t1 = time.time()
110     import copy
111     session = copy.deepcopy(a_session)
112     addresses = session['addresses']
113     session['last_time'] = time.time()
114     ret = {}
115     k = 0
116     for addr in addresses:
117         status = store.get_status( addr )
118         msg_id, last_status = addresses.get( addr )
119         if last_status != status:
120             addresses[addr] = msg_id, status
121             ret[addr] = status
122
123     #t2 = time.time() - t1 
124     #if t2 > 10: print "high load:", session_id, "%d/%d"%(k,len(addresses)), t2
125     return ret, addresses
126
127
128 def poll_session(session_id): 
129     # native
130     session = sessions.get(session_id)
131     if session is None:
132         print time.asctime(), "session not found", session_id
133         return -1, {}
134     else:
135         sessions[session_id]['last_time'] = time.time()
136         ret, addresses = modified_addresses(session)
137         if ret: sessions[session_id]['addresses'] = addresses
138         return repr( (block_number,ret))
139
140
141 def poll_session_json(session_id, message_id):
142     session = m_sessions[0].get(session_id)
143     if session is None:
144         raise BaseException("session not found %s"%session_id)
145     else:
146         m_sessions[0][session_id]['last_time'] = time.time()
147         out = []
148         ret, addresses = modified_addresses(session)
149         if ret: 
150             m_sessions[0][session_id]['addresses'] = addresses
151             for addr in ret:
152                 msg_id, status = addresses[addr]
153                 out.append(  { 'id':msg_id, 'result':status } )
154
155         msg_id, last_nb = session.get('numblocks')
156         if last_nb:
157             if last_nb != block_number:
158                 m_sessions[0][session_id]['numblocks'] = msg_id, block_number
159                 out.append( {'id':msg_id, 'result':block_number} )
160
161         return out
162
163
164 def do_update_address(addr):
165     # an address was involved in a transaction; we check if it was subscribed to in a session
166     # the address can be subscribed in several sessions; the cache should ensure that we don't do redundant requests
167
168     for session_id in sessions.keys():
169         session = sessions[session_id]
170         if session.get('type') != 'persistent': continue
171         addresses = session['addresses'].keys()
172
173         if addr in addresses:
174             status = store.get_status( addr )
175             message_id, last_status = session['addresses'][addr]
176             if last_status != status:
177                 #print "sending new status for %s:"%addr, status
178                 send_status(session_id,message_id,addr,status)
179                 sessions[session_id]['addresses'][addr] = (message_id,status)
180
181
182
183 def send_numblocks(session_id):
184     message_id = sessions_sub_numblocks[session_id]
185     out = json.dumps( {'id':message_id, 'result':block_number} )
186     output_queue.put((session_id, out))
187
188 def send_status(session_id, message_id, address, status):
189     out = json.dumps( { 'id':message_id, 'result':status } )
190     output_queue.put((session_id, out))
191
192 def address_get_history_json(_,message_id,address):
193     return store.get_history(address)
194
195 def subscribe_to_numblocks(session_id, message_id):
196     sessions_sub_numblocks[session_id] = message_id
197     send_numblocks(session_id)
198
199 def subscribe_to_numblocks_json(session_id, message_id):
200     global m_sessions
201     m_sessions[0][session_id]['numblocks'] = message_id,block_number
202     return block_number
203
204 def subscribe_to_address(session_id, message_id, address):
205     status = store.get_status(address)
206     sessions[session_id]['addresses'][address] = (message_id, status)
207     sessions[session_id]['last_time'] = time.time()
208     send_status(session_id, message_id, address, status)
209
210 def add_address_to_session_json(session_id, message_id, address):
211     global m_sessions
212     sessions = m_sessions[0]
213     status = store.get_status(address)
214     sessions[session_id]['addresses'][address] = (message_id, status)
215     sessions[session_id]['last_time'] = time.time()
216     m_sessions[0] = sessions
217     return status
218
219 def add_address_to_session(session_id, address):
220     status = store.get_status(address)
221     sessions[session_id]['addresses'][address] = ("", status)
222     sessions[session_id]['last_time'] = time.time()
223     return status
224
225 def new_session(version, addresses):
226     session_id = random_string(10)
227     sessions[session_id] = { 'addresses':{}, 'version':version }
228     for a in addresses:
229         sessions[session_id]['addresses'][a] = ('','')
230     out = repr( (session_id, config.get('server','banner').replace('\\n','\n') ) )
231     sessions[session_id]['last_time'] = time.time()
232     return out
233
234
235 def client_version_json(session_id, _, version):
236     global m_sessions
237     sessions = m_sessions[0]
238     sessions[session_id]['version'] = version
239     m_sessions[0] = sessions
240
241 def create_session_json(_, __):
242     sessions = m_sessions[0]
243     session_id = random_string(10)
244     print "creating session", session_id
245     sessions[session_id] = { 'addresses':{}, 'numblocks':('','') }
246     sessions[session_id]['last_time'] = time.time()
247     m_sessions[0] = sessions
248     return session_id
249
250
251
252 def get_banner(_,__):
253     return config.get('server','banner').replace('\\n','\n')
254
255 def update_session(session_id,addresses):
256     """deprecated in 0.42"""
257     sessions[session_id]['addresses'] = {}
258     for a in addresses:
259         sessions[session_id]['addresses'][a] = ''
260     sessions[session_id]['last_time'] = time.time()
261     return 'ok'
262
263 def native_server_thread():
264     s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
265     s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
266     s.bind((config.get('server','host'), config.getint('server','port')))
267     s.listen(1)
268     while not stopping:
269         conn, addr = s.accept()
270         try:
271             thread.start_new_thread(native_client_thread, (addr, conn,))
272         except:
273             # can't start new thread if there is no memory..
274             traceback.print_exc(file=sys.stdout)
275
276
277 def native_client_thread(ipaddr,conn):
278     #print "client thread", ipaddr
279     try:
280         ipaddr = ipaddr[0]
281         msg = ''
282         while 1:
283             d = conn.recv(1024)
284             msg += d
285             if not d: 
286                 break
287             if '#' in msg:
288                 msg = msg.split('#', 1)[0]
289                 break
290         try:
291             cmd, data = ast.literal_eval(msg)
292         except:
293             print "syntax error", repr(msg), ipaddr
294             conn.close()
295             return
296
297         out = do_command(cmd, data, ipaddr)
298         if out:
299             #print ipaddr, cmd, len(out)
300             try:
301                 conn.send(out)
302             except:
303                 print "error, could not send"
304
305     finally:
306         conn.close()
307
308
309 def timestr():
310     return time.strftime("[%d/%m/%Y-%H:%M:%S]")
311
312 # used by the native handler
313 def do_command(cmd, data, ipaddr):
314
315     if cmd=='b':
316         out = "%d"%block_number
317
318     elif cmd in ['session','new_session']:
319         try:
320             if cmd == 'session':
321                 addresses = ast.literal_eval(data)
322                 version = "old"
323             else:
324                 version, addresses = ast.literal_eval(data)
325                 if version[0]=="0": version = "v" + version
326         except:
327             print "error", data
328             return None
329         print timestr(), "new session", ipaddr, addresses[0] if addresses else addresses, len(addresses), version
330         out = new_session(version, addresses)
331
332     elif cmd=='address.subscribe':
333         try:
334             session_id, addr = ast.literal_eval(data)
335         except:
336             traceback.print_exc(file=sys.stdout)
337             print data
338             return None
339         out = add_address_to_session(session_id,addr)
340
341     elif cmd=='update_session':
342         try:
343             session_id, addresses = ast.literal_eval(data)
344         except:
345             traceback.print_exc(file=sys.stdout)
346             return None
347         print timestr(), "update session", ipaddr, addresses[0] if addresses else addresses, len(addresses)
348         out = update_session(session_id,addresses)
349             
350     elif cmd=='poll': 
351         out = poll_session(data)
352
353     elif cmd == 'h': 
354         # history
355         address = data
356         out = repr( store.get_history( address ) )
357
358     elif cmd == 'load': 
359         out = cmd_load(None,None,data)
360
361     elif cmd =='tx':
362         out = store.send_tx(data)
363         print timestr(), "sent tx:", ipaddr, out
364
365     elif cmd == 'stop':
366         out = cmd_stop(data)
367
368     elif cmd == 'peers':
369         out = repr(peer_list.values())
370
371     else:
372         out = None
373
374     return out
375
376
377 def clean_session_thread():
378     while not stopping:
379         time.sleep(30)
380         t = time.time()
381         for k,s in sessions.items():
382             if s.get('type') == 'persistent': continue
383             t0 = s['last_time']
384             if t - t0 > 5*60:
385                 sessions.pop(k)
386                 print "lost session", k
387             
388
389 ####################################################################
390
391
392 import stratum
393
394 class AbeProcessor(stratum.Processor):
395     def process(self,session,request):
396         message_id = request['id']
397         method = request['method']
398         params = request.get('params',[])
399         #print request
400
401         result = ''
402         if method == 'numblocks.subscribe':
403             session.subscribe_to_numblocks(message_id)
404             result = block_number
405         elif method == 'address.subscribe':
406             address = params[0]
407             status = store.get_status(address)
408             session.subscribe_to_address(address,message_id,status)
409             result = status
410         elif method == 'client.version':
411             session.version = params[0]
412         elif method == 'server.banner':
413             result = config.get('server','banner').replace('\\n','\n')
414         elif method == 'server.peers':
415             result = peer_list.values()
416         elif method == 'address.get_history':
417             address = params[0]
418             result = store.get_history( address ) 
419         elif method == 'transaction.broadcast':
420             txo = store.send_tx(params[0])
421             print "sent tx:", txo
422             result = txo 
423         else:
424             print "unknown method", request
425
426         if result!='':
427             response = { 'id':message_id, 'result':result }
428             self.push_response(session,response)
429
430
431
432 ####################################################################
433
434
435
436
437
438 def irc_thread():
439     global peer_list
440     NICK = 'E_'+random_string(10)
441     while not stopping:
442         try:
443             s = socket.socket()
444             s.connect(('irc.freenode.net', 6667))
445             s.send('USER electrum 0 * :'+config.get('server','host')+' '+config.get('server','ircname')+'\n')
446             s.send('NICK '+NICK+'\n')
447             s.send('JOIN #electrum\n')
448             sf = s.makefile('r', 0)
449             t = 0
450             while not stopping:
451                 line = sf.readline()
452                 line = line.rstrip('\r\n')
453                 line = line.split()
454                 if line[0]=='PING': 
455                     s.send('PONG '+line[1]+'\n')
456                 elif '353' in line: # answer to /names
457                     k = line.index('353')
458                     for item in line[k+1:]:
459                         if item[0:2] == 'E_':
460                             s.send('WHO %s\n'%item)
461                 elif '352' in line: # answer to /who
462                     # warning: this is a horrible hack which apparently works
463                     k = line.index('352')
464                     ip = line[k+4]
465                     ip = socket.gethostbyname(ip)
466                     name = line[k+6]
467                     host = line[k+9]
468                     peer_list[name] = (ip,host)
469                 if time.time() - t > 5*60:
470                     s.send('NAMES #electrum\n')
471                     t = time.time()
472                     peer_list = {}
473         except:
474             traceback.print_exc(file=sys.stdout)
475         finally:
476             sf.close()
477             s.close()
478
479
480 def get_peers_json(_,__):
481     return peer_list.values()
482
483 def http_server_thread():
484     # see http://code.google.com/p/jsonrpclib/
485     from SocketServer import ThreadingMixIn
486     from StratumJSONRPCServer import StratumJSONRPCServer
487     class StratumThreadedJSONRPCServer(ThreadingMixIn, StratumJSONRPCServer): pass
488     server = StratumThreadedJSONRPCServer(( config.get('server','host'), 8081))
489     server.register_function(get_peers_json, 'server.peers')
490     server.register_function(cmd_stop, 'stop')
491     server.register_function(cmd_load, 'load')
492     server.register_function(get_banner, 'server.banner')
493     server.register_function(lambda a,b,c: store.send_tx(c), 'transaction.broadcast')
494     server.register_function(address_get_history_json, 'address.get_history')
495     server.register_function(add_address_to_session_json, 'address.subscribe')
496     server.register_function(subscribe_to_numblocks_json, 'numblocks.subscribe')
497     server.register_function(client_version_json, 'client.version')
498     server.register_function(create_session_json, 'session.create')   # internal message (not part of protocol)
499     server.register_function(poll_session_json, 'session.poll')       # internal message (not part of protocol)
500     server.serve_forever()
501
502
503 if __name__ == '__main__':
504
505     if len(sys.argv)>1:
506         import jsonrpclib
507         server = jsonrpclib.Server('http://%s:8081'%config.get('server','host'))
508         cmd = sys.argv[1]
509         if cmd == 'load':
510             out = server.load(password)
511         elif cmd == 'peers':
512             out = server.server.peers()
513         elif cmd == 'stop':
514             out = server.stop(password)
515         elif cmd == 'clear_cache':
516             out = server.clear_cache(password)
517         elif cmd == 'get_cache':
518             out = server.get_cache(password,sys.argv[2])
519         elif cmd == 'h':
520             out = server.address.get_history(sys.argv[2])
521         elif cmd == 'tx':
522             out = server.transaction.broadcast(sys.argv[2])
523         elif cmd == 'b':
524             out = server.numblocks.subscribe()
525         else:
526             out = "Unknown command: '%s'" % cmd
527         print out
528         sys.exit(0)
529
530     # backend
531     store = abe_backend.AbeStore(config)
532
533     # supported protocols
534     thread.start_new_thread(native_server_thread, ())
535
536     thread.start_new_thread(http_server_thread, ())
537     thread.start_new_thread(clean_session_thread, ())
538
539     #tcp stratum
540     stratum_processor = AbeProcessor()
541     shared = stratum.Shared()
542     # Bind shared to processor since constructor is user defined
543     stratum_processor.shared = shared
544     stratum_processor.start()
545     # Create various transports we need
546     server = stratum.TcpServer(shared, stratum_processor, "ecdsa.org",50001)
547     server.start()
548
549     if (config.get('server','irc') == 'yes' ):
550         thread.start_new_thread(irc_thread, ())
551
552     print "starting Electrum server"
553
554     old_block_number = None
555     while not stopping:
556         block_number = store.main_iteration()
557
558         if block_number != old_block_number:
559             old_block_number = block_number
560             for session_id in sessions_sub_numblocks.keys():
561                 send_numblocks(session_id)
562
563             for session in stratum_processor.sessions:
564                 if session.numblocks_sub is not None:
565                     response = { 'id':session.numblocks_sub, 'result':block_number }
566                     stratum_processor.push_response(session,response)
567
568         while True:
569             try:
570                 addr = store.address_queue.get(False)
571             except:
572                 break
573             do_update_address(addr)
574
575             for session in stratum_processor.sessions:
576                 m = session.addresses_sub.get(addr)
577                 if m:
578                     status = store.get_status( addr )
579                     message_id, last_status = m
580                     if status != last_status:
581                         session.subscribe_to_address(message_id, status)
582                         response = { 'id':message_id, 'result':status }
583                         stratum_processor.push_response(session,response)
584
585         time.sleep(10)
586     print "server stopped"
587