change estimated_fee to include 34 bytes per output instead of hard-coded 80 (issue...
[electrum-nvc.git] / lib / wallet.py
index 2d9eb69..3948a95 100644 (file)
@@ -29,6 +29,7 @@ import random
 import aes
 import Queue
 import time
+import math
 
 from util import print_msg, print_error, format_satoshis
 from bitcoin import *
@@ -72,6 +73,7 @@ class WalletStorage:
 
     def __init__(self, config):
         self.lock = threading.Lock()
+        self.config = config
         self.data = {}
         self.file_exists = False
         self.path = self.init_path(config)
@@ -150,7 +152,10 @@ class WalletStorage:
             os.chmod(self.path,stat.S_IREAD | stat.S_IWRITE)
 
 
-class Wallet:
+
+    
+
+class NewWallet:
 
     def __init__(self, storage):
 
@@ -159,7 +164,7 @@ class Wallet:
         self.gap_limit_for_change = 3 # constant
 
         # saved fields
-        self.seed_version          = storage.get('seed_version', SEED_VERSION)
+        self.seed_version          = storage.get('seed_version', NEW_SEED_VERSION)
 
         self.gap_limit             = storage.get('gap_limit', 5)
         self.use_change            = storage.get('use_change',True)
@@ -179,12 +184,6 @@ class Wallet:
 
         self.next_addresses = storage.get('next_addresses',{})
 
-        if self.seed_version not in [4, 6]:
-            msg = "This wallet seed is not supported."
-            if self.seed_version in [5]:
-                msg += "\nTo open this wallet, try 'git checkout seed_v%d'"%self.seed_version
-                print msg
-                sys.exit(1)
 
         # This attribute is set when wallet.start_threads is called.
         self.synchronizer = None
@@ -297,39 +296,19 @@ class Wallet:
 
 
     def init_seed(self, seed):
-        import mnemonic
+        import mnemonic, unicodedata
         
         if self.seed: 
             raise Exception("a seed exists")
 
+        self.seed_version = NEW_SEED_VERSION
+
         if not seed:
             self.seed = self.make_seed()
-            self.seed_version = SEED_VERSION
             return
 
-        # find out what kind of wallet we are
-        try:
-            seed.strip().decode('hex')
-            self.seed_version = 4
-            self.seed = str(seed)
-            return
-        except Exception:
-            pass
+        self.seed = unicodedata.normalize('NFC', unicode(seed.strip()))
 
-        words = seed.split()
-        try:
-            mnemonic.mn_decode(words)
-            uses_electrum_words = True
-        except Exception:
-            uses_electrum_words = False
-        
-        if uses_electrum_words and len(words) != 13:
-            self.seed_version = 4
-            self.seed = mnemonic.mn_decode(words)
-        else:
-            #assert is_seed(seed)
-            self.seed_version = SEED_VERSION
-            self.seed = seed
             
 
     def save_seed(self, password):
@@ -342,41 +321,26 @@ class Wallet:
         self.create_accounts(password)
 
 
-    def create_watching_only_wallet(self, params):
-        K0, c0 = params
-        if not K0:
-            return
-
-        if not c0:
-            self.seed_version = 4
-            self.storage.put('seed_version', self.seed_version, True)
-            self.create_old_account(K0)
-            return
-
-        cK0 = ""
+    def create_watching_only_wallet(self, K0, c0):
+        cK0 = "" #FIXME
         self.master_public_keys = {
             "m/0'/": (c0, K0, cK0),
             }
         self.storage.put('master_public_keys', self.master_public_keys, True)
         self.storage.put('seed_version', self.seed_version, True)
-        self.create_account('1','Main account')
+        self.create_account('1of1','Main account')
 
 
     def create_accounts(self, password):
         seed = pw_decode(self.seed, password)
-
-        if self.seed_version == 4:
-            mpk = OldAccount.mpk_from_seed(seed)
-            self.create_old_account(mpk)
-        else:
-            # create default account
-            self.create_master_keys('1', password)
-            self.create_account('1','Main account')
+        # create default account
+        self.create_master_keys('1of1', password)
+        self.create_account('1of1','Main account')
 
 
     def create_master_keys(self, account_type, password):
         master_k, master_c, master_K, master_cK = bip32_init(self.get_seed(password))
-        if account_type == '1':
+        if account_type == '1of1':
             k0, c0, K0, cK0 = bip32_private_derivation(master_k, master_c, "m/", "m/0'/")
             self.master_public_keys["m/0'/"] = (c0, K0, cK0)
             self.master_private_keys["m/0'/"] = pw_encode(k0, password)
@@ -402,7 +366,7 @@ class Wallet:
         self.storage.put('master_private_keys', self.master_private_keys, True)
 
     def has_master_public_keys(self, account_type):
-        if account_type == '1':
+        if account_type == '1of1':
             return "m/0'/" in self.master_public_keys
         elif account_type == '2of2':
             return set(["m/1'/", "m/2'/"]) <= set(self.master_public_keys.keys())
@@ -436,7 +400,7 @@ class Wallet:
 
 
     def account_id(self, account_type, i):
-        if account_type == '1':
+        if account_type == '1of1':
             return "m/0'/%d"%i
         elif account_type == '2of2':
             return "m/1'/%d & m/2'/%d"%(i,i)
@@ -456,7 +420,7 @@ class Wallet:
         return i
 
 
-    def new_account_address(self, account_type = '1'):
+    def new_account_address(self, account_type = '1of1'):
         i = self.num_accounts(account_type)
         k = self.account_id(account_type,i)
 
@@ -470,12 +434,12 @@ class Wallet:
         return k, addr
 
 
-    def next_account(self, account_type = '1'):
+    def next_account(self, account_type = '1of1'):
 
         i = self.num_accounts(account_type)
         account_id = self.account_id(account_type,i)
 
-        if account_type is '1':
+        if account_type is '1of1':
             master_c0, master_K0, _ = self.master_public_keys["m/0'/"]
             c0, K0, cK0 = bip32_public_derivation(master_c0.decode('hex'), master_K0.decode('hex'), "m/0'/", "m/0'/%d"%i)
             account = BIP32_Account({ 'c':c0, 'K':K0, 'cK':cK0 })
@@ -519,7 +483,7 @@ class Wallet:
 
 
 
-    def create_account(self, account_type = '1', name = None):
+    def create_account(self, account_type = '1of1', name = None):
         k, account = self.next_account(account_type)
         if k in self.pending_accounts:
             self.pending_accounts.pop(k)
@@ -531,12 +495,6 @@ class Wallet:
             self.set_label(k, name)
 
 
-    def create_old_account(self, mpk):
-        self.storage.put('master_public_key', mpk, True)
-        self.accounts[0] = OldAccount({'mpk':mpk, 0:[], 1:[]})
-        self.save_accounts()
-
-
     def save_accounts(self):
         d = {}
         for k, v in self.accounts.items():
@@ -601,11 +559,8 @@ class Wallet:
         return s[0] == 1
 
     def get_master_public_key(self):
-        if self.seed_version == 4:
-            return self.storage.get("master_public_key")
-        else:
-            c, K, cK = self.storage.get("master_public_keys")["m/0'/"]
-            return repr((c, K))
+        c, K, cK = self.storage.get("master_public_keys")["m/0'/"]
+        return repr((c, K))
 
     def get_master_private_key(self, account, password):
         k = self.master_private_keys.get(account)
@@ -638,6 +593,14 @@ class Wallet:
         raise Exception("Address not found", address)
 
 
+    def getpubkeys(self, addr):
+        assert is_valid(addr) and self.is_mine(addr)
+        account, sequence = self.get_address_index(addr)
+        if account != -1:
+            a = self.accounts[account]
+            return a.get_pubkeys( sequence )
+
+
     def get_roots(self, account):
         roots = []
         for a in account.split('&'):
@@ -682,25 +645,14 @@ class Wallet:
         return '&'.join(dd)
 
 
-
     def get_seed(self, password):
         s = pw_decode(self.seed, password)
-        if self.seed_version == 4:
-            seed = s
-            self.accounts[0].check_seed(seed)
-        else:
-            seed = mnemonic_hash(s)
+        seed = mnemonic_to_seed(s,'').encode('hex')
         return seed
-        
 
-    def get_mnemonic(self, password):
-        import mnemonic
-        s = pw_decode(self.seed, password)
-        if self.seed_version == 4:
-            return ' '.join(mnemonic.mn_encode(s))
-        else:
-            return s
 
+    def get_mnemonic(self, password):
+        return pw_decode(self.seed, password)
         
 
     def get_private_key(self, address, password):
@@ -753,23 +705,6 @@ class Wallet:
         for txin in tx.inputs:
             keyid = txin.get('KeyID')
             if keyid:
-
-                if self.seed_version == 4:
-                    m = re.match("old\(([0-9a-f]+),(\d+),(\d+)", keyid)
-                    if not m: continue
-                    mpk = m.group(1)
-                    if mpk != self.storage.get('master_public_key'): continue 
-                    for_change = int(m.group(2))
-                    num = int(m.group(3))
-                    account = self.accounts[0]
-                    addr = account.get_address(for_change, num)
-                    txin['address'] = addr # fixme: side effect
-                    pk = account.get_private_key(seed, (for_change, num))
-                    pubkey = public_key_from_private_key(pk)
-                    keypairs[pubkey] = pk
-                    continue
-
-
                 roots = []
                 for s in keyid.split('&'):
                     m = re.match("bip32\(([0-9a-f]+),([0-9a-f]+),(/\d+/\d+/\d+)", s)
@@ -844,6 +779,16 @@ class Wallet:
         return key.sign_message(message, compressed, address)
 
 
+
+    def decrypt_message(self, pubkey, message, password):
+        address = public_key_to_bc_address(pubkey.decode('hex'))
+        keys = self.get_private_key(address, password)
+        secret = keys[0]
+        ec = regenerate_key(secret)
+        decrypted = ec.decrypt_message(message)
+        return decrypted[0]
+
+
     def change_gap_limit(self, value):
         if value >= self.gap_limit:
             self.gap_limit = value
@@ -928,7 +873,7 @@ class Wallet:
 
 
     def create_pending_accounts(self):
-        for account_type in ['1','2of2','2of3']:
+        for account_type in ['1of1','2of2','2of3']:
             if not self.has_master_public_keys(account_type):
                 continue
             k, a = self.new_account_address(account_type)
@@ -1057,29 +1002,23 @@ class Wallet:
 
 
     def get_account_name(self, k):
-        if k == 0:
-            if self.seed_version == 4: 
-                name = 'Main account'
+        default = "Unnamed account"
+        m = re.match("m/0'/(\d+)", k)
+        if m:
+            num = m.group(1)
+            if num == '0':
+                default = "Main account"
             else:
-                name = 'Old account'
-        else:
-            default = "Unnamed account"
-            m = re.match("m/0'/(\d+)", k)
-            if m:
-                num = m.group(1)
-                if num == '0':
-                    default = "Main account"
-                else:
-                    default = "Account %s"%num
+                default = "Account %s"%num
                     
-            m = re.match("m/1'/(\d+) & m/2'/(\d+)", k)
-            if m:
-                num = m.group(1)
-                default = "2of2 account %s"%num
-            name = self.labels.get(k, default)
-
+        m = re.match("m/1'/(\d+) & m/2'/(\d+)", k)
+        if m:
+            num = m.group(1)
+            default = "2of2 account %s"%num
+        name = self.labels.get(k, default)
         return name
 
+
     def get_account_names(self):
         accounts = {}
         for k, account in self.accounts.items():
@@ -1088,6 +1027,7 @@ class Wallet:
             accounts[-1] = 'Imported keys'
         return accounts
 
+
     def get_account_addresses(self, a, include_change=True):
         if a is None:
             o = self.addresses(True)
@@ -1147,7 +1087,7 @@ class Wallet:
         return [x[1] for x in coins]
 
 
-    def choose_tx_inputs( self, amount, fixed_fee, domain = None ):
+    def choose_tx_inputs( self, amount, fixed_fee, num_outputs, domain = None ):
         """ todo: minimize tx size """
         total = 0
         fee = self.fee if fixed_fee is None else fixed_fee
@@ -1161,13 +1101,13 @@ class Wallet:
         inputs = []
 
         for item in coins:
-            if item.get('coinbase') and item.get('height') + COINBASE_MATURITY > self.network.blockchain.height:
+            if item.get('coinbase') and item.get('height') + COINBASE_MATURITY > self.network.blockchain.height():
                 continue
             addr = item.get('address')
             v = item.get('value')
             total += v
             inputs.append(item)
-            fee = self.estimated_fee(inputs) if fixed_fee is None else fixed_fee
+            fee = self.estimated_fee(inputs, num_outputs) if fixed_fee is None else fixed_fee
             if total >= amount + fee: break
         else:
             inputs = []
@@ -1180,10 +1120,9 @@ class Wallet:
             self.fee = fee
             self.storage.put('fee_per_kb', self.fee, True)
         
-    def estimated_fee(self, inputs):
-        estimated_size =  len(inputs) * 180 + 80     # this assumes non-compressed keys
-        fee = self.fee * int(round(estimated_size/1024.))
-        if fee == 0: fee = self.fee
+    def estimated_fee(self, inputs, num_outputs):
+        estimated_size =  len(inputs) * 180 + num_outputs * 34    # this assumes non-compressed keys
+        fee = self.fee * int(math.ceil(estimated_size/1000.))
         return fee
 
 
@@ -1341,9 +1280,9 @@ class Wallet:
 
     def make_unsigned_transaction(self, outputs, fee=None, change_addr=None, domain=None ):
         for address, x in outputs:
-            assert is_valid(address)
+            assert is_valid(address), "Address " + address + " is invalid!"
         amount = sum( map(lambda x:x[1], outputs) )
-        inputs, total, fee = self.choose_tx_inputs( amount, fee, domain )
+        inputs, total, fee = self.choose_tx_inputs( amount, fee, len(outputs), domain )
         if not inputs:
             raise ValueError("Not enough funds")
         self.add_input_info(inputs)
@@ -1763,3 +1702,176 @@ class WalletSynchronizer(threading.Thread):
                 # Updated gets called too many times from other places as well; if we use that signal we get the notification three times
                 self.network.trigger_callback("new_transaction") 
                 self.was_updated = False
+
+
+
+
+class OldWallet(NewWallet):
+
+    def init_seed(self, seed):
+        import mnemonic
+        
+        if self.seed: 
+            raise Exception("a seed exists")
+
+        if not seed:
+            seed = random_seed(128)
+
+        self.seed_version = OLD_SEED_VERSION
+
+        # see if seed was entered as hex
+        seed = seed.strip()
+        try:
+            assert seed
+            seed.decode('hex')
+            self.seed = str(seed)
+            return
+        except Exception:
+            pass
+
+        words = seed.split()
+        try:
+            mnemonic.mn_decode(words)
+        except Exception:
+            raise
+
+        self.seed = mnemonic.mn_decode(words)
+
+        if not self.seed:
+            raise Exception("Invalid seed")
+            
+
+
+    def get_master_public_key(self):
+        return self.storage.get("master_public_key")
+
+    def create_accounts(self, password):
+        seed = pw_decode(self.seed, password)
+        mpk = OldAccount.mpk_from_seed(seed)
+        self.create_account(mpk)
+
+    def create_account(self, mpk):
+        self.storage.put('master_public_key', mpk, True)
+        self.accounts[0] = OldAccount({'mpk':mpk, 0:[], 1:[]})
+        self.save_accounts()
+
+    def create_watching_only_wallet(self, K0):
+        self.seed_version = OLD_SEED_VERSION
+        self.storage.put('seed_version', self.seed_version, True)
+        self.create_account(K0)
+
+    def get_seed(self, password):
+        seed = pw_decode(self.seed, password)
+        self.accounts[0].check_seed(seed)
+        return seed
+
+    def get_mnemonic(self, password):
+        import mnemonic
+        s = pw_decode(self.seed, password)
+        return ' '.join(mnemonic.mn_encode(s))
+
+
+    def add_keypairs_from_KeyID(self, tx, keypairs, password):
+        # first check the provided password
+        seed = self.get_seed(password)
+        for txin in tx.inputs:
+            keyid = txin.get('KeyID')
+            if keyid:
+                m = re.match("old\(([0-9a-f]+),(\d+),(\d+)", keyid)
+                if not m: continue
+                mpk = m.group(1)
+                if mpk != self.storage.get('master_public_key'): continue 
+                for_change = int(m.group(2))
+                num = int(m.group(3))
+                account = self.accounts[0]
+                addr = account.get_address(for_change, num)
+                txin['address'] = addr # fixme: side effect
+                pk = account.get_private_key(seed, (for_change, num))
+                pubkey = public_key_from_private_key(pk)
+                keypairs[pubkey] = pk
+
+
+    def get_account_name(self, k):
+        assert k == 0
+        return 'Main account'
+
+
+
+
+# former WalletFactory
+class Wallet(object):
+
+    def __new__(self, storage):
+        config = storage.config
+        if config.get('bitkey', False):
+            # if user requested support for Bitkey device,
+            # import Bitkey driver
+            from wallet_bitkey import WalletBitkey
+            return WalletBitkey(config)
+
+        if not storage.file_exists:
+            seed_version = NEW_SEED_VERSION if config.get('bip32') is True else OLD_SEED_VERSION
+        else:
+            seed_version = storage.get('seed_version')
+
+        if seed_version == OLD_SEED_VERSION:
+            return OldWallet(storage)
+        elif seed_version == NEW_SEED_VERSION:
+            return NewWallet(storage)
+        else:
+            msg = "This wallet seed is not supported."
+            if seed_version in [5]:
+                msg += "\nTo open this wallet, try 'git checkout seed_v%d'"%seed_version
+            print msg
+            sys.exit(1)
+
+
+
+    @classmethod
+    def from_seed(self, seed, storage):
+        import mnemonic
+        if not seed:
+            return 
+
+        words = seed.strip().split()
+        try:
+            mnemonic.mn_decode(words)
+            uses_electrum_words = True
+        except Exception:
+            uses_electrum_words = False
+
+        try:
+            seed.decode('hex')
+            is_hex = True
+        except Exception:
+            is_hex = False
+         
+        if is_hex or (uses_electrum_words and len(words) != 13):
+            print "old style wallet", len(words), words
+            w = OldWallet(storage)
+            w.init_seed(seed) #hex
+        else:
+            #assert is_seed(seed)
+            w = NewWallet(storage)
+            w.init_seed(seed)
+
+        return w
+
+
+    @classmethod
+    def from_mpk(self, s, storage):
+        try:
+            mpk, chain = s.split(':')
+        except:
+            mpk = s
+            chain = False
+
+        if chain:
+            w = NewWallet(storage)
+            w.create_watching_only_wallet(mpk, chain)
+        else:
+            w = OldWallet(storage)
+            w.seed = ''
+            w.create_watching_only_wallet(mpk)
+
+        return w