join threads during server shutdown
[electrum-server.git] / backends / bitcoind / blockchain_processor.py
1 import ast
2 import hashlib
3 from json import dumps, loads
4 import os
5 from Queue import Queue
6 import random
7 import sys
8 import time
9 import threading
10 import traceback
11 import urllib
12
13 from backends.bitcoind import deserialize
14 from processor import Processor, print_log
15 from utils import *
16
17 from storage import Storage
18
19
20 class BlockchainProcessor(Processor):
21
22     def __init__(self, config, shared):
23         Processor.__init__(self)
24
25         self.mtimes = {} # monitoring
26         self.shared = shared
27         self.config = config
28         self.up_to_date = False
29
30         self.watch_lock = threading.Lock()
31         self.watch_blocks = []
32         self.watch_headers = []
33         self.watched_addresses = {}
34
35         self.history_cache = {}
36         self.chunk_cache = {}
37         self.cache_lock = threading.Lock()
38         self.headers_data = ''
39         self.headers_path = config.get('leveldb', 'path_fulltree')
40
41         self.mempool_addresses = {}
42         self.mempool_hist = {}
43         self.mempool_hashes = set([])
44         self.mempool_lock = threading.Lock()
45
46         self.address_queue = Queue()
47
48         try:
49             self.test_reorgs = config.getboolean('leveldb', 'test_reorgs')   # simulate random blockchain reorgs
50         except:
51             self.test_reorgs = False
52         self.storage = Storage(config, shared, self.test_reorgs)
53
54         self.dblock = threading.Lock()
55
56         self.bitcoind_url = 'http://%s:%s@%s:%s/' % (
57             config.get('bitcoind', 'user'),
58             config.get('bitcoind', 'password'),
59             config.get('bitcoind', 'host'),
60             config.get('bitcoind', 'port'))
61
62         while True:
63             try:
64                 self.bitcoind('getinfo')
65                 break
66             except:
67                 print_log('cannot contact bitcoind...')
68                 time.sleep(5)
69                 continue
70
71         self.sent_height = 0
72         self.sent_header = None
73
74         # catch_up headers
75         self.init_headers(self.storage.height)
76
77         threading.Timer(0, lambda: self.catch_up(sync=False)).start()
78         while not shared.stopped() and not self.up_to_date:
79             try:
80                 time.sleep(1)
81             except:
82                 print "keyboard interrupt: stopping threads"
83                 shared.stop()
84                 sys.exit(0)
85
86         print_log("Blockchain is up to date.")
87         self.memorypool_update()
88         print_log("Memory pool initialized.")
89
90         self.timer = threading.Timer(10, self.main_iteration)
91         self.timer.start()
92
93
94
95     def mtime(self, name):
96         now = time.time()
97         if name != '':
98             delta = now - self.now
99             t = self.mtimes.get(name, 0)
100             self.mtimes[name] = t + delta
101         self.now = now
102
103     def print_mtime(self):
104         s = ''
105         for k, v in self.mtimes.items():
106             s += k+':'+"%.2f"%v+' '
107         print_log(s)
108
109
110     def bitcoind(self, method, params=[]):
111         postdata = dumps({"method": method, 'params': params, 'id': 'jsonrpc'})
112         try:
113             respdata = urllib.urlopen(self.bitcoind_url, postdata).read()
114         except:
115             traceback.print_exc(file=sys.stdout)
116             self.shared.stop()
117
118         r = loads(respdata)
119         if r['error'] is not None:
120             raise BaseException(r['error'])
121         return r.get('result')
122
123
124     def block2header(self, b):
125         return {
126             "block_height": b.get('height'),
127             "version": b.get('version'),
128             "prev_block_hash": b.get('previousblockhash'),
129             "merkle_root": b.get('merkleroot'),
130             "timestamp": b.get('time'),
131             "bits": int(b.get('bits'), 16),
132             "nonce": b.get('nonce'),
133         }
134
135     def get_header(self, height):
136         block_hash = self.bitcoind('getblockhash', [height])
137         b = self.bitcoind('getblock', [block_hash])
138         return self.block2header(b)
139
140     def init_headers(self, db_height):
141         self.chunk_cache = {}
142         self.headers_filename = os.path.join(self.headers_path, 'blockchain_headers')
143
144         if os.path.exists(self.headers_filename):
145             height = os.path.getsize(self.headers_filename)/80 - 1   # the current height
146             if height > 0:
147                 prev_hash = self.hash_header(self.read_header(height))
148             else:
149                 prev_hash = None
150         else:
151             open(self.headers_filename, 'wb').close()
152             prev_hash = None
153             height = -1
154
155         if height < db_height:
156             print_log("catching up missing headers:", height, db_height)
157
158         try:
159             while height < db_height:
160                 height = height + 1
161                 header = self.get_header(height)
162                 if height > 1:
163                     assert prev_hash == header.get('prev_block_hash')
164                 self.write_header(header, sync=False)
165                 prev_hash = self.hash_header(header)
166                 if (height % 1000) == 0:
167                     print_log("headers file:", height)
168         except KeyboardInterrupt:
169             self.flush_headers()
170             sys.exit()
171
172         self.flush_headers()
173
174     def hash_header(self, header):
175         return rev_hex(Hash(header_to_string(header).decode('hex')).encode('hex'))
176
177     def read_header(self, block_height):
178         if os.path.exists(self.headers_filename):
179             with open(self.headers_filename, 'rb') as f:
180                 f.seek(block_height * 80)
181                 h = f.read(80)
182             if len(h) == 80:
183                 h = header_from_string(h)
184                 return h
185
186     def read_chunk(self, index):
187         with open(self.headers_filename, 'rb') as f:
188             f.seek(index*2016*80)
189             chunk = f.read(2016*80)
190         return chunk.encode('hex')
191
192     def write_header(self, header, sync=True):
193         if not self.headers_data:
194             self.headers_offset = header.get('block_height')
195
196         self.headers_data += header_to_string(header).decode('hex')
197         if sync or len(self.headers_data) > 40*100:
198             self.flush_headers()
199
200         with self.cache_lock:
201             chunk_index = header.get('block_height')/2016
202             if self.chunk_cache.get(chunk_index):
203                 self.chunk_cache.pop(chunk_index)
204
205     def pop_header(self):
206         # we need to do this only if we have not flushed
207         if self.headers_data:
208             self.headers_data = self.headers_data[:-40]
209
210     def flush_headers(self):
211         if not self.headers_data:
212             return
213         with open(self.headers_filename, 'rb+') as f:
214             f.seek(self.headers_offset*80)
215             f.write(self.headers_data)
216         self.headers_data = ''
217
218     def get_chunk(self, i):
219         # store them on disk; store the current chunk in memory
220         with self.cache_lock:
221             chunk = self.chunk_cache.get(i)
222             if not chunk:
223                 chunk = self.read_chunk(i)
224                 self.chunk_cache[i] = chunk
225
226         return chunk
227
228     def get_mempool_transaction(self, txid):
229         try:
230             raw_tx = self.bitcoind('getrawtransaction', [txid, 0])
231         except:
232             return None
233
234         vds = deserialize.BCDataStream()
235         vds.write(raw_tx.decode('hex'))
236         try:
237             return deserialize.parse_Transaction(vds, is_coinbase=False)
238         except:
239             print_log("ERROR: cannot parse", txid)
240             return None
241
242
243     def get_history(self, addr, cache_only=False):
244         with self.cache_lock:
245             hist = self.history_cache.get(addr)
246         if hist is not None:
247             return hist
248         if cache_only:
249             return -1
250
251         with self.dblock:
252             try:
253                 hist = self.storage.get_history(addr)
254                 is_known = True
255             except:
256                 self.shared.stop()
257                 raise
258             if hist:
259                 is_known = True
260             else:
261                 hist = []
262                 is_known = False
263
264         # add memory pool
265         with self.mempool_lock:
266             for txid in self.mempool_hist.get(addr, []):
267                 hist.append({'tx_hash':txid, 'height':0})
268
269         # add something to distinguish between unused and empty addresses
270         if hist == [] and is_known:
271             hist = ['*']
272
273         with self.cache_lock:
274             self.history_cache[addr] = hist
275         return hist
276
277
278     def get_status(self, addr, cache_only=False):
279         tx_points = self.get_history(addr, cache_only)
280         if cache_only and tx_points == -1:
281             return -1
282
283         if not tx_points:
284             return None
285         if tx_points == ['*']:
286             return '*'
287         status = ''
288         for tx in tx_points:
289             status += tx.get('tx_hash') + ':%d:' % tx.get('height')
290         return hashlib.sha256(status).digest().encode('hex')
291
292     def get_merkle(self, tx_hash, height):
293
294         block_hash = self.bitcoind('getblockhash', [height])
295         b = self.bitcoind('getblock', [block_hash])
296         tx_list = b.get('tx')
297         tx_pos = tx_list.index(tx_hash)
298
299         merkle = map(hash_decode, tx_list)
300         target_hash = hash_decode(tx_hash)
301         s = []
302         while len(merkle) != 1:
303             if len(merkle) % 2:
304                 merkle.append(merkle[-1])
305             n = []
306             while merkle:
307                 new_hash = Hash(merkle[0] + merkle[1])
308                 if merkle[0] == target_hash:
309                     s.append(hash_encode(merkle[1]))
310                     target_hash = new_hash
311                 elif merkle[1] == target_hash:
312                     s.append(hash_encode(merkle[0]))
313                     target_hash = new_hash
314                 n.append(new_hash)
315                 merkle = merkle[2:]
316             merkle = n
317
318         return {"block_height": height, "merkle": s, "pos": tx_pos}
319
320
321     def add_to_history(self, addr, tx_hash, tx_pos, tx_height):
322         # keep it sorted
323         s = self.serialize_item(tx_hash, tx_pos, tx_height) + 40*chr(0)
324         assert len(s) == 80
325
326         serialized_hist = self.batch_list[addr]
327
328         l = len(serialized_hist)/80
329         for i in range(l-1, -1, -1):
330             item = serialized_hist[80*i:80*(i+1)]
331             item_height = int(rev_hex(item[36:39].encode('hex')), 16)
332             if item_height <= tx_height:
333                 serialized_hist = serialized_hist[0:80*(i+1)] + s + serialized_hist[80*(i+1):]
334                 break
335         else:
336             serialized_hist = s + serialized_hist
337
338         self.batch_list[addr] = serialized_hist
339
340         # backlink
341         txo = (tx_hash + int_to_hex(tx_pos, 4)).decode('hex')
342         self.batch_txio[txo] = addr
343
344
345
346
347
348
349     def deserialize_block(self, block):
350         txlist = block.get('tx')
351         tx_hashes = []  # ordered txids
352         txdict = {}     # deserialized tx
353         is_coinbase = True
354         for raw_tx in txlist:
355             tx_hash = hash_encode(Hash(raw_tx.decode('hex')))
356             vds = deserialize.BCDataStream()
357             vds.write(raw_tx.decode('hex'))
358             try:
359                 tx = deserialize.parse_Transaction(vds, is_coinbase)
360             except:
361                 print_log("ERROR: cannot parse", tx_hash)
362                 continue
363             tx_hashes.append(tx_hash)
364             txdict[tx_hash] = tx
365             is_coinbase = False
366         return tx_hashes, txdict
367
368
369
370     def import_block(self, block, block_hash, block_height, sync, revert=False):
371
372         touched_addr = set([])
373
374         # deserialize transactions
375         tx_hashes, txdict = self.deserialize_block(block)
376
377         # undo info
378         if revert:
379             undo_info = self.storage.get_undo_info(block_height)
380             tx_hashes.reverse()
381         else:
382             undo_info = {}
383
384         for txid in tx_hashes:  # must be ordered
385             tx = txdict[txid]
386             if not revert:
387                 undo = self.storage.import_transaction(txid, tx, block_height, touched_addr)
388                 undo_info[txid] = undo
389             else:
390                 undo = undo_info.pop(txid)
391                 self.storage.revert_transaction(txid, tx, block_height, touched_addr, undo)
392
393         if revert: 
394             assert undo_info == {}
395
396         # add undo info
397         if not revert:
398             self.storage.write_undo_info(block_height, self.bitcoind_height, undo_info)
399
400         # add the max
401         self.storage.db_undo.put('height', repr( (block_hash, block_height, self.storage.db_version) ))
402
403         for addr in touched_addr:
404             self.invalidate_cache(addr)
405
406         self.storage.update_hashes()
407
408
409     def add_request(self, session, request):
410         # see if we can get if from cache. if not, add to queue
411         if self.process(session, request, cache_only=True) == -1:
412             self.queue.put((session, request))
413
414
415     def do_subscribe(self, method, params, session):
416         with self.watch_lock:
417             if method == 'blockchain.numblocks.subscribe':
418                 if session not in self.watch_blocks:
419                     self.watch_blocks.append(session)
420
421             elif method == 'blockchain.headers.subscribe':
422                 if session not in self.watch_headers:
423                     self.watch_headers.append(session)
424
425             elif method == 'blockchain.address.subscribe':
426                 address = params[0]
427                 l = self.watched_addresses.get(address)
428                 if l is None:
429                     self.watched_addresses[address] = [session]
430                 elif session not in l:
431                     l.append(session)
432
433
434     def do_unsubscribe(self, method, params, session):
435         with self.watch_lock:
436             if method == 'blockchain.numblocks.subscribe':
437                 if session in self.watch_blocks:
438                     self.watch_blocks.remove(session)
439             elif method == 'blockchain.headers.subscribe':
440                 if session in self.watch_headers:
441                     self.watch_headers.remove(session)
442             elif method == "blockchain.address.subscribe":
443                 addr = params[0]
444                 l = self.watched_addresses.get(addr)
445                 if not l:
446                     return
447                 if session in l:
448                     l.remove(session)
449                 if session in l:
450                     print "error rc!!"
451                     self.shared.stop()
452                 if l == []:
453                     self.watched_addresses.pop(addr)
454
455
456     def process(self, session, request, cache_only=False):
457         
458         message_id = request['id']
459         method = request['method']
460         params = request.get('params', [])
461         result = None
462         error = None
463
464         if method == 'blockchain.numblocks.subscribe':
465             result = self.storage.height
466
467         elif method == 'blockchain.headers.subscribe':
468             result = self.header
469
470         elif method == 'blockchain.address.subscribe':
471             try:
472                 address = str(params[0])
473                 result = self.get_status(address, cache_only)
474             except BaseException, e:
475                 error = str(e) + ': ' + address
476                 print_log("error:", error)
477
478         elif method == 'blockchain.address.get_history':
479             try:
480                 address = str(params[0])
481                 result = self.get_history(address, cache_only)
482             except BaseException, e:
483                 error = str(e) + ': ' + address
484                 print_log("error:", error)
485
486         elif method == 'blockchain.address.get_balance':
487             try:
488                 address = str(params[0])
489                 result = self.storage.get_balance(address)
490             except BaseException, e:
491                 error = str(e) + ': ' + address
492                 print_log("error:", error)
493
494         elif method == 'blockchain.address.get_proof':
495             try:
496                 address = str(params[0])
497                 result = self.storage.get_proof(address)
498             except BaseException, e:
499                 error = str(e) + ': ' + address
500                 print_log("error:", error)
501
502         elif method == 'blockchain.address.listunspent':
503             try:
504                 address = str(params[0])
505                 result = self.storage.listunspent(address)
506             except BaseException, e:
507                 error = str(e) + ': ' + address
508                 print_log("error:", error)
509
510         elif method == 'blockchain.utxo.get_address':
511             try:
512                 txid = str(params[0])
513                 pos = int(params[1])
514                 txi = (txid + int_to_hex(pos, 4)).decode('hex')
515                 result = self.storage.get_address(txi)
516             except BaseException, e:
517                 error = str(e)
518                 print_log("error:", error, params)
519
520         elif method == 'blockchain.block.get_header':
521             if cache_only:
522                 result = -1
523             else:
524                 try:
525                     height = int(params[0])
526                     result = self.get_header(height)
527                 except BaseException, e:
528                     error = str(e) + ': %d' % height
529                     print_log("error:", error)
530
531         elif method == 'blockchain.block.get_chunk':
532             if cache_only:
533                 result = -1
534             else:
535                 try:
536                     index = int(params[0])
537                     result = self.get_chunk(index)
538                 except BaseException, e:
539                     error = str(e) + ': %d' % index
540                     print_log("error:", error)
541
542         elif method == 'blockchain.transaction.broadcast':
543             try:
544                 txo = self.bitcoind('sendrawtransaction', params)
545                 print_log("sent tx:", txo)
546                 result = txo
547             except BaseException, e:
548                 result = str(e)  # do not send an error
549                 print_log("error:", result, params)
550
551         elif method == 'blockchain.transaction.get_merkle':
552             if cache_only:
553                 result = -1
554             else:
555                 try:
556                     tx_hash = params[0]
557                     tx_height = params[1]
558                     result = self.get_merkle(tx_hash, tx_height)
559                 except BaseException, e:
560                     error = str(e) + ': ' + repr(params)
561                     print_log("get_merkle error:", error)
562
563         elif method == 'blockchain.transaction.get':
564             try:
565                 tx_hash = params[0]
566                 result = self.bitcoind('getrawtransaction', [tx_hash, 0])
567             except BaseException, e:
568                 error = str(e) + ': ' + repr(params)
569                 print_log("tx get error:", error)
570
571         else:
572             error = "unknown method:%s" % method
573
574         if cache_only and result == -1:
575             return -1
576
577         if error:
578             self.push_response(session, {'id': message_id, 'error': error})
579         elif result != '':
580             self.push_response(session, {'id': message_id, 'result': result})
581
582
583     def getfullblock(self, block_hash):
584         block = self.bitcoind('getblock', [block_hash])
585
586         rawtxreq = []
587         i = 0
588         for txid in block['tx']:
589             rawtxreq.append({
590                 "method": "getrawtransaction",
591                 "params": [txid],
592                 "id": i,
593             })
594             i += 1
595
596         postdata = dumps(rawtxreq)
597         try:
598             respdata = urllib.urlopen(self.bitcoind_url, postdata).read()
599         except:
600             traceback.print_exc(file=sys.stdout)
601             self.shared.stop()
602
603         r = loads(respdata)
604         rawtxdata = []
605         for ir in r:
606             if ir['error'] is not None:
607                 self.shared.stop()
608                 print_log("Error: make sure you run bitcoind with txindex=1; use -reindex if needed.")
609                 raise BaseException(ir['error'])
610             rawtxdata.append(ir['result'])
611         block['tx'] = rawtxdata
612         return block
613
614     def catch_up(self, sync=True):
615
616         prev_root_hash = None
617         while not self.shared.stopped():
618
619             self.mtime('')
620
621             # are we done yet?
622             info = self.bitcoind('getinfo')
623             self.bitcoind_height = info.get('blocks')
624             bitcoind_block_hash = self.bitcoind('getblockhash', [self.bitcoind_height])
625             if self.storage.last_hash == bitcoind_block_hash:
626                 self.up_to_date = True
627                 break
628
629             # not done..
630             self.up_to_date = False
631             next_block_hash = self.bitcoind('getblockhash', [self.storage.height + 1])
632             next_block = self.getfullblock(next_block_hash)
633             self.mtime('daemon')
634
635             # fixme: this is unsafe, if we revert when the undo info is not yet written
636             revert = (random.randint(1, 100) == 1) if self.test_reorgs else False
637
638             if (next_block.get('previousblockhash') == self.storage.last_hash) and not revert:
639
640                 prev_root_hash = self.storage.get_root_hash()
641
642                 self.import_block(next_block, next_block_hash, self.storage.height+1, sync)
643                 self.storage.height = self.storage.height + 1
644                 self.write_header(self.block2header(next_block), sync)
645                 self.storage.last_hash = next_block_hash
646                 self.mtime('import')
647             
648                 if self.storage.height % 1000 == 0 and not sync:
649                     t_daemon = self.mtimes.get('daemon')
650                     t_import = self.mtimes.get('import')
651                     print_log("catch_up: block %d (%.3fs %.3fs)" % (self.storage.height, t_daemon, t_import), self.storage.get_root_hash().encode('hex'))
652                     self.mtimes['daemon'] = 0
653                     self.mtimes['import'] = 0
654
655             else:
656
657                 # revert current block
658                 block = self.getfullblock(self.storage.last_hash)
659                 print_log("blockchain reorg", self.storage.height, block.get('previousblockhash'), self.storage.last_hash)
660                 self.import_block(block, self.storage.last_hash, self.storage.height, sync, revert=True)
661                 self.pop_header()
662                 self.flush_headers()
663
664                 self.storage.height -= 1
665
666                 # read previous header from disk
667                 self.header = self.read_header(self.storage.height)
668                 self.storage.last_hash = self.hash_header(self.header)
669
670                 if prev_root_hash:
671                     assert prev_root_hash == self.storage.get_root_hash()
672                     prev_root_hash = None
673
674
675         self.header = self.block2header(self.bitcoind('getblock', [self.storage.last_hash]))
676         self.header['utxo_root'] = self.storage.get_root_hash().encode('hex')
677
678         if self.shared.stopped(): 
679             print_log( "closing database" )
680             self.storage.close()
681
682
683     def memorypool_update(self):
684         mempool_hashes = set(self.bitcoind('getrawmempool'))
685         touched_addresses = set([])
686
687         for tx_hash in mempool_hashes:
688             if tx_hash in self.mempool_hashes:
689                 continue
690
691             tx = self.get_mempool_transaction(tx_hash)
692             if not tx:
693                 continue
694
695             mpa = self.mempool_addresses.get(tx_hash, [])
696             for x in tx.get('inputs'):
697                 # we assume that the input address can be parsed by deserialize(); this is true for Electrum transactions
698                 addr = x.get('address')
699                 if addr and addr not in mpa:
700                     mpa.append(addr)
701                     touched_addresses.add(addr)
702
703             for x in tx.get('outputs'):
704                 addr = x.get('address')
705                 if addr and addr not in mpa:
706                     mpa.append(addr)
707                     touched_addresses.add(addr)
708
709             self.mempool_addresses[tx_hash] = mpa
710             self.mempool_hashes.add(tx_hash)
711
712         # remove older entries from mempool_hashes
713         self.mempool_hashes = mempool_hashes
714
715         # remove deprecated entries from mempool_addresses
716         for tx_hash, addresses in self.mempool_addresses.items():
717             if tx_hash not in self.mempool_hashes:
718                 self.mempool_addresses.pop(tx_hash)
719                 for addr in addresses:
720                     touched_addresses.add(addr)
721
722         # rebuild mempool histories
723         new_mempool_hist = {}
724         for tx_hash, addresses in self.mempool_addresses.items():
725             for addr in addresses:
726                 h = new_mempool_hist.get(addr, [])
727                 if tx_hash not in h:
728                     h.append(tx_hash)
729                 new_mempool_hist[addr] = h
730
731         with self.mempool_lock:
732             self.mempool_hist = new_mempool_hist
733
734         # invalidate cache for touched addresses
735         for addr in touched_addresses:
736             self.invalidate_cache(addr)
737
738
739     def invalidate_cache(self, address):
740         with self.cache_lock:
741             if address in self.history_cache:
742                 print_log("cache: invalidating", address)
743                 self.history_cache.pop(address)
744
745         with self.watch_lock:
746             sessions = self.watched_addresses.get(address)
747
748         if sessions:
749             # TODO: update cache here. if new value equals cached value, do not send notification
750             self.address_queue.put((address,sessions))
751
752     
753     def close(self):
754         self.timer.join()
755         print_log("Closing database...")
756         self.storage.close()
757         print_log("Database is closed")
758
759
760     def main_iteration(self):
761         if self.shared.stopped():
762             print_log("Stopping timer")
763             return
764
765         with self.dblock:
766             t1 = time.time()
767             self.catch_up()
768             t2 = time.time()
769
770         self.memorypool_update()
771
772         if self.sent_height != self.storage.height:
773             self.sent_height = self.storage.height
774             for session in self.watch_blocks:
775                 self.push_response(session, {
776                         'id': None,
777                         'method': 'blockchain.numblocks.subscribe',
778                         'params': [self.storage.height],
779                         })
780
781         if self.sent_header != self.header:
782             print_log("blockchain: %d (%.3fs)" % (self.storage.height, t2 - t1))
783             self.sent_header = self.header
784             for session in self.watch_headers:
785                 self.push_response(session, {
786                         'id': None,
787                         'method': 'blockchain.headers.subscribe',
788                         'params': [self.header],
789                         })
790
791         while True:
792             try:
793                 addr, sessions = self.address_queue.get(False)
794             except:
795                 break
796
797             status = self.get_status(addr)
798             for session in sessions:
799                 self.push_response(session, {
800                         'id': None,
801                         'method': 'blockchain.address.subscribe',
802                         'params': [addr, status],
803                         })
804
805         # next iteration 
806         self.timer = threading.Timer(10, self.main_iteration)
807         self.timer.start()
808