move private key methods from wallet to accounts
[electrum-nvc.git] / lib / wallet.py
index 0e2f6b7..ecd0ea7 100644 (file)
@@ -45,30 +45,6 @@ DUST_THRESHOLD = 5430
 # internal ID for imported account
 IMPORTED_ACCOUNT = '/x'
 
-# AES encryption
-EncodeAES = lambda secret, s: base64.b64encode(aes.encryptData(secret,s))
-DecodeAES = lambda secret, e: aes.decryptData(secret, base64.b64decode(e))
-
-def pw_encode(s, password):
-    if password:
-        secret = Hash(password)
-        return EncodeAES(secret, s)
-    else:
-        return s
-
-def pw_decode(s, password):
-    if password is not None:
-        secret = Hash(password)
-        try:
-            d = DecodeAES(secret, s)
-        except Exception:
-            raise Exception('Invalid password')
-        return d
-    else:
-        return s
-
-
-
 
 
 from version import *
@@ -249,7 +225,26 @@ class Abstract_Wallet:
         self.accounts = {}
         self.imported_keys = self.storage.get('imported_keys',{})
         if self.imported_keys:
-            self.accounts['/x'] = ImportedAccount(self.imported_keys)
+            print_error("cannot load imported keys")
+
+        d = self.storage.get('accounts', {})
+        for k, v in d.items():
+            if k == 0:
+                v['mpk'] = self.storage.get('master_public_key')
+                self.accounts[k] = OldAccount(v)
+            elif v.get('imported'):
+                self.accounts[k] = ImportedAccount(v)
+            elif v.get('xpub3'):
+                self.accounts[k] = BIP32_Account_2of3(v)
+            elif v.get('xpub2'):
+                self.accounts[k] = BIP32_Account_2of2(v)
+            elif v.get('xpub'):
+                self.accounts[k] = BIP32_Account(v)
+            elif v.get('pending'):
+                self.accounts[k] = PendingAccount(v)
+            else:
+                print_error("cannot load account", v)
+
 
     def synchronize(self):
         pass
@@ -257,14 +252,9 @@ class Abstract_Wallet:
     def can_create_accounts(self):
         return False
 
-    def check_password(self, password):
-        raise
-
-
     def set_up_to_date(self,b):
         with self.lock: self.up_to_date = b
 
-
     def is_up_to_date(self):
         with self.lock: return self.up_to_date
 
@@ -274,21 +264,27 @@ class Abstract_Wallet:
         while not self.is_up_to_date(): 
             time.sleep(0.1)
 
+    def is_imported(self, addr):
+        account = self.accounts.get(IMPORTED_ACCOUNT)
+        if account: 
+            return addr in account.get_addresses(0)
+        else:
+            return False
 
     def import_key(self, sec, password):
-        self.check_password(password)
         try:
-            address = address_from_private_key(sec)
+            pubkey = public_key_from_private_key(sec)
+            address = public_key_to_bc_address(pubkey.decode('hex'))
         except Exception:
             raise Exception('Invalid private key')
 
         if self.is_mine(address):
             raise Exception('Address already in wallet')
         
-        # store the originally requested keypair into the imported keys table
-        self.imported_keys[address] = pw_encode(sec, password )
-        self.storage.put('imported_keys', self.imported_keys, True)
-        self.accounts[IMPORTED_ACCOUNT] = ImportedAccount(self.imported_keys)
+        if self.accounts.get(IMPORTED_ACCOUNT) is None:
+            self.accounts[IMPORTED_ACCOUNT] = ImportedAccount({'imported':{}})
+        self.accounts[IMPORTED_ACCOUNT].add(address, pubkey, sec, password)
+        self.save_accounts()
         
         if self.synchronizer:
             self.synchronizer.subscribe_to_addresses([address])
@@ -296,13 +292,11 @@ class Abstract_Wallet:
         
 
     def delete_imported_key(self, addr):
-        if addr in self.imported_keys:
-            self.imported_keys.pop(addr)
-            self.storage.put('imported_keys', self.imported_keys, True)
-            if self.imported_keys:
-                self.accounts[IMPORTED_ACCOUNT] = ImportedAccount(self.imported_keys)
-            else:
-                self.accounts.pop(IMPORTED_ACCOUNT)
+        account = self.accounts[IMPORTED_ACCOUNT]
+        account.remove(addr)
+        if not account.get_addresses(0):
+            self.accounts.pop(IMPORTED_ACCOUNT)
+        self.save_accounts()
 
 
     def set_label(self, name, text = None):
@@ -368,35 +362,15 @@ class Abstract_Wallet:
     def getpubkeys(self, addr):
         assert is_valid(addr) and self.is_mine(addr)
         account, sequence = self.get_address_index(addr)
-        if account != IMPORTED_ACCOUNT:
-            a = self.accounts[account]
-            return a.get_pubkeys( sequence )
-
+        a = self.accounts[account]
+        return a.get_pubkeys( sequence )
 
 
     def get_private_key(self, address, password):
         if self.is_watching_only():
             return []
-
-        out = []
-        if address in self.imported_keys.keys():
-            self.check_password(password)
-            out.append( pw_decode( self.imported_keys[address], password ) )
-        else:
-            seed = self.get_seed(password)
-            account_id, sequence = self.get_address_index(address)
-            account = self.accounts[account_id]
-            xpubs = account.get_master_pubkeys()
-            roots = [k for k, v in self.master_public_keys.iteritems() if v in xpubs]
-            for root in roots:
-                xpriv = self.get_master_private_key(root, password)
-                if not xpriv:
-                    continue
-                _, _, _, c, k = deserialize_xkey(xpriv)
-                pk = bip32_private_key( sequence, k, c )
-                out.append(pk)
-                    
-        return out
+        account_id, sequence = self.get_address_index(address)
+        return self.accounts[account_id].get_private_key(sequence, self, password)
 
 
     def get_public_keys(self, address):
@@ -414,9 +388,6 @@ class Abstract_Wallet:
                 pubkey = public_key_from_private_key(sec)
                 keypairs[ pubkey ] = sec
 
-                # this is needed because we don't store imported pubkeys
-                if address in self.imported_keys.keys():
-                    txin['redeemPubkey'] = pubkey
 
 
     def add_keypairs_from_KeyID(self, tx, keypairs, password):
@@ -891,8 +862,6 @@ class Abstract_Wallet:
 
     def add_input_info(self, txin):
         address = txin['address']
-        if address in self.imported_keys.keys():
-            return
         account_id, sequence = self.get_address_index(address)
         account = self.accounts[account_id]
         txin['KeyID'] = account.get_keyID(sequence)
@@ -941,12 +910,10 @@ class Abstract_Wallet:
             self.seed = pw_encode( decoded, new_password)
             self.storage.put('seed', self.seed, True)
 
-        for k in self.imported_keys.keys():
-            a = self.imported_keys[k]
-            b = pw_decode(a, old_password)
-            c = pw_encode(b, new_password)
-            self.imported_keys[k] = c
-        self.storage.put('imported_keys', self.imported_keys, True)
+        imported_account = self.accounts.get(IMPORTED_ACCOUNT)
+        if imported_account: 
+            imported_account.update_password(old_password, new_password)
+            self.save_accounts()
 
         for k, v in self.master_private_keys.items():
             b = pw_decode(v, old_password)
@@ -1097,15 +1064,27 @@ class Abstract_Wallet:
     def get_accounts(self):
         return self.accounts
 
+    def save_accounts(self):
+        d = {}
+        for k, v in self.accounts.items():
+            d[k] = v.dump()
+        self.storage.put('accounts', d, True)
+
+    
 
 class Imported_Wallet(Abstract_Wallet):
 
     def __init__(self, storage):
         Abstract_Wallet.__init__(self, storage)
+        a = self.accounts.get(IMPORTED_ACCOUNT)
+        if not a:
+            self.accounts[IMPORTED_ACCOUNT] = ImportedAccount({'imported':{}})
+
 
     def is_watching_only(self):
-        n = self.imported_keys.values()
-        return n == [''] * len(n)
+        acc = self.accounts[IMPORTED_ACCOUNT]
+        n = acc.keypairs.values()
+        return n == [(None, None)] * len(n)
 
     def has_seed(self):
         return False
@@ -1114,12 +1093,7 @@ class Imported_Wallet(Abstract_Wallet):
         return False
 
     def check_password(self, password):
-        if self.imported_keys:
-            k, v = self.imported_keys.items()[0]
-            sec = pw_decode(v, password)
-            address = address_from_private_key(sec)
-            assert address == k
-
+        self.accounts[IMPORTED_ACCOUNT].get_private_key((0,0), self, password)
 
 
 
@@ -1137,9 +1111,6 @@ class Deterministic_Wallet(Abstract_Wallet):
     def is_watching_only(self):
         return not self.has_seed()
 
-    def check_password(self, password):
-        self.get_seed(password)
-
     def add_seed(self, seed, password):
         if self.seed: 
             raise Exception("a seed exists")
@@ -1157,12 +1128,10 @@ class Deterministic_Wallet(Abstract_Wallet):
         self.create_master_keys(password)
 
     def get_seed(self, password):
-        s = pw_decode(self.seed, password)
-        seed = mnemonic_to_seed(s,'').encode('hex')
-        return seed
+        return pw_decode(self.seed, password)
 
     def get_mnemonic(self, password):
-        return pw_decode(self.seed, password)
+        return self.get_seed(password)
         
     def change_gap_limit(self, value):
         if value >= self.gap_limit:
@@ -1324,32 +1293,6 @@ class Deterministic_Wallet(Abstract_Wallet):
         self.save_accounts()
 
 
-    def save_accounts(self):
-        d = {}
-        for k, v in self.accounts.items():
-            d[k] = v.dump()
-        self.storage.put('accounts', d, True)
-
-    
-
-    def load_accounts(self):
-        Abstract_Wallet.load_accounts(self)
-        d = self.storage.get('accounts', {})
-        for k, v in d.items():
-            if k == 0:
-                v['mpk'] = self.storage.get('master_public_key')
-                self.accounts[k] = OldAccount(v)
-            elif v.get('xpub3'):
-                self.accounts[k] = BIP32_Account_2of3(v)
-            elif v.get('xpub2'):
-                self.accounts[k] = BIP32_Account_2of2(v)
-            elif v.get('xpub'):
-                self.accounts[k] = BIP32_Account(v)
-            elif v.get('pending'):
-                self.accounts[k] = PendingAccount(v)
-            else:
-                print_error("cannot load account", v)
-
 
     def account_is_pending(self, k):
         return type(self.accounts.get(k)) == PendingAccount
@@ -1393,6 +1336,10 @@ class NewWallet(Deterministic_Wallet):
         xpriv = pw_decode( k, password)
         return xpriv
 
+    def check_password(self, password):
+        xpriv = self.get_master_private_key( "m/", password )
+        xpub = self.master_public_keys["m/"]
+        assert deserialize_xkey(xpriv)[3] == deserialize_xkey(xpub)[3]
 
     def create_watching_only_wallet(self, xpub):
         self.storage.put('seed_version', self.seed_version, True)
@@ -1611,9 +1558,12 @@ class OldWallet(Deterministic_Wallet):
 
     def get_seed(self, password):
         seed = pw_decode(self.seed, password)
-        self.accounts[0].check_seed(seed)
         return seed
 
+    def check_password(self, password):
+        seed = pw_decode(self.seed, password)
+        self.accounts[0].check_seed(seed)
+
     def get_mnemonic(self, password):
         import mnemonic
         s = pw_decode(self.seed, password)
@@ -1641,21 +1591,6 @@ class OldWallet(Deterministic_Wallet):
 
 
 
-    def get_private_key(self, address, password):
-        if self.is_watching_only():
-            return []
-
-        out = []
-        if address in self.imported_keys.keys():
-            self.check_password(password)
-            out.append( pw_decode( self.imported_keys[address], password ) )
-        else:
-            seed = self.get_seed(password)
-            account_id, sequence = self.get_address_index(address)
-            pk = self.accounts[0].get_private_key(seed, sequence)
-            out.append(pk)
-        return out
-
     def check_pending_accounts(self):
         pass
 
@@ -1757,8 +1692,8 @@ class Wallet(object):
     def from_address(self, text, storage):
         w = Imported_Wallet(storage)
         for x in text.split():
-            w.imported_keys[x] = ''
-        w.storage.put('imported_keys', w.imported_keys, True)
+            w.accounts[IMPORTED_ACCOUNT].add(x, None, None, None)
+        w.save_accounts()
         return w
 
     @classmethod