65bd28455ebc63555ed63f16b252357430374bbf
[electrum-server.git] / backends / abe / __init__.py
1 from Abe.util import hash_to_address, decode_check_address
2 from Abe.DataStore import DataStore as Datastore_class
3 from Abe import DataStore, readconf, BCDataStream,  deserialize, util, base58
4
5 import binascii
6
7 import thread, traceback, sys, urllib, operator
8 from json import dumps, loads
9 from Queue import Queue
10 import time, threading
11
12
13 class AbeStore(Datastore_class):
14
15     def __init__(self, config):
16         conf = DataStore.CONFIG_DEFAULTS
17         args, argv = readconf.parse_argv( [], conf)
18         args.dbtype = config.get('database','type')
19         if args.dbtype == 'sqlite3':
20             args.connect_args = { 'database' : config.get('database','database') }
21         elif args.dbtype == 'MySQLdb':
22             args.connect_args = { 'db' : config.get('database','database'), 'user' : config.get('database','username'), 'passwd' : config.get('database','password') }
23         elif args.dbtype == 'psycopg2':
24             args.connect_args = { 'database' : config.get('database','database') }
25
26         coin = config.get('server', 'coin')
27         self.addrtype = 0
28         if coin == 'litecoin':
29             print 'Litecoin settings:'
30             datadir = config.get('server','datadir')
31             print '  datadir = ' + datadir
32             args.datadir = [{"dirname":datadir,"chain":"Litecoin","code3":"LTC","address_version":"\u0030"}]
33             print '  addrtype = 48'
34             self.addrtype = 48
35
36         Datastore_class.__init__(self,args)
37
38         # Use 1 (Bitcoin) if chain_id is not sent
39         self.chain_id = self.datadirs[0]["chain_id"] or 1
40         print 'Coin chain_id = %d' % self.chain_id
41
42         self.sql_limit = int( config.get('database','limit') )
43
44         self.tx_cache = {}
45         self.bitcoind_url = 'http://%s:%s@%s:%s/' % ( config.get('bitcoind','user'), config.get('bitcoind','password'), config.get('bitcoind','host'), config.get('bitcoind','port'))
46
47         self.address_queue = Queue()
48
49         self.dblock = thread.allocate_lock()
50         self.last_tx_id = 0
51         self.known_mempool_hashes = []
52
53     
54     def import_tx(self, tx, is_coinbase):
55         tx_id = super(AbeStore, self).import_tx(tx, is_coinbase)
56         self.last_tx_id = tx_id
57         return tx_id
58         
59
60
61
62     def import_block(self, b, chain_ids=frozenset()):
63         #print "import block"
64         block_id = super(AbeStore, self).import_block(b, chain_ids)
65         for pos in xrange(len(b['transactions'])):
66             tx = b['transactions'][pos]
67             if 'hash' not in tx:
68                 tx['hash'] = util.double_sha256(tx['tx'])
69             tx_id = self.tx_find_id_and_value(tx)
70             if tx_id:
71                 self.update_tx_cache(tx_id)
72             else:
73                 print "error: import_block: no tx_id"
74         return block_id
75
76
77     def update_tx_cache(self, txid):
78         inrows = self.get_tx_inputs(txid, False)
79         for row in inrows:
80             _hash = self.binout(row[6])
81             if not _hash:
82                 #print "WARNING: missing tx_in for tx", txid
83                 continue
84
85             address = hash_to_address(chr(self.addrtype), _hash)
86             if self.tx_cache.has_key(address):
87                 print "cache: invalidating", address
88                 self.tx_cache.pop(address)
89             self.address_queue.put(address)
90
91         outrows = self.get_tx_outputs(txid, False)
92         for row in outrows:
93             _hash = self.binout(row[6])
94             if not _hash:
95                 #print "WARNING: missing tx_out for tx", txid
96                 continue
97
98             address = hash_to_address(chr(self.addrtype), _hash)
99             if self.tx_cache.has_key(address):
100                 print "cache: invalidating", address
101                 self.tx_cache.pop(address)
102             self.address_queue.put(address)
103
104     def safe_sql(self,sql, params=(), lock=True):
105
106         error = False
107         try:
108             if lock: self.dblock.acquire()
109             ret = self.selectall(sql,params)
110         except:
111             error = True
112             traceback.print_exc(file=sys.stdout)
113         finally:
114             if lock: self.dblock.release()
115
116         if error: 
117             raise BaseException('sql error')
118
119         return ret
120             
121
122     def get_tx_outputs(self, tx_id, lock=True):
123         return self.safe_sql("""SELECT
124                 txout.txout_pos,
125                 txout.txout_scriptPubKey,
126                 txout.txout_value,
127                 nexttx.tx_hash,
128                 nexttx.tx_id,
129                 txin.txin_pos,
130                 pubkey.pubkey_hash
131               FROM txout
132               LEFT JOIN txin ON (txin.txout_id = txout.txout_id)
133               LEFT JOIN pubkey ON (pubkey.pubkey_id = txout.pubkey_id)
134               LEFT JOIN tx nexttx ON (txin.tx_id = nexttx.tx_id)
135              WHERE txout.tx_id = %d 
136              ORDER BY txout.txout_pos
137         """%(tx_id), (), lock)
138
139     def get_tx_inputs(self, tx_id, lock=True):
140         return self.safe_sql(""" SELECT
141                 txin.txin_pos,
142                 txin.txin_scriptSig,
143                 txout.txout_value,
144                 COALESCE(prevtx.tx_hash, u.txout_tx_hash),
145                 prevtx.tx_id,
146                 COALESCE(txout.txout_pos, u.txout_pos),
147                 pubkey.pubkey_hash
148               FROM txin
149               LEFT JOIN txout ON (txout.txout_id = txin.txout_id)
150               LEFT JOIN pubkey ON (pubkey.pubkey_id = txout.pubkey_id)
151               LEFT JOIN tx prevtx ON (txout.tx_id = prevtx.tx_id)
152               LEFT JOIN unlinked_txin u ON (u.txin_id = txin.txin_id)
153              WHERE txin.tx_id = %d
154              ORDER BY txin.txin_pos
155              """%(tx_id,), (), lock)
156
157
158     def get_address_out_rows(self, dbhash):
159         out = self.safe_sql(""" SELECT
160                 b.block_nTime,
161                 cc.chain_id,
162                 b.block_height,
163                 1,
164                 b.block_hash,
165                 tx.tx_hash,
166                 tx.tx_id,
167                 txin.txin_pos,
168                 -prevout.txout_value
169               FROM chain_candidate cc
170               JOIN block b ON (b.block_id = cc.block_id)
171               JOIN block_tx ON (block_tx.block_id = b.block_id)
172               JOIN tx ON (tx.tx_id = block_tx.tx_id)
173               JOIN txin ON (txin.tx_id = tx.tx_id)
174               JOIN txout prevout ON (txin.txout_id = prevout.txout_id)
175               JOIN pubkey ON (pubkey.pubkey_id = prevout.pubkey_id)
176              WHERE pubkey.pubkey_hash = ?
177                AND cc.chain_id = ?
178                AND cc.in_longest = 1
179              LIMIT ? """, (dbhash, self.chain_id, self.sql_limit))
180
181         if len(out)==self.sql_limit: 
182             raise BaseException('limit reached')
183         return out
184
185     def get_address_out_rows_memorypool(self, dbhash):
186         out = self.safe_sql(""" SELECT
187                 1,
188                 tx.tx_hash,
189                 tx.tx_id,
190                 txin.txin_pos,
191                 -prevout.txout_value
192               FROM tx 
193               JOIN txin ON (txin.tx_id = tx.tx_id)
194               JOIN txout prevout ON (txin.txout_id = prevout.txout_id)
195               JOIN pubkey ON (pubkey.pubkey_id = prevout.pubkey_id)
196              WHERE pubkey.pubkey_hash = ?
197              LIMIT ? """, (dbhash,self.sql_limit))
198
199         if len(out)==self.sql_limit: 
200             raise BaseException('limit reached')
201         return out
202
203     def get_address_in_rows(self, dbhash):
204         out = self.safe_sql(""" SELECT
205                 b.block_nTime,
206                 cc.chain_id,
207                 b.block_height,
208                 0,
209                 b.block_hash,
210                 tx.tx_hash,
211                 tx.tx_id,
212                 txout.txout_pos,
213                 txout.txout_value
214               FROM chain_candidate cc
215               JOIN block b ON (b.block_id = cc.block_id)
216               JOIN block_tx ON (block_tx.block_id = b.block_id)
217               JOIN tx ON (tx.tx_id = block_tx.tx_id)
218               JOIN txout ON (txout.tx_id = tx.tx_id)
219               JOIN pubkey ON (pubkey.pubkey_id = txout.pubkey_id)
220              WHERE pubkey.pubkey_hash = ?
221                AND cc.chain_id = ?
222                AND cc.in_longest = 1
223                LIMIT ? """, (dbhash, self.chain_id, self.sql_limit))
224
225         if len(out)==self.sql_limit: 
226             raise BaseException('limit reached')
227         return out
228
229     def get_address_in_rows_memorypool(self, dbhash):
230         out = self.safe_sql( """ SELECT
231                 0,
232                 tx.tx_hash,
233                 tx.tx_id,
234                 txout.txout_pos,
235                 txout.txout_value
236               FROM tx
237               JOIN txout ON (txout.tx_id = tx.tx_id)
238               JOIN pubkey ON (pubkey.pubkey_id = txout.pubkey_id)
239              WHERE pubkey.pubkey_hash = ?
240              LIMIT ? """, (dbhash,self.sql_limit))
241
242         if len(out)==self.sql_limit: 
243             raise BaseException('limit reached')
244         return out
245
246     def get_history(self, addr):
247
248         cached_version = self.tx_cache.get( addr )
249         if cached_version is not None:
250             return cached_version
251
252         version, binaddr = decode_check_address(addr)
253         if binaddr is None:
254             return None
255
256         dbhash = self.binin(binaddr)
257         rows = []
258         rows += self.get_address_out_rows( dbhash )
259         rows += self.get_address_in_rows( dbhash )
260
261         txpoints = []
262         known_tx = []
263
264         for row in rows:
265             try:
266                 nTime, chain_id, height, is_in, blk_hash, tx_hash, tx_id, pos, value = row
267             except:
268                 print "cannot unpack row", row
269                 break
270             tx_hash = self.hashout_hex(tx_hash)
271             txpoint = {
272                     "timestamp":    int(nTime),
273                     "height":   int(height),
274                     "is_input":    int(is_in),
275                     "block_hash": self.hashout_hex(blk_hash),
276                     "tx_hash":  tx_hash,
277                     "tx_id":    int(tx_id),
278                     "index":      int(pos),
279                     "value":    int(value),
280                     }
281
282             txpoints.append(txpoint)
283             known_tx.append(self.hashout_hex(tx_hash))
284
285
286         # todo: sort them really...
287         txpoints = sorted(txpoints, key=operator.itemgetter("timestamp"))
288
289         # read memory pool
290         rows = []
291         rows += self.get_address_in_rows_memorypool( dbhash )
292         rows += self.get_address_out_rows_memorypool( dbhash )
293         address_has_mempool = False
294
295         for row in rows:
296             is_in, tx_hash, tx_id, pos, value = row
297             tx_hash = self.hashout_hex(tx_hash)
298             if tx_hash in known_tx:
299                 continue
300
301             # discard transactions that are too old
302             if self.last_tx_id - tx_id > 50000:
303                 print "discarding tx id", tx_id
304                 continue
305
306             # this means that pending transactions were added to the db, even if they are not returned by getmemorypool
307             address_has_mempool = True
308
309             #print "mempool", tx_hash
310             txpoint = {
311                     "timestamp":    0,
312                     "height":   0,
313                     "is_input":    int(is_in),
314                     "block_hash": 'mempool', 
315                     "tx_hash":  tx_hash,
316                     "tx_id":    int(tx_id),
317                     "index":      int(pos),
318                     "value":    int(value),
319                     }
320             txpoints.append(txpoint)
321
322
323         for txpoint in txpoints:
324             tx_id = txpoint['tx_id']
325             
326             txinputs = []
327             inrows = self.get_tx_inputs(tx_id)
328             for row in inrows:
329                 _hash = self.binout(row[6])
330                 if not _hash:
331                     #print "WARNING: missing tx_in for tx", tx_id, addr
332                     continue
333                 address = hash_to_address(chr(self.addrtype), _hash)
334                 txinputs.append(address)
335             txpoint['inputs'] = txinputs
336             txoutputs = []
337             outrows = self.get_tx_outputs(tx_id)
338             for row in outrows:
339                 _hash = self.binout(row[6])
340                 if not _hash:
341                     #print "WARNING: missing tx_out for tx", tx_id, addr
342                     continue
343                 address = hash_to_address(chr(self.addrtype), _hash)
344                 txoutputs.append(address)
345             txpoint['outputs'] = txoutputs
346
347             # for all unspent inputs, I want their scriptpubkey. (actually I could deduce it from the address)
348             if not txpoint['is_input']:
349                 # detect if already redeemed...
350                 for row in outrows:
351                     if row[6] == dbhash: break
352                 else:
353                     raise
354                 #row = self.get_tx_output(tx_id,dbhash)
355                 # pos, script, value, o_hash, o_id, o_pos, binaddr = row
356                 # if not redeemed, we add the script
357                 if row:
358                     if not row[4]: txpoint['raw_output_script'] = row[1]
359
360             txpoint.pop('tx_id')
361
362         # cache result
363         # do not cache mempool results because statuses are ambiguous
364         if not address_has_mempool:
365             self.tx_cache[addr] = txpoints
366         
367         return txpoints
368
369
370     def get_status(self,addr):
371         # get address status, i.e. the last block for that address.
372         tx_points = self.get_history(addr)
373         if not tx_points:
374             status = None
375         else:
376             lastpoint = tx_points[-1]
377             status = lastpoint['block_hash']
378             # this is a temporary hack; move it up once old clients have disappeared
379             if status == 'mempool': # and session['version'] != "old":
380                 status = status + ':%d'% len(tx_points)
381         return status
382
383
384     def get_block_header(self, block_height):
385         out = self.safe_sql("""
386             SELECT
387                 block_hash,
388                 block_version,
389                 block_hashMerkleRoot,
390                 block_nTime,
391                 block_nBits,
392                 block_nNonce,
393                 block_height,
394                 prev_block_hash,
395                 block_id
396               FROM chain_summary
397              WHERE block_height = %d AND in_longest = 1"""%block_height)
398
399         if not out: raise BaseException("block not found")
400         row = out[0]
401         (block_hash, block_version, hashMerkleRoot, nTime, nBits, nNonce, height,prev_block_hash, block_id) \
402             = ( self.hashout_hex(row[0]), int(row[1]), self.hashout_hex(row[2]), int(row[3]), int(row[4]), int(row[5]), int(row[6]), self.hashout_hex(row[7]), int(row[8]) )
403
404         out = {"block_height":block_height, "version":block_version, "prev_block_hash":prev_block_hash, 
405                 "merkle_root":hashMerkleRoot, "timestamp":nTime, "bits":nBits, "nonce":nNonce}
406         return out
407         
408
409     def get_tx_merkle(self, tx_hash):
410
411         out = self.safe_sql("""
412              SELECT block_tx.block_id FROM tx 
413              JOIN block_tx on tx.tx_id = block_tx.tx_id 
414              JOIN chain_summary on chain_summary.block_id = block_tx.block_id
415              WHERE tx_hash='%s' AND in_longest = 1"""%tx_hash)
416         block_id = out[0]
417
418         # get block height
419         out = self.safe_sql("SELECT block_height FROM chain_summary WHERE block_id = %d AND in_longest = 1"%block_id)
420
421         if not out: raise BaseException("block not found")
422         block_height = int(out[0][0])
423
424         merkle = []
425         # list all tx in block
426         for row in self.safe_sql("""
427             SELECT DISTINCT tx_id, tx_pos, tx_hash
428               FROM txin_detail
429              WHERE block_id = ?
430              ORDER BY tx_pos""", (block_id,)):
431             tx_id, tx_pos, tx_h = row
432             merkle.append(tx_h)
433
434         # find subset.
435         # TODO: do not compute this on client request, better store the hash tree of each block in a database...
436         import hashlib
437         encode = lambda x: x[::-1].encode('hex')
438         decode = lambda x: x.decode('hex')[::-1]
439         Hash = lambda x: hashlib.sha256(hashlib.sha256(x).digest()).digest()
440
441         merkle = map(decode, merkle)
442         target_hash = decode(tx_hash)
443
444         s = []
445         while len(merkle) != 1:
446             if len(merkle)%2: merkle.append( merkle[-1] )
447             n = []
448             while merkle:
449                 new_hash = Hash( merkle[0] + merkle[1] )
450                 if merkle[0] == target_hash:
451                     s.append( "L" + encode(merkle[1]))
452                     target_hash = new_hash
453                 elif merkle[1] == target_hash:
454                     s.append( "R" + encode(merkle[0]))
455                     target_hash = new_hash
456                 n.append( new_hash )
457                 merkle = merkle[2:]
458             merkle = n
459
460         # send result
461         return {"block_height":block_height,"merkle":s}
462
463
464
465
466     def memorypool_update(store):
467
468         ds = BCDataStream.BCDataStream()
469         postdata = dumps({"method": 'getrawmempool', 'params': [], 'id':'jsonrpc'})
470         respdata = urllib.urlopen(store.bitcoind_url, postdata).read()
471         r = loads(respdata)
472         if r['error'] != None:
473             print r['error']
474             return
475
476         mempool_hashes = r.get('result')
477         for tx_hash in mempool_hashes:
478
479             if tx_hash in store.known_mempool_hashes: continue
480             store.known_mempool_hashes.append(tx_hash)
481
482             postdata = dumps({"method": 'getrawtransaction', 'params': [tx_hash], 'id':'jsonrpc'})
483             respdata = urllib.urlopen(store.bitcoind_url, postdata).read()
484             r = loads(respdata)
485             if r['error'] != None:
486                 continue
487             hextx = r.get('result')
488             ds.clear()
489             ds.write(hextx.decode('hex'))
490             tx = deserialize.parse_Transaction(ds)
491             tx['hash'] = util.double_sha256(tx['tx'])
492                 
493             if store.tx_find_id_and_value(tx):
494                 pass
495             else:
496                 tx_id = store.import_tx(tx, False)
497                 store.update_tx_cache(tx_id)
498                 #print tx_hash
499
500         store.commit()
501         store.known_mempool_hashes = mempool_hashes
502
503
504     def send_tx(self,tx):
505         postdata = dumps({"method": 'sendrawtransaction', 'params': [tx], 'id':'jsonrpc'})
506         respdata = urllib.urlopen(self.bitcoind_url, postdata).read()
507         r = loads(respdata)
508         if r['error'] != None:
509             msg = r['error'].get('message')
510             out = "error: transaction rejected by memorypool: " + msg + "\n" + tx
511         else:
512             out = r['result']
513         return out
514
515
516     def main_iteration(store):
517         with store.dblock:
518             store.catch_up()
519             store.memorypool_update()
520             block_number = store.get_block_number(store.chain_id)
521             return block_number
522
523
524
525
526     def catch_up(store):
527         # if there is an exception, do rollback and then re-raise the exception
528         for dircfg in store.datadirs:
529             try:
530                 store.catch_up_dir(dircfg)
531             except Exception, e:
532                 store.log.exception("Failed to catch up %s", dircfg)
533                 store.rollback()
534                 raise e
535
536
537
538
539 from processor import Processor
540
541 class BlockchainProcessor(Processor):
542
543     def __init__(self, config):
544         Processor.__init__(self)
545         self.store = AbeStore(config)
546         self.block_number = -1
547         self.watched_addresses = []
548
549         # catch_up first
550         n = self.store.main_iteration()
551         print "blockchain: %d blocks"%n
552
553         threading.Timer(10, self.run_store_iteration).start()
554
555     def process(self, request):
556         #print "abe process", request
557
558         message_id = request['id']
559         method = request['method']
560         params = request.get('params',[])
561         result = None
562         error = None
563
564         if method == 'blockchain.numblocks.subscribe':
565             result = self.block_number
566
567         elif method == 'blockchain.address.subscribe':
568             try:
569                 address = params[0]
570                 result = self.store.get_status(address)
571                 self.watch_address(address)
572             except BaseException, e:
573                 error = str(e) + ': ' + address
574                 print "error:", error
575
576         elif method == 'blockchain.address.get_history':
577             try:
578                 address = params[0]
579                 result = self.store.get_history( address ) 
580             except BaseException, e:
581                 error = str(e) + ': ' + address
582                 print "error:", error
583
584         elif method == 'blockchain.block.get_header':
585             try:
586                 height = params[0]
587                 result = self.store.get_block_header( height ) 
588             except BaseException, e:
589                 error = str(e) + ': %d'% height
590                 print "error:", error
591
592         elif method == 'blockchain.transaction.broadcast':
593             txo = self.store.send_tx(params[0])
594             print "sent tx:", txo
595             result = txo 
596
597         elif method == 'blockchain.transaction.get_merkle':
598             try:
599                 tx_hash = params[0]
600                 result = self.store.get_tx_merkle(tx_hash ) 
601             except BaseException, e:
602                 error = str(e) + ': ' + tx_hash
603                 print "error:", error
604
605         else:
606             error = "unknown method:%s"%method
607
608
609         if error:
610             response = { 'id':message_id, 'error':error }
611             self.push_response(response)
612         elif result != '':
613             response = { 'id':message_id, 'result':result }
614             self.push_response(response)
615
616
617     def watch_address(self, addr):
618         if addr not in self.watched_addresses:
619             self.watched_addresses.append(addr)
620
621
622     def run_store_iteration(self):
623         
624         try:
625             block_number = self.store.main_iteration()
626         except:
627             traceback.print_exc(file=sys.stdout)
628             print "terminating"
629             self.shared.stop()
630
631         if self.shared.stopped(): 
632             print "exit timer"
633             return
634
635         if self.block_number != block_number:
636             self.block_number = block_number
637             print "block number:", self.block_number
638             self.push_response({ 'id': None, 'method':'blockchain.numblocks.subscribe', 'params':[self.block_number] })
639
640         while True:
641             try:
642                 addr = self.store.address_queue.get(False)
643             except:
644                 break
645             if addr in self.watched_addresses:
646                 status = self.store.get_status( addr )
647                 self.push_response({ 'id': None, 'method':'blockchain.address.subscribe', 'params':[addr, status] })
648
649         threading.Timer(10, self.run_store_iteration).start()
650
651