fix: use address and not key in db_addr
[electrum-server.git] / backends / bitcoind / storage.py
1 import plyvel, ast, hashlib, traceback, os
2 from processor import print_log
3 from utils import *
4
5
6 """
7 Patricia tree for hashing unspents
8
9 """
10
11 DEBUG = 0
12 KEYLENGTH = 20 + 32 + 4   #56
13
14 class Storage(object):
15
16     def __init__(self, config, shared, test_reorgs):
17
18         self.dbpath = config.get('leveldb', 'path_fulltree')
19         if not os.path.exists(self.dbpath):
20             os.mkdir(self.dbpath)
21         self.pruning_limit = config.getint('leveldb', 'pruning_limit')
22         self.shared = shared
23         self.hash_list = {}
24         self.parents = {}
25
26         self.test_reorgs = test_reorgs
27         try:
28             self.db_utxo = plyvel.DB(os.path.join(self.dbpath,'utxo'), create_if_missing=True, compression=None)
29             self.db_addr = plyvel.DB(os.path.join(self.dbpath,'addr'), create_if_missing=True, compression=None)
30             self.db_hist = plyvel.DB(os.path.join(self.dbpath,'hist'), create_if_missing=True, compression=None)
31             self.db_undo = plyvel.DB(os.path.join(self.dbpath,'undo'), create_if_missing=True, compression=None)
32         except:
33             traceback.print_exc(file=sys.stdout)
34             self.shared.stop()
35
36         self.db_version = 2 # increase this when database needs to be updated
37         try:
38             self.last_hash, self.height, db_version = ast.literal_eval(self.db_undo.get('height'))
39             print_log("Database version", self.db_version)
40             print_log("Blockchain height", self.height)
41         except:
42             #traceback.print_exc(file=sys.stdout)
43             print_log('initializing database')
44             self.height = 0
45             self.last_hash = '000000000019d6689c085ae165831e934ff763ae46a2a6c172b3f1b60a8ce26f'
46             db_version = self.db_version
47             # write root
48             self.put_node('', {})
49
50         # check version
51         if self.db_version != db_version:
52             print_log("Your database '%s' is deprecated. Please create a new database"%self.dbpath)
53             self.shared.stop()
54             return
55
56
57         # compute root hash
58         d = self.get_node('')
59         self.root_hash, v = self.get_node_hash('',d,None)
60         print_log("UTXO tree root hash:", self.root_hash.encode('hex'))
61         print_log("Coins in database:", v)
62
63     # convert between bitcoin addresses and 20 bytes keys used for storage. 
64     def address_to_key(self, addr):
65         return bc_address_to_hash_160(addr)
66
67     def key_to_address(self, addr):
68         return hash_160_to_bc_address(addr)
69
70
71     def get_proof(self, addr):
72         key = self.address_to_key(addr)
73         i = self.db_utxo.iterator(start=key)
74         k, _ = i.next()
75
76         p = self.get_path(k) 
77         p.append(k)
78
79         out = []
80         for item in p:
81             v = self.db_utxo.get(item)
82             out.append((item.encode('hex'), v.encode('hex')))
83
84         return out
85
86
87     def get_balance(self, addr):
88         key = self.address_to_key(addr)
89         i = self.db_utxo.iterator(start=key)
90         k, _ = i.next()
91         if not k.startswith(key): 
92             return 0
93         p = self.get_parent(k)
94         d = self.get_node(p)
95         letter = k[len(p)]
96         return d[letter][1]
97
98
99     def listunspent(self, addr):
100         key = self.address_to_key(addr)
101
102         out = []
103         for k, v in self.db_utxo.iterator(start=key):
104             if not k.startswith(key):
105                 break
106             if len(k) == KEYLENGTH:
107                 txid = k[20:52].encode('hex')
108                 txpos = hex_to_int(k[52:56])
109                 h = hex_to_int(v[8:12])
110                 v = hex_to_int(v[0:8])
111                 out.append({'tx_hash': txid, 'tx_pos':txpos, 'height': h, 'value':v})
112
113         out.sort(key=lambda x:x['height'])
114         return out
115
116
117     def get_history(self, addr):
118         out = []
119
120         o = self.listunspent(addr)
121         for item in o:
122             out.append((item['tx_hash'], item['height']))
123
124         h = self.db_hist.get(addr)
125         
126         while h:
127             item = h[0:80]
128             h = h[80:]
129             txi = item[0:32].encode('hex')
130             hi = hex_to_int(item[36:40])
131             txo = item[40:72].encode('hex')
132             ho = hex_to_int(item[76:80])
133             out.append((txi, hi))
134             out.append((txo, ho))
135
136         # sort
137         out.sort(key=lambda x:x[1])
138
139         # uniqueness
140         out = set(out)
141
142         return map(lambda x: {'tx_hash':x[0], 'height':x[1]}, out)
143
144
145
146     def get_address(self, txi):
147         return self.db_addr.get(txi)
148
149
150     def get_undo_info(self, height):
151         s = self.db_undo.get("undo_info_%d" % (height % 100))
152         if s is None: print_log("no undo info for ", height)
153         return eval(s)
154
155
156     def write_undo_info(self, height, bitcoind_height, undo_info):
157         if height > bitcoind_height - 100 or self.test_reorgs:
158             self.db_undo.put("undo_info_%d" % (height % 100), repr(undo_info))
159
160
161     def common_prefix(self, word1, word2):
162         max_len = min(len(word1),len(word2))
163         for i in range(max_len):
164             if word2[i] != word1[i]:
165                 index = i
166                 break
167         else:
168             index = max_len
169         return word1[0:index]
170
171
172     def put_node(self, key, d, batch=None):
173         k = 0
174         serialized = ''
175         for i in range(256):
176             if chr(i) in d.keys():
177                 k += 1<<i
178                 h, v = d[chr(i)]
179                 if h is None: h = chr(0)*32
180                 vv = int_to_hex(v, 8).decode('hex')
181                 item = h + vv
182                 assert len(item) == 40
183                 serialized += item
184
185         k = "0x%0.64X" % k # 32 bytes
186         k = k[2:].decode('hex')
187         assert len(k) == 32
188         out = k + serialized
189         if batch:
190             batch.put(key, out)
191         else:
192             self.db_utxo.put(key, out) 
193
194
195     def get_node(self, key):
196
197         s = self.db_utxo.get(key)
198         if s is None: 
199             return 
200
201         #print "get node", key.encode('hex'), len(key), s.encode('hex')
202
203         k = int(s[0:32].encode('hex'), 16)
204         s = s[32:]
205         d = {}
206         for i in range(256):
207             if k % 2 == 1: 
208                 _hash = s[0:32]
209                 value = hex_to_int(s[32:40])
210                 d[chr(i)] = (_hash, value)
211                 s = s[40:]
212             k = k/2
213
214         #cache
215         return d
216
217
218     def add_address(self, target, value, height):
219         assert len(target) == KEYLENGTH
220
221         word = target
222         key = ''
223         path = [ '' ]
224         i = self.db_utxo.iterator()
225
226         while key != target:
227
228             items = self.get_node(key)
229
230             if word[0] in items.keys():
231   
232                 i.seek(key + word[0])
233                 new_key, _ = i.next()
234
235                 if target.startswith(new_key):
236                     # add value to the child node
237                     key = new_key
238                     word = target[len(key):]
239                     if key == target:
240                         break
241                     else:
242                         assert key not in path
243                         path.append(key)
244                 else:
245                     # prune current node and add new node
246                     prefix = self.common_prefix(new_key, target)
247                     index = len(prefix)
248
249                     ## get hash and value of new_key from parent (if it's a leaf)
250                     if len(new_key) == KEYLENGTH:
251                         parent_key = self.get_parent(new_key)
252                         parent = self.get_node(parent_key)
253                         z = parent[ new_key[len(parent_key)] ]
254                         self.put_node(prefix, { target[index]:(None,0), new_key[index]:z } )
255                     else:
256                         # if it is not a leaf, update the hash of new_key because skip_string changed
257                         h, v = self.get_node_hash(new_key, self.get_node(new_key), prefix)
258                         self.put_node(prefix, { target[index]:(None,0), new_key[index]:(h,v) } )
259
260                     path.append(prefix)
261                     self.parents[new_key] = prefix
262                     break
263
264             else:
265                 assert key in path
266                 items[ word[0] ] = (None,0)
267                 self.put_node(key,items)
268                 break
269
270         # write 
271         s = (int_to_hex(value, 8) + int_to_hex(height,4)).decode('hex')
272         self.db_utxo.put(target, s)
273         # the hash of a node is the txid
274         _hash = target[20:52]
275         self.update_node_hash(target, path, _hash, value)
276
277
278     def update_node_hash(self, node, path, _hash, value):
279         c = node
280         for x in path[::-1]:
281             self.parents[c] = x
282             c = x
283
284         self.hash_list[node] = (_hash, value)
285
286
287     def update_hashes(self):
288
289         nodes = {} # nodes to write
290
291         for i in range(KEYLENGTH, -1, -1):
292
293             for node in self.hash_list.keys():
294                 if len(node) != i: continue
295
296                 node_hash, node_value = self.hash_list.pop(node)
297
298                 # for each node, compute its hash, send it to the parent
299                 if node == '':
300                     self.root_hash = node_hash
301                     self.root_value = node_value
302                     break
303
304                 parent = self.parents[node]
305
306                 # read parent.. do this in add_address
307                 d = nodes.get(parent)
308                 if d is None:
309                     d = self.get_node(parent)
310                     assert d is not None
311
312                 letter = node[len(parent)]
313                 assert letter in d.keys()
314
315                 if i != KEYLENGTH and node_hash is None:
316                     d2 = self.get_node(node)
317                     node_hash, node_value = self.get_node_hash(node, d2, parent)
318
319                 assert node_hash is not None
320                 # write new value
321                 d[letter] = (node_hash, node_value)
322                 nodes[parent] = d
323
324                 # iterate
325                 grandparent = self.parents[parent] if parent != '' else None
326                 parent_hash, parent_value = self.get_node_hash(parent, d, grandparent)
327                 self.hash_list[parent] = (parent_hash, parent_value)
328
329         
330         # batch write modified nodes 
331         batch = self.db_utxo.write_batch()
332         for k, v in nodes.items():
333             self.put_node(k, v, batch)
334         batch.write()
335
336         # cleanup
337         assert self.hash_list == {}
338         self.parents = {}
339
340
341     def get_node_hash(self, x, d, parent):
342
343         # final hash
344         if x != '':
345             skip_string = x[len(parent)+1:]
346         else:
347             skip_string = ''
348
349         d2 = sorted(d.items())
350         values = map(lambda x: x[1][1], d2)
351         hashes = map(lambda x: x[1][0], d2)
352         value = sum( values )
353         _hash = self.hash( skip_string + ''.join(hashes) )
354         return _hash, value
355
356
357     def get_path(self, target):
358         word = target
359         key = ''
360         path = [ '' ]
361         i = self.db_utxo.iterator(start='')
362
363         while key != target:
364
365             i.seek(key + word[0])
366             try:
367                 new_key, _ = i.next()
368                 is_child = new_key.startswith(key + word[0])
369             except StopIteration:
370                 is_child = False
371
372             if is_child:
373   
374                 if target.startswith(new_key):
375                     # add value to the child node
376                     key = new_key
377                     word = target[len(key):]
378                     if key == target:
379                         break
380                     else:
381                         assert key not in path
382                         path.append(key)
383                 else:
384                     print_log('not in tree', self.db_utxo.get(key+word[0]), new_key.encode('hex'))
385                     return False
386             else:
387                 assert key in path
388                 break
389
390         return path
391
392
393     def delete_address(self, leaf):
394         path = self.get_path(leaf)
395         if path is False:
396             print_log("addr not in tree", leaf.encode('hex'), self.key_to_address(leaf[0:20]), self.db_utxo.get(leaf))
397             raise
398
399         s = self.db_utxo.get(leaf)
400         
401         self.db_utxo.delete(leaf)
402         if leaf in self.hash_list:
403             self.hash_list.pop(leaf)
404
405         parent = path[-1]
406         letter = leaf[len(parent)]
407         items = self.get_node(parent)
408         items.pop(letter)
409
410         # remove key if it has a single child
411         if len(items) == 1:
412             letter, v = items.items()[0]
413
414             self.db_utxo.delete(parent)
415             if parent in self.hash_list: 
416                 self.hash_list.pop(parent)
417
418             # we need the exact length for the iteration
419             i = self.db_utxo.iterator()
420             i.seek(parent+letter)
421             k, v = i.next()
422
423             # note: k is not necessarily a leaf
424             if len(k) == KEYLENGTH:
425                  _hash, value = k[20:52], hex_to_int(v[0:8])
426             else:
427                 _hash, value = None, None
428
429             self.update_node_hash(k, path[:-1], _hash, value)
430
431         else:
432             self.put_node(parent, items)
433             _hash, value = None, None
434             self.update_node_hash(parent, path[:-1], _hash, value)
435
436         return s
437
438
439     def get_children(self, x):
440         i = self.db_utxo.iterator()
441         l = 0
442         while l <256:
443             i.seek(x+chr(l))
444             k, v = i.next()
445             if k.startswith(x+chr(l)): 
446                 yield k, v
447                 l += 1
448             elif k.startswith(x): 
449                 yield k, v
450                 l = ord(k[len(x)]) + 1
451             else: 
452                 break
453
454
455
456
457     def get_parent(self, x):
458         """ return parent and skip string"""
459         i = self.db_utxo.iterator()
460         for j in range(len(x)):
461             p = x[0:-j-1]
462             i.seek(p)
463             k, v = i.next()
464             if x.startswith(k) and x!=k: 
465                 break
466         else: raise
467         return k
468
469         
470     def hash(self, x):
471         if DEBUG: return "hash("+x+")"
472         return Hash(x)
473
474
475     def get_root_hash(self):
476         return self.root_hash
477
478
479     def close(self):
480         self.db_utxo.close()
481         self.db_addr.close()
482         self.db_hist.close()
483         self.db_undo.close()
484
485
486     def add_to_history(self, addr, tx_hash, tx_pos, value, tx_height):
487         key = self.address_to_key(addr)
488         txo = (tx_hash + int_to_hex(tx_pos, 4)).decode('hex')
489
490         # write the new history
491         self.add_address(key + txo, value, tx_height)
492
493         # backlink
494         self.db_addr.put(txo, addr)
495
496
497
498     def revert_add_to_history(self, addr, tx_hash, tx_pos, value, tx_height):
499         key = self.address_to_key(addr)
500         txo = (tx_hash + int_to_hex(tx_pos, 4)).decode('hex')
501
502         # delete
503         self.delete_address(key + txo)
504
505         # backlink
506         self.db_addr.delete(txo)
507
508
509     def get_utxo_value(self, addr, txi):
510         key = self.address_to_key(addr)
511         leaf = key + txi
512         s = self.db_utxo.get(leaf)
513         value = hex_to_int(s[0:8])
514         return value
515
516
517     def set_spent(self, addr, txi, txid, index, height, undo):
518         key = self.address_to_key(addr)
519         leaf = key + txi
520
521         s = self.delete_address(leaf)
522         value = hex_to_int(s[0:8])
523         in_height = hex_to_int(s[8:12])
524         undo[leaf] = value, in_height
525
526         # delete backlink txi-> addr
527         self.db_addr.delete(txi)
528
529         # add to history
530         s = self.db_hist.get(addr)
531         if s is None: s = ''
532         txo = (txid + int_to_hex(index,4) + int_to_hex(height,4)).decode('hex')
533         s += txi + int_to_hex(in_height,4).decode('hex') + txo
534         s = s[ -80*self.pruning_limit:]
535         self.db_hist.put(addr, s)
536
537
538
539     def revert_set_spent(self, addr, txi, undo):
540         key = self.address_to_key(addr)
541         leaf = key + txi
542
543         # restore backlink
544         self.db_addr.put(txi, addr)
545
546         v, height = undo.pop(leaf)
547         self.add_address(leaf, v, height)
548
549         # revert add to history
550         s = self.db_hist.get(addr)
551         # s might be empty if pruning limit was reached
552         if not s:
553             return
554
555         assert s[-80:-44] == txi
556         s = s[:-80]
557         self.db_hist.put(addr, s)
558
559
560
561
562         
563
564     def import_transaction(self, txid, tx, block_height, touched_addr):
565
566         undo = { 'prev_addr':[] } # contains the list of pruned items for each address in the tx; also, 'prev_addr' is a list of prev addresses
567                 
568         prev_addr = []
569         for i, x in enumerate(tx.get('inputs')):
570             txi = (x.get('prevout_hash') + int_to_hex(x.get('prevout_n'), 4)).decode('hex')
571             addr = self.get_address(txi)
572             if addr is not None: 
573                 self.set_spent(addr, txi, txid, i, block_height, undo)
574                 touched_addr.add(addr)
575             prev_addr.append(addr)
576
577         undo['prev_addr'] = prev_addr 
578
579         # here I add only the outputs to history; maybe I want to add inputs too (that's in the other loop)
580         for x in tx.get('outputs'):
581             addr = x.get('address')
582             if addr is None: continue
583             self.add_to_history(addr, txid, x.get('index'), x.get('value'), block_height)
584             touched_addr.add(addr)
585
586         return undo
587
588
589     def revert_transaction(self, txid, tx, block_height, touched_addr, undo):
590         #print_log("revert tx", txid)
591         for x in reversed(tx.get('outputs')):
592             addr = x.get('address')
593             if addr is None: continue
594             self.revert_add_to_history(addr, txid, x.get('index'), x.get('value'), block_height)
595             touched_addr.add(addr)
596
597         prev_addr = undo.pop('prev_addr')
598         for i, x in reversed(list(enumerate(tx.get('inputs')))):
599             addr = prev_addr[i]
600             if addr is not None:
601                 txi = (x.get('prevout_hash') + int_to_hex(x.get('prevout_n'), 4)).decode('hex')
602                 self.revert_set_spent(addr, txi, undo)
603                 touched_addr.add(addr)
604
605         assert undo == {}
606