add memory pool results to getaddressbalance
[electrum-server.git] / backends / bitcoind / storage.py
1 import plyvel, ast, hashlib, traceback, os
2 from processor import print_log
3 from utils import *
4
5
6 """
7 Patricia tree for hashing unspents
8
9 """
10
11 DEBUG = 0
12 KEYLENGTH = 20 + 32 + 4   #56
13
14 class Storage(object):
15
16     def __init__(self, config, shared, test_reorgs):
17
18         self.dbpath = config.get('leveldb', 'path_fulltree')
19         if not os.path.exists(self.dbpath):
20             os.mkdir(self.dbpath)
21         self.pruning_limit = config.getint('leveldb', 'pruning_limit')
22         self.shared = shared
23         self.hash_list = {}
24         self.parents = {}
25
26         self.test_reorgs = test_reorgs
27         try:
28             self.db_utxo = plyvel.DB(os.path.join(self.dbpath,'utxo'), create_if_missing=True, compression=None)
29             self.db_addr = plyvel.DB(os.path.join(self.dbpath,'addr'), create_if_missing=True, compression=None)
30             self.db_hist = plyvel.DB(os.path.join(self.dbpath,'hist'), create_if_missing=True, compression=None)
31             self.db_undo = plyvel.DB(os.path.join(self.dbpath,'undo'), create_if_missing=True, compression=None)
32         except:
33             traceback.print_exc(file=sys.stdout)
34             self.shared.stop()
35
36         self.db_version = 2 # increase this when database needs to be updated
37         try:
38             self.last_hash, self.height, db_version = ast.literal_eval(self.db_undo.get('height'))
39             print_log("Database version", self.db_version)
40             print_log("Blockchain height", self.height)
41         except:
42             #traceback.print_exc(file=sys.stdout)
43             print_log('initializing database')
44             self.height = 0
45             self.last_hash = '000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f'
46             db_version = self.db_version
47             # write root
48             self.put_node('', {})
49
50         # check version
51         if self.db_version != db_version:
52             print_log("Your database '%s' is deprecated. Please create a new database"%self.dbpath)
53             self.shared.stop()
54             return
55
56
57         # compute root hash
58         d = self.get_node('')
59         self.root_hash, v = self.get_node_hash('',d,None)
60         print_log("UTXO tree root hash:", self.root_hash.encode('hex'))
61         print_log("Coins in database:", v)
62
63     # convert between bitcoin addresses and 20 bytes keys used for storage. 
64     def address_to_key(self, addr):
65         return bc_address_to_hash_160(addr)
66
67     def key_to_address(self, addr):
68         return hash_160_to_bc_address(addr)
69
70
71     def get_proof(self, addr):
72         key = self.address_to_key(addr)
73         i = self.db_utxo.iterator(start=key)
74         k, _ = i.next()
75
76         p = self.get_path(k) 
77         p.append(k)
78
79         out = []
80         for item in p:
81             v = self.db_utxo.get(item)
82             out.append((item.encode('hex'), v.encode('hex')))
83
84         return out
85
86
87     def get_balance(self, addr):
88         key = self.address_to_key(addr)
89         i = self.db_utxo.iterator(start=key)
90         k, _ = i.next()
91         if not k.startswith(key): 
92             return 0
93         p = self.get_parent(k)
94         d = self.get_node(p)
95         letter = k[len(p)]
96         return d[letter][1]
97
98
99     def listunspent(self, addr):
100         key = self.address_to_key(addr)
101
102         out = []
103         for k, v in self.db_utxo.iterator(start=key):
104             if not k.startswith(key):
105                 break
106             if len(k) == KEYLENGTH:
107                 txid = k[20:52].encode('hex')
108                 txpos = hex_to_int(k[52:56])
109                 h = hex_to_int(v[8:12])
110                 v = hex_to_int(v[0:8])
111                 out.append({'tx_hash': txid, 'tx_pos':txpos, 'height': h, 'value':v})
112
113         out.sort(key=lambda x:x['height'])
114         return out
115
116
117     def get_history(self, addr):
118         out = []
119
120         o = self.listunspent(addr)
121         for item in o:
122             out.append((item['tx_hash'], item['height']))
123
124         h = self.db_hist.get(addr)
125         
126         while h:
127             item = h[0:80]
128             h = h[80:]
129             txi = item[0:32].encode('hex')
130             hi = hex_to_int(item[36:40])
131             txo = item[40:72].encode('hex')
132             ho = hex_to_int(item[76:80])
133             out.append((txi, hi))
134             out.append((txo, ho))
135
136         # sort
137         out.sort(key=lambda x:x[1])
138
139         # uniqueness
140         out = set(out)
141
142         return map(lambda x: {'tx_hash':x[0], 'height':x[1]}, out)
143
144
145
146     def get_address(self, txi):
147         addr = self.db_addr.get(txi)
148         return self.key_to_address(addr) if addr else None
149
150
151     def get_undo_info(self, height):
152         s = self.db_undo.get("undo_info_%d" % (height % 100))
153         if s is None: print_log("no undo info for ", height)
154         return eval(s)
155
156
157     def write_undo_info(self, height, bitcoind_height, undo_info):
158         if height > bitcoind_height - 100 or self.test_reorgs:
159             self.db_undo.put("undo_info_%d" % (height % 100), repr(undo_info))
160
161
162     def common_prefix(self, word1, word2):
163         max_len = min(len(word1),len(word2))
164         for i in range(max_len):
165             if word2[i] != word1[i]:
166                 index = i
167                 break
168         else:
169             index = max_len
170         return word1[0:index]
171
172
173     def put_node(self, key, d, batch=None):
174         k = 0
175         serialized = ''
176         for i in range(256):
177             if chr(i) in d.keys():
178                 k += 1<<i
179                 h, v = d[chr(i)]
180                 if h is None: h = chr(0)*32
181                 vv = int_to_hex(v, 8).decode('hex')
182                 item = h + vv
183                 assert len(item) == 40
184                 serialized += item
185
186         k = "0x%0.64X" % k # 32 bytes
187         k = k[2:].decode('hex')
188         assert len(k) == 32
189         out = k + serialized
190         if batch:
191             batch.put(key, out)
192         else:
193             self.db_utxo.put(key, out) 
194
195
196     def get_node(self, key):
197
198         s = self.db_utxo.get(key)
199         if s is None: 
200             return 
201
202         #print "get node", key.encode('hex'), len(key), s.encode('hex')
203
204         k = int(s[0:32].encode('hex'), 16)
205         s = s[32:]
206         d = {}
207         for i in range(256):
208             if k % 2 == 1: 
209                 _hash = s[0:32]
210                 value = hex_to_int(s[32:40])
211                 d[chr(i)] = (_hash, value)
212                 s = s[40:]
213             k = k/2
214
215         #cache
216         return d
217
218
219     def add_address(self, target, value, height):
220         assert len(target) == KEYLENGTH
221
222         word = target
223         key = ''
224         path = [ '' ]
225         i = self.db_utxo.iterator()
226
227         while key != target:
228
229             items = self.get_node(key)
230
231             if word[0] in items.keys():
232   
233                 i.seek(key + word[0])
234                 new_key, _ = i.next()
235
236                 if target.startswith(new_key):
237                     # add value to the child node
238                     key = new_key
239                     word = target[len(key):]
240                     if key == target:
241                         break
242                     else:
243                         assert key not in path
244                         path.append(key)
245                 else:
246                     # prune current node and add new node
247                     prefix = self.common_prefix(new_key, target)
248                     index = len(prefix)
249
250                     ## get hash and value of new_key from parent (if it's a leaf)
251                     if len(new_key) == KEYLENGTH:
252                         parent_key = self.get_parent(new_key)
253                         parent = self.get_node(parent_key)
254                         z = parent[ new_key[len(parent_key)] ]
255                         self.put_node(prefix, { target[index]:(None,0), new_key[index]:z } )
256                     else:
257                         # if it is not a leaf, update the hash of new_key because skip_string changed
258                         h, v = self.get_node_hash(new_key, self.get_node(new_key), prefix)
259                         self.put_node(prefix, { target[index]:(None,0), new_key[index]:(h,v) } )
260
261                     path.append(prefix)
262                     self.parents[new_key] = prefix
263                     break
264
265             else:
266                 assert key in path
267                 items[ word[0] ] = (None,0)
268                 self.put_node(key,items)
269                 break
270
271         # write 
272         s = (int_to_hex(value, 8) + int_to_hex(height,4)).decode('hex')
273         self.db_utxo.put(target, s)
274         # the hash of a node is the txid
275         _hash = target[20:52]
276         self.update_node_hash(target, path, _hash, value)
277
278
279     def update_node_hash(self, node, path, _hash, value):
280         c = node
281         for x in path[::-1]:
282             self.parents[c] = x
283             c = x
284
285         self.hash_list[node] = (_hash, value)
286
287
288     def update_hashes(self):
289
290         nodes = {} # nodes to write
291
292         for i in range(KEYLENGTH, -1, -1):
293
294             for node in self.hash_list.keys():
295                 if len(node) != i: continue
296
297                 node_hash, node_value = self.hash_list.pop(node)
298
299                 # for each node, compute its hash, send it to the parent
300                 if node == '':
301                     self.root_hash = node_hash
302                     self.root_value = node_value
303                     break
304
305                 parent = self.parents[node]
306
307                 # read parent.. do this in add_address
308                 d = nodes.get(parent)
309                 if d is None:
310                     d = self.get_node(parent)
311                     assert d is not None
312
313                 letter = node[len(parent)]
314                 assert letter in d.keys()
315
316                 if i != KEYLENGTH and node_hash is None:
317                     d2 = self.get_node(node)
318                     node_hash, node_value = self.get_node_hash(node, d2, parent)
319
320                 assert node_hash is not None
321                 # write new value
322                 d[letter] = (node_hash, node_value)
323                 nodes[parent] = d
324
325                 # iterate
326                 grandparent = self.parents[parent] if parent != '' else None
327                 parent_hash, parent_value = self.get_node_hash(parent, d, grandparent)
328                 self.hash_list[parent] = (parent_hash, parent_value)
329
330         
331         # batch write modified nodes 
332         batch = self.db_utxo.write_batch()
333         for k, v in nodes.items():
334             self.put_node(k, v, batch)
335         batch.write()
336
337         # cleanup
338         assert self.hash_list == {}
339         self.parents = {}
340
341
342     def get_node_hash(self, x, d, parent):
343
344         # final hash
345         if x != '':
346             skip_string = x[len(parent)+1:]
347         else:
348             skip_string = ''
349
350         d2 = sorted(d.items())
351         values = map(lambda x: x[1][1], d2)
352         hashes = map(lambda x: x[1][0], d2)
353         value = sum( values )
354         _hash = self.hash( skip_string + ''.join(hashes) )
355         return _hash, value
356
357
358     def get_path(self, target):
359         word = target
360         key = ''
361         path = [ '' ]
362         i = self.db_utxo.iterator(start='')
363
364         while key != target:
365
366             i.seek(key + word[0])
367             try:
368                 new_key, _ = i.next()
369                 is_child = new_key.startswith(key + word[0])
370             except StopIteration:
371                 is_child = False
372
373             if is_child:
374   
375                 if target.startswith(new_key):
376                     # add value to the child node
377                     key = new_key
378                     word = target[len(key):]
379                     if key == target:
380                         break
381                     else:
382                         assert key not in path
383                         path.append(key)
384                 else:
385                     print_log('not in tree', self.db_utxo.get(key+word[0]), new_key.encode('hex'))
386                     return False
387             else:
388                 assert key in path
389                 break
390
391         return path
392
393
394     def delete_address(self, leaf):
395         path = self.get_path(leaf)
396         if path is False:
397             print_log("addr not in tree", leaf.encode('hex'), self.key_to_address(leaf[0:20]), self.db_utxo.get(leaf))
398             raise
399
400         s = self.db_utxo.get(leaf)
401         
402         self.db_utxo.delete(leaf)
403         if leaf in self.hash_list:
404             self.hash_list.pop(leaf)
405
406         parent = path[-1]
407         letter = leaf[len(parent)]
408         items = self.get_node(parent)
409         items.pop(letter)
410
411         # remove key if it has a single child
412         if len(items) == 1:
413             letter, v = items.items()[0]
414
415             self.db_utxo.delete(parent)
416             if parent in self.hash_list: 
417                 self.hash_list.pop(parent)
418
419             # we need the exact length for the iteration
420             i = self.db_utxo.iterator()
421             i.seek(parent+letter)
422             k, v = i.next()
423
424             # note: k is not necessarily a leaf
425             if len(k) == KEYLENGTH:
426                  _hash, value = k[20:52], hex_to_int(v[0:8])
427             else:
428                 _hash, value = None, None
429
430             self.update_node_hash(k, path[:-1], _hash, value)
431
432         else:
433             self.put_node(parent, items)
434             _hash, value = None, None
435             self.update_node_hash(parent, path[:-1], _hash, value)
436
437         return s
438
439
440     def get_children(self, x):
441         i = self.db_utxo.iterator()
442         l = 0
443         while l <256:
444             i.seek(x+chr(l))
445             k, v = i.next()
446             if k.startswith(x+chr(l)): 
447                 yield k, v
448                 l += 1
449             elif k.startswith(x): 
450                 yield k, v
451                 l = ord(k[len(x)]) + 1
452             else: 
453                 break
454
455
456
457
458     def get_parent(self, x):
459         """ return parent and skip string"""
460         i = self.db_utxo.iterator()
461         for j in range(len(x)):
462             p = x[0:-j-1]
463             i.seek(p)
464             k, v = i.next()
465             if x.startswith(k) and x!=k: 
466                 break
467         else: raise
468         return k
469
470         
471     def hash(self, x):
472         if DEBUG: return "hash("+x+")"
473         return Hash(x)
474
475
476     def get_root_hash(self):
477         return self.root_hash
478
479
480     def close(self):
481         self.db_utxo.close()
482         self.db_addr.close()
483         self.db_hist.close()
484         self.db_undo.close()
485
486
487     def add_to_history(self, addr, tx_hash, tx_pos, value, tx_height):
488         key = self.address_to_key(addr)
489         txo = (tx_hash + int_to_hex(tx_pos, 4)).decode('hex')
490
491         # write the new history
492         self.add_address(key + txo, value, tx_height)
493
494         # backlink
495         self.db_addr.put(txo, key)
496
497
498
499     def revert_add_to_history(self, addr, tx_hash, tx_pos, value, tx_height):
500         key = self.address_to_key(addr)
501         txo = (tx_hash + int_to_hex(tx_pos, 4)).decode('hex')
502
503         # delete
504         self.delete_address(key + txo)
505
506         # backlink
507         self.db_addr.delete(txo)
508
509
510     def get_utxo_value(self, addr, txi):
511         key = self.address_to_key(addr)
512         leaf = key + txi
513         s = self.db_utxo.get(leaf)
514         value = hex_to_int(s[0:8])
515         return value
516
517
518     def set_spent(self, addr, txi, txid, index, height, undo):
519         key = self.address_to_key(addr)
520         leaf = key + txi
521
522         s = self.delete_address(leaf)
523         value = hex_to_int(s[0:8])
524         in_height = hex_to_int(s[8:12])
525         undo[leaf] = value, in_height
526
527         # delete backlink txi-> addr
528         self.db_addr.delete(txi)
529
530         # add to history
531         s = self.db_hist.get(addr)
532         if s is None: s = ''
533         txo = (txid + int_to_hex(index,4) + int_to_hex(height,4)).decode('hex')
534         s += txi + int_to_hex(in_height,4).decode('hex') + txo
535         s = s[ -80*self.pruning_limit:]
536         self.db_hist.put(addr, s)
537
538
539
540     def revert_set_spent(self, addr, txi, undo):
541         key = self.address_to_key(addr)
542         leaf = key + txi
543
544         # restore backlink
545         self.db_addr.put(txi, key)
546
547         v, height = undo.pop(leaf)
548         self.add_address(leaf, v, height)
549
550         # revert add to history
551         s = self.db_hist.get(addr)
552         # s might be empty if pruning limit was reached
553         if not s:
554             return
555
556         assert s[-80:-44] == txi
557         s = s[:-80]
558         self.db_hist.put(addr, s)
559
560
561
562
563         
564
565     def import_transaction(self, txid, tx, block_height, touched_addr):
566
567         undo = { 'prev_addr':[] } # contains the list of pruned items for each address in the tx; also, 'prev_addr' is a list of prev addresses
568                 
569         prev_addr = []
570         for i, x in enumerate(tx.get('inputs')):
571             txi = (x.get('prevout_hash') + int_to_hex(x.get('prevout_n'), 4)).decode('hex')
572             addr = self.get_address(txi)
573             if addr is not None: 
574                 self.set_spent(addr, txi, txid, i, block_height, undo)
575                 touched_addr.add(addr)
576             prev_addr.append(addr)
577
578         undo['prev_addr'] = prev_addr 
579
580         # here I add only the outputs to history; maybe I want to add inputs too (that's in the other loop)
581         for x in tx.get('outputs'):
582             addr = x.get('address')
583             if addr is None: continue
584             self.add_to_history(addr, txid, x.get('index'), x.get('value'), block_height)
585             touched_addr.add(addr)
586
587         return undo
588
589
590     def revert_transaction(self, txid, tx, block_height, touched_addr, undo):
591         #print_log("revert tx", txid)
592         for x in reversed(tx.get('outputs')):
593             addr = x.get('address')
594             if addr is None: continue
595             self.revert_add_to_history(addr, txid, x.get('index'), x.get('value'), block_height)
596             touched_addr.add(addr)
597
598         prev_addr = undo.pop('prev_addr')
599         for i, x in reversed(list(enumerate(tx.get('inputs')))):
600             addr = prev_addr[i]
601             if addr is not None:
602                 txi = (x.get('prevout_hash') + int_to_hex(x.get('prevout_n'), 4)).decode('hex')
603                 self.revert_set_spent(addr, txi, undo)
604                 touched_addr.add(addr)
605
606         assert undo == {}
607