add limit on sql requests
[electrum-server.git] / backends / abe / __init__.py
1 from Abe.util import hash_to_address, decode_check_address
2 from Abe.DataStore import DataStore as Datastore_class
3 from Abe import DataStore, readconf, BCDataStream,  deserialize, util, base58
4
5 import binascii
6
7 import thread, traceback, sys, urllib, operator
8 from json import dumps, loads
9 from Queue import Queue
10 import time, threading
11
12
13 SQL_LIMIT=200
14
15 class AbeStore(Datastore_class):
16
17     def __init__(self, config):
18         conf = DataStore.CONFIG_DEFAULTS
19         args, argv = readconf.parse_argv( [], conf)
20         args.dbtype = config.get('database','type')
21         if args.dbtype == 'sqlite3':
22             args.connect_args = { 'database' : config.get('database','database') }
23         elif args.dbtype == 'MySQLdb':
24             args.connect_args = { 'db' : config.get('database','database'), 'user' : config.get('database','username'), 'passwd' : config.get('database','password') }
25         elif args.dbtype == 'psycopg2':
26             args.connect_args = { 'database' : config.get('database','database') }
27
28         Datastore_class.__init__(self,args)
29
30         self.tx_cache = {}
31         self.bitcoind_url = 'http://%s:%s@%s:%s/' % ( config.get('bitcoind','user'), config.get('bitcoind','password'), config.get('bitcoind','host'), config.get('bitcoind','port'))
32
33         self.address_queue = Queue()
34
35         self.dblock = thread.allocate_lock()
36
37
38
39     def import_block(self, b, chain_ids=frozenset()):
40         #print "import block"
41         block_id = super(AbeStore, self).import_block(b, chain_ids)
42         for pos in xrange(len(b['transactions'])):
43             tx = b['transactions'][pos]
44             if 'hash' not in tx:
45                 tx['hash'] = util.double_sha256(tx['tx'])
46             tx_id = self.tx_find_id_and_value(tx)
47             if tx_id:
48                 self.update_tx_cache(tx_id)
49             else:
50                 print "error: import_block: no tx_id"
51         return block_id
52
53
54     def update_tx_cache(self, txid):
55         inrows = self.get_tx_inputs(txid, False)
56         for row in inrows:
57             _hash = self.binout(row[6])
58             if not _hash:
59                 #print "WARNING: missing tx_in for tx", txid
60                 continue
61
62             address = hash_to_address(chr(0), _hash)
63             if self.tx_cache.has_key(address):
64                 print "cache: invalidating", address
65                 self.tx_cache.pop(address)
66             self.address_queue.put(address)
67
68         outrows = self.get_tx_outputs(txid, False)
69         for row in outrows:
70             _hash = self.binout(row[6])
71             if not _hash:
72                 #print "WARNING: missing tx_out for tx", txid
73                 continue
74
75             address = hash_to_address(chr(0), _hash)
76             if self.tx_cache.has_key(address):
77                 print "cache: invalidating", address
78                 self.tx_cache.pop(address)
79             self.address_queue.put(address)
80
81     def safe_sql(self,sql, params=(), lock=True):
82
83         error = False
84         try:
85             if lock: self.dblock.acquire()
86             ret = self.selectall(sql,params)
87         except:
88             error = True
89         finally:
90             if lock: self.dblock.release()
91
92         if error: 
93             raise BaseException('sql error')
94
95         return ret
96             
97
98     def get_tx_outputs(self, tx_id, lock=True):
99         return self.safe_sql("""SELECT
100                 txout.txout_pos,
101                 txout.txout_scriptPubKey,
102                 txout.txout_value,
103                 nexttx.tx_hash,
104                 nexttx.tx_id,
105                 txin.txin_pos,
106                 pubkey.pubkey_hash
107               FROM txout
108               LEFT JOIN txin ON (txin.txout_id = txout.txout_id)
109               LEFT JOIN pubkey ON (pubkey.pubkey_id = txout.pubkey_id)
110               LEFT JOIN tx nexttx ON (txin.tx_id = nexttx.tx_id)
111              WHERE txout.tx_id = %d 
112              ORDER BY txout.txout_pos
113         """%(tx_id), (), lock)
114
115     def get_tx_inputs(self, tx_id, lock=True):
116         return self.safe_sql(""" SELECT
117                 txin.txin_pos,
118                 txin.txin_scriptSig,
119                 txout.txout_value,
120                 COALESCE(prevtx.tx_hash, u.txout_tx_hash),
121                 prevtx.tx_id,
122                 COALESCE(txout.txout_pos, u.txout_pos),
123                 pubkey.pubkey_hash
124               FROM txin
125               LEFT JOIN txout ON (txout.txout_id = txin.txout_id)
126               LEFT JOIN pubkey ON (pubkey.pubkey_id = txout.pubkey_id)
127               LEFT JOIN tx prevtx ON (txout.tx_id = prevtx.tx_id)
128               LEFT JOIN unlinked_txin u ON (u.txin_id = txin.txin_id)
129              WHERE txin.tx_id = %d
130              ORDER BY txin.txin_pos
131              """%(tx_id,), (), lock)
132
133
134     def get_address_out_rows(self, dbhash):
135         out = self.safe_sql(""" SELECT
136                 b.block_nTime,
137                 cc.chain_id,
138                 b.block_height,
139                 1,
140                 b.block_hash,
141                 tx.tx_hash,
142                 tx.tx_id,
143                 txin.txin_pos,
144                 -prevout.txout_value
145               FROM chain_candidate cc
146               JOIN block b ON (b.block_id = cc.block_id)
147               JOIN block_tx ON (block_tx.block_id = b.block_id)
148               JOIN tx ON (tx.tx_id = block_tx.tx_id)
149               JOIN txin ON (txin.tx_id = tx.tx_id)
150               JOIN txout prevout ON (txin.txout_id = prevout.txout_id)
151               JOIN pubkey ON (pubkey.pubkey_id = prevout.pubkey_id)
152              WHERE pubkey.pubkey_hash = ?
153                AND cc.in_longest = 1
154              LIMIT ? """, (dbhash,SQL_LIMIT))
155
156         if len(out)==SQL_LIMIT: 
157             raise BaseException('limit reached')
158         return out
159
160     def get_address_out_rows_memorypool(self, dbhash):
161         out = self.safe_sql(""" SELECT
162                 1,
163                 tx.tx_hash,
164                 tx.tx_id,
165                 txin.txin_pos,
166                 -prevout.txout_value
167               FROM tx 
168               JOIN txin ON (txin.tx_id = tx.tx_id)
169               JOIN txout prevout ON (txin.txout_id = prevout.txout_id)
170               JOIN pubkey ON (pubkey.pubkey_id = prevout.pubkey_id)
171              WHERE pubkey.pubkey_hash = ?
172              LIMIT ? """, (dbhash,SQL_LIMIT))
173
174         if len(out)==SQL_LIMIT: 
175             raise BaseException('limit reached')
176         return out
177
178     def get_address_in_rows(self, dbhash):
179         out = self.safe_sql(""" SELECT
180                 b.block_nTime,
181                 cc.chain_id,
182                 b.block_height,
183                 0,
184                 b.block_hash,
185                 tx.tx_hash,
186                 tx.tx_id,
187                 txout.txout_pos,
188                 txout.txout_value
189               FROM chain_candidate cc
190               JOIN block b ON (b.block_id = cc.block_id)
191               JOIN block_tx ON (block_tx.block_id = b.block_id)
192               JOIN tx ON (tx.tx_id = block_tx.tx_id)
193               JOIN txout ON (txout.tx_id = tx.tx_id)
194               JOIN pubkey ON (pubkey.pubkey_id = txout.pubkey_id)
195              WHERE pubkey.pubkey_hash = ?
196                AND cc.in_longest = 1
197                LIMIT ? """, (dbhash,SQL_LIMIT))
198
199         if len(out)==SQL_LIMIT: 
200             raise BaseException('limit reached')
201         return out
202
203     def get_address_in_rows_memorypool(self, dbhash):
204         out = self.safe_sql( """ SELECT
205                 0,
206                 tx.tx_hash,
207                 tx.tx_id,
208                 txout.txout_pos,
209                 txout.txout_value
210               FROM tx
211               JOIN txout ON (txout.tx_id = tx.tx_id)
212               JOIN pubkey ON (pubkey.pubkey_id = txout.pubkey_id)
213              WHERE pubkey.pubkey_hash = ?
214              LIMIT ? """, (dbhash,SQL_LIMIT))
215
216         if len(out)==SQL_LIMIT: 
217             raise BaseException('limit reached')
218         return out
219
220     def get_history(self, addr):
221
222         cached_version = self.tx_cache.get( addr )
223         if cached_version is not None:
224             return cached_version
225
226         version, binaddr = decode_check_address(addr)
227         if binaddr is None:
228             return None
229
230         dbhash = self.binin(binaddr)
231         rows = []
232         rows += self.get_address_out_rows( dbhash )
233         rows += self.get_address_in_rows( dbhash )
234
235         txpoints = []
236         known_tx = []
237
238         for row in rows:
239             try:
240                 nTime, chain_id, height, is_in, blk_hash, tx_hash, tx_id, pos, value = row
241             except:
242                 print "cannot unpack row", row
243                 break
244             tx_hash = self.hashout_hex(tx_hash)
245             txpoint = {
246                     "timestamp":    int(nTime),
247                     "height":   int(height),
248                     "is_input":    int(is_in),
249                     "block_hash": self.hashout_hex(blk_hash),
250                     "tx_hash":  tx_hash,
251                     "tx_id":    int(tx_id),
252                     "index":      int(pos),
253                     "value":    int(value),
254                     }
255
256             txpoints.append(txpoint)
257             known_tx.append(self.hashout_hex(tx_hash))
258
259
260         # todo: sort them really...
261         txpoints = sorted(txpoints, key=operator.itemgetter("timestamp"))
262
263         # read memory pool
264         rows = []
265         rows += self.get_address_in_rows_memorypool( dbhash )
266         rows += self.get_address_out_rows_memorypool( dbhash )
267         address_has_mempool = False
268
269         current_id = self.new_id("tx")
270
271         for row in rows:
272             is_in, tx_hash, tx_id, pos, value = row
273             tx_hash = self.hashout_hex(tx_hash)
274             if tx_hash in known_tx:
275                 continue
276
277             # this means that pending transactions were added to the db, even if they are not returned by getmemorypool
278             address_has_mempool = True
279
280             # fixme: we need to detect transactions that became invalid
281             if current_id - tx_id > 10000:
282                 continue
283
284
285             #print "mempool", tx_hash
286             txpoint = {
287                     "timestamp":    0,
288                     "height":   0,
289                     "is_input":    int(is_in),
290                     "block_hash": 'mempool', 
291                     "tx_hash":  tx_hash,
292                     "tx_id":    int(tx_id),
293                     "index":      int(pos),
294                     "value":    int(value),
295                     }
296             txpoints.append(txpoint)
297
298
299         for txpoint in txpoints:
300             tx_id = txpoint['tx_id']
301             
302             txinputs = []
303             inrows = self.get_tx_inputs(tx_id)
304             for row in inrows:
305                 _hash = self.binout(row[6])
306                 if not _hash:
307                     #print "WARNING: missing tx_in for tx", tx_id, addr
308                     continue
309                 address = hash_to_address(chr(0), _hash)
310                 txinputs.append(address)
311             txpoint['inputs'] = txinputs
312             txoutputs = []
313             outrows = self.get_tx_outputs(tx_id)
314             for row in outrows:
315                 _hash = self.binout(row[6])
316                 if not _hash:
317                     #print "WARNING: missing tx_out for tx", tx_id, addr
318                     continue
319                 address = hash_to_address(chr(0), _hash)
320                 txoutputs.append(address)
321             txpoint['outputs'] = txoutputs
322
323             # for all unspent inputs, I want their scriptpubkey. (actually I could deduce it from the address)
324             if not txpoint['is_input']:
325                 # detect if already redeemed...
326                 for row in outrows:
327                     if row[6] == dbhash: break
328                 else:
329                     raise
330                 #row = self.get_tx_output(tx_id,dbhash)
331                 # pos, script, value, o_hash, o_id, o_pos, binaddr = row
332                 # if not redeemed, we add the script
333                 if row:
334                     if not row[4]: txpoint['raw_output_script'] = row[1]
335
336         # cache result
337         # do not cache mempool results because statuses are ambiguous
338         if not address_has_mempool:
339             self.tx_cache[addr] = txpoints
340         
341         return txpoints
342
343
344     def get_status(self,addr):
345         # get address status, i.e. the last block for that address.
346         tx_points = self.get_history(addr)
347         if not tx_points:
348             status = None
349         else:
350             lastpoint = tx_points[-1]
351             status = lastpoint['block_hash']
352             # this is a temporary hack; move it up once old clients have disappeared
353             if status == 'mempool': # and session['version'] != "old":
354                 status = status + ':%d'% len(tx_points)
355         return status
356
357
358
359     def memorypool_update(store):
360
361         ds = BCDataStream.BCDataStream()
362         postdata = dumps({"method": 'getmemorypool', 'params': [], 'id':'jsonrpc'})
363
364         respdata = urllib.urlopen(store.bitcoind_url, postdata).read()
365         r = loads(respdata)
366         if r['error'] != None:
367             return
368
369         v = r['result'].get('transactions')
370         for hextx in v:
371             ds.clear()
372             ds.write(hextx.decode('hex'))
373             tx = deserialize.parse_Transaction(ds)
374             tx['hash'] = util.double_sha256(tx['tx'])
375             tx_hash = store.hashin(tx['hash'])
376
377             if store.tx_find_id_and_value(tx):
378                 pass
379             else:
380                 tx_id = store.import_tx(tx, False)
381                 store.update_tx_cache(tx_id)
382     
383         store.commit()
384
385
386     def send_tx(self,tx):
387         postdata = dumps({"method": 'importtransaction', 'params': [tx], 'id':'jsonrpc'})
388         respdata = urllib.urlopen(self.bitcoind_url, postdata).read()
389         r = loads(respdata)
390         if r['error'] != None:
391             msg = r['error'].get('message')
392             out = "error: transaction rejected by memorypool: " + msg + "\n" + tx
393         else:
394             out = r['result']
395         return out
396
397
398     def main_iteration(store):
399         try:
400             store.dblock.acquire()
401             store.catch_up()
402             store.memorypool_update()
403             block_number = store.get_block_number(1)
404
405         except IOError:
406             print "IOError: cannot reach bitcoind"
407             block_number = 0
408         except:
409             traceback.print_exc(file=sys.stdout)
410             block_number = 0
411         finally:
412             store.dblock.release()
413
414         return block_number
415
416
417     def catch_up(store):
418         # if there is an exception, do rollback and then re-raise the exception
419         for dircfg in store.datadirs:
420             try:
421                 store.catch_up_dir(dircfg)
422             except Exception, e:
423                 store.log.exception("Failed to catch up %s", dircfg)
424                 store.rollback()
425                 raise e
426
427
428
429
430 from processor import Processor
431
432 class BlockchainProcessor(Processor):
433
434     def __init__(self, config):
435         Processor.__init__(self)
436         self.store = AbeStore(config)
437         self.block_number = -1
438         self.watched_addresses = []
439         threading.Timer(10, self.run_store_iteration).start()
440
441     def process(self, request):
442         #print "abe process", request
443
444         message_id = request['id']
445         method = request['method']
446         params = request.get('params',[])
447         result = None
448         error = None
449
450         if method == 'blockchain.numblocks.subscribe':
451             result = self.block_number
452
453         elif method == 'blockchain.address.subscribe':
454             try:
455                 address = params[0]
456                 result = self.store.get_status(address)
457                 self.watch_address(address)
458             except BaseException, e:
459                 error = str(e)
460
461         elif method == 'blockchain.address.get_history':
462             try:
463                 address = params[0]
464                 result = self.store.get_history( address ) 
465             except BaseException, e:
466                 error = str(e)
467
468         elif method == 'blockchain.transaction.broadcast':
469             txo = self.store.send_tx(params[0])
470             print "sent tx:", txo
471             result = txo 
472
473         else:
474             error = "unknown method:%s"%method
475
476
477         if error:
478             response = { 'id':message_id, 'error':error }
479             self.push_response(response)
480         elif result != '':
481             response = { 'id':message_id, 'result':result }
482             self.push_response(response)
483
484
485     def watch_address(self, addr):
486         if addr not in self.watched_addresses:
487             self.watched_addresses.append(addr)
488
489
490     def run_store_iteration(self):
491         if self.shared.stopped(): 
492             print "exit timer"
493             return
494         
495         block_number = self.store.main_iteration()
496         if self.block_number != block_number:
497             self.block_number = block_number
498             print "block number:", self.block_number
499             self.push_response({ 'method':'blockchain.numblocks.subscribe', 'params':[self.block_number] })
500
501         while True:
502             try:
503                 addr = self.store.address_queue.get(False)
504             except:
505                 break
506             if addr in self.watched_addresses:
507                 status = self.store.get_status( addr )
508                 self.push_response({ 'method':'blockchain.address.subscribe', 'params':[addr, status] })
509
510         threading.Timer(10, self.run_store_iteration).start()
511
512