use separata class for old wallets; decide with WalletFactory
authorThomasV <thomasv@gitorious>
Mon, 3 Feb 2014 05:26:03 +0000 (06:26 +0100)
committerThomasV <thomasv@gitorious>
Mon, 3 Feb 2014 05:26:03 +0000 (06:26 +0100)
electrum
gui/qt/installwizard.py
lib/wallet.py
lib/wallet_factory.py

index 4bc2c7c..8c100c1 100755 (executable)
--- a/electrum
+++ b/electrum
@@ -210,17 +210,6 @@ if __name__ == '__main__':
     # instanciate wallet for command-line
     storage = WalletStorage(config)
 
-    if cmd.requires_wallet:
-        wallet = Wallet(storage)
-    else:
-        wallet = None
-
-    if cmd.name not in ['create', 'restore'] and cmd.requires_wallet and not storage.file_exists:
-        print_msg("Error: Wallet file not found.")
-        print_msg("Type 'electrum create' to create a new wallet, or provide a path to a wallet with the -w option")
-        sys.exit(0)
-
-
 
     if cmd.name in ['create', 'restore']:
         if storage.file_exists:
@@ -236,35 +225,26 @@ if __name__ == '__main__':
         if not config.get('server'):
             config.set_key('server', pick_random_server())
 
-        fee = options.tx_fee if options.tx_fee else raw_input("fee (default:%s):" % (str(Decimal(wallet.fee)/100000000)))
-        gap = options.gap_limit if options.gap_limit else raw_input("gap limit (default 5):")
-
-        if fee:
-            wallet.set_fee(float(fee)*100000000)
-        if gap:
-            wallet.change_gap_limit(int(gap))
+        #fee = options.tx_fee if options.tx_fee else raw_input("fee (default:%s):" % (str(Decimal(wallet.fee)/100000000)))
+        #gap = options.gap_limit if options.gap_limit else raw_input("gap limit (default 5):")
+        #if fee:
+        #    wallet.set_fee(float(fee)*100000000)
+        #if gap:
+        #    wallet.change_gap_limit(int(gap))
 
         if cmd.name == 'restore':
             import getpass
             seed = getpass.getpass(prompt="seed:", stream=None) if options.concealed else raw_input("seed:")
-            try:
-                seed.decode('hex')
-            except Exception:
-                print_error("Warning: Not hex, trying decode.")
-                seed = mnemonic_decode(seed.split(' '))
-            if not seed:
-                sys.exit("Error: No seed")
-
-            wallet.init_seed(str(seed))
+            wallet = Wallet.from_seed(str(seed),storage)
+            if not wallet:
+                sys.exit("Error: Invalid seed")
             wallet.save_seed(password)
             if not options.offline:
                 network = Network(config)
                 network.start()
                 wallet.start_threads(network)
-
                 print_msg("Recovering wallet...")
                 wallet.restore(lambda x: x)
-
                 if wallet.is_found():
                     print_msg("Recovery successful")
                 else:
@@ -274,6 +254,7 @@ if __name__ == '__main__':
                 print_msg("Warning: This wallet was restored offline. It may contain more addresses than displayed.")
 
         else:
+            wallet = Wallet(storage)
             wallet.init_seed(None)
             wallet.save_seed(password)
             wallet.synchronize()
@@ -285,6 +266,19 @@ if __name__ == '__main__':
         # terminate
         sys.exit(0)
 
+
+    if cmd.name not in ['create', 'restore'] and cmd.requires_wallet and not storage.file_exists:
+        print_msg("Error: Wallet file not found.")
+        print_msg("Type 'electrum create' to create a new wallet, or provide a path to a wallet with the -w option")
+        sys.exit(0)
+
+
+    if cmd.requires_wallet:
+        wallet = Wallet(storage)
+    else:
+        wallet = None
+
+
     # important warning
     if cmd.name in ['dumpprivkey', 'dumpprivkeys']:
         print_msg("WARNING: ALL your private keys are secret.")
index 83adc46..0852763 100644 (file)
@@ -258,13 +258,13 @@ class InstallWizard(QDialog):
         if not action: 
             return
 
-        wallet = Wallet(self.storage)
-        gap = self.config.get('gap_limit', 5)
-        if gap != 5:
-            wallet.gap_limit = gap
-            wallet.storage.put('gap_limit', gap, True)
+        #gap = self.config.get('gap_limit', 5)
+        #if gap != 5:
+        #    wallet.gap_limit = gap
+        #    wallet.storage.put('gap_limit', gap, True)
 
         if action == 'create':
+            wallet = Wallet(self.storage)
             wallet.init_seed(None)
             if not self.show_seed(wallet):
                 return
@@ -276,23 +276,14 @@ class InstallWizard(QDialog):
                 wallet.synchronize()  # generate first addresses offline
             self.waiting_dialog(create)
 
-
         elif action == 'restore':
             seed = self.seed_dialog()
             if not seed:
                 return
-            try:
-                wallet.init_seed(seed)
-            except Exception:
-                import traceback
-                traceback.print_exc(file=sys.stdout)
-                QMessageBox.warning(None, _('Error'), _('Incorrect seed'), _('OK'))
-                return
-
+            wallet = Wallet.from_seed(seed, self.storage)
             ok, old_password, password = self.password_dialog(wallet)
             wallet.save_seed(password)
 
-
         elif action == 'watching':
             mpk = self.mpk_dialog()
             if not mpk:
index f17c08f..ed03b4f 100644 (file)
@@ -308,29 +308,6 @@ class Wallet:
         self.seed = unicodedata.normalize('NFC', unicode(seed.strip()))
         return
 
-        # find out what kind of wallet we are
-        try:
-            seed.strip().decode('hex')
-            self.seed_version = 4
-            self.seed = str(seed)
-            return
-        except Exception:
-            pass
-
-        words = seed.split()
-        try:
-            mnemonic.mn_decode(words)
-            uses_electrum_words = True
-        except Exception:
-            uses_electrum_words = False
-        
-        if uses_electrum_words and len(words) != 13:
-            self.seed_version = 4
-            self.seed = mnemonic.mn_decode(words)
-        else:
-            #assert is_seed(seed)
-            self.seed_version = SEED_VERSION
-            self.seed = seed
             
 
     def save_seed(self, password):
@@ -343,18 +320,8 @@ class Wallet:
         self.create_accounts(password)
 
 
-    def create_watching_only_wallet(self, params):
-        K0, c0 = params
-        if not K0:
-            return
-
-        if not c0:
-            self.seed_version = 4
-            self.storage.put('seed_version', self.seed_version, True)
-            self.create_old_account(K0)
-            return
-
-        cK0 = ""
+    def create_watching_only_wallet(self, K0, c0):
+        cK0 = "" #FIXME
         self.master_public_keys = {
             "m/0'/": (c0, K0, cK0),
             }
@@ -365,14 +332,9 @@ class Wallet:
 
     def create_accounts(self, password):
         seed = pw_decode(self.seed, password)
-
-        if self.seed_version == 4:
-            mpk = OldAccount.mpk_from_seed(seed)
-            self.create_old_account(mpk)
-        else:
-            # create default account
-            self.create_master_keys('1', password)
-            self.create_account('1','Main account')
+        # create default account
+        self.create_master_keys('1', password)
+        self.create_account('1','Main account')
 
 
     def create_master_keys(self, account_type, password):
@@ -532,12 +494,6 @@ class Wallet:
             self.set_label(k, name)
 
 
-    def create_old_account(self, mpk):
-        self.storage.put('master_public_key', mpk, True)
-        self.accounts[0] = OldAccount({'mpk':mpk, 0:[], 1:[]})
-        self.save_accounts()
-
-
     def save_accounts(self):
         d = {}
         for k, v in self.accounts.items():
@@ -602,11 +558,8 @@ class Wallet:
         return s[0] == 1
 
     def get_master_public_key(self):
-        if self.seed_version == 4:
-            return self.storage.get("master_public_key")
-        else:
-            c, K, cK = self.storage.get("master_public_keys")["m/0'/"]
-            return repr((c, K))
+        c, K, cK = self.storage.get("master_public_keys")["m/0'/"]
+        return repr((c, K))
 
     def get_master_private_key(self, account, password):
         k = self.master_private_keys.get(account)
@@ -685,21 +638,12 @@ class Wallet:
 
     def get_seed(self, password):
         s = pw_decode(self.seed, password)
-        if self.seed_version == 4:
-            seed = s
-            self.accounts[0].check_seed(seed)
-        else:
-            seed = mnemonic_to_seed(s,'').encode('hex')
+        seed = mnemonic_to_seed(s,'').encode('hex')
         return seed
-        
+
 
     def get_mnemonic(self, password):
-        import mnemonic
-        s = pw_decode(self.seed, password)
-        if self.seed_version == 4:
-            return ' '.join(mnemonic.mn_encode(s))
-        else:
-            return s
+        return pw_decode(self.seed, password)
         
 
     def get_private_key(self, address, password):
@@ -752,23 +696,6 @@ class Wallet:
         for txin in tx.inputs:
             keyid = txin.get('KeyID')
             if keyid:
-
-                if self.seed_version == 4:
-                    m = re.match("old\(([0-9a-f]+),(\d+),(\d+)", keyid)
-                    if not m: continue
-                    mpk = m.group(1)
-                    if mpk != self.storage.get('master_public_key'): continue 
-                    for_change = int(m.group(2))
-                    num = int(m.group(3))
-                    account = self.accounts[0]
-                    addr = account.get_address(for_change, num)
-                    txin['address'] = addr # fixme: side effect
-                    pk = account.get_private_key(seed, (for_change, num))
-                    pubkey = public_key_from_private_key(pk)
-                    keypairs[pubkey] = pk
-                    continue
-
-
                 roots = []
                 for s in keyid.split('&'):
                     m = re.match("bip32\(([0-9a-f]+),([0-9a-f]+),(/\d+/\d+/\d+)", s)
@@ -1056,29 +983,23 @@ class Wallet:
 
 
     def get_account_name(self, k):
-        if k == 0:
-            if self.seed_version == 4: 
-                name = 'Main account'
+        default = "Unnamed account"
+        m = re.match("m/0'/(\d+)", k)
+        if m:
+            num = m.group(1)
+            if num == '0':
+                default = "Main account"
             else:
-                name = 'Old account'
-        else:
-            default = "Unnamed account"
-            m = re.match("m/0'/(\d+)", k)
-            if m:
-                num = m.group(1)
-                if num == '0':
-                    default = "Main account"
-                else:
-                    default = "Account %s"%num
+                default = "Account %s"%num
                     
-            m = re.match("m/1'/(\d+) & m/2'/(\d+)", k)
-            if m:
-                num = m.group(1)
-                default = "2of2 account %s"%num
-            name = self.labels.get(k, default)
-
+        m = re.match("m/1'/(\d+) & m/2'/(\d+)", k)
+        if m:
+            num = m.group(1)
+            default = "2of2 account %s"%num
+        name = self.labels.get(k, default)
         return name
 
+
     def get_account_names(self):
         accounts = {}
         for k, account in self.accounts.items():
@@ -1087,6 +1008,7 @@ class Wallet:
             accounts[-1] = 'Imported keys'
         return accounts
 
+
     def get_account_addresses(self, a, include_change=True):
         if a is None:
             o = self.addresses(True)
@@ -1764,3 +1686,88 @@ class WalletSynchronizer(threading.Thread):
                 self.was_updated = False
 
 
+
+
+class OldWallet(Wallet):
+
+    def init_seed(self, seed):
+        import mnemonic
+        
+        if self.seed: 
+            raise Exception("a seed exists")
+
+        if not seed:
+            raise
+
+        self.seed_version = 4
+
+        # see if seed was entered as hex
+        try:
+            seed.strip().decode('hex')
+            self.seed = str(seed)
+            return
+        except Exception:
+            pass
+
+        words = seed.split()
+        try:
+            mnemonic.mn_decode(words)
+        except Exception:
+            raise
+
+        self.seed = mnemonic.mn_decode(words)
+            
+
+    def get_master_public_key(self):
+        return self.storage.get("master_public_key")
+
+    def create_accounts(self, password):
+        seed = pw_decode(self.seed, password)
+        mpk = OldAccount.mpk_from_seed(seed)
+        self.create_account(mpk)
+
+    def create_account(self, mpk):
+        self.storage.put('master_public_key', mpk, True)
+        self.accounts[0] = OldAccount({'mpk':mpk, 0:[], 1:[]})
+        self.save_accounts()
+
+    def create_watching_only_wallet(self, K0):
+        self.seed_version = 4
+        self.storage.put('seed_version', self.seed_version, True)
+        self.create_account(K0)
+
+    def get_seed(self, password):
+        seed = pw_decode(self.seed, password)
+        self.accounts[0].check_seed(seed)
+        return seed
+
+    def get_mnemonic(self, password):
+        import mnemonic
+        s = pw_decode(self.seed, password)
+        return ' '.join(mnemonic.mn_encode(s))
+
+
+    def add_keypairs_from_KeyID(self, tx, keypairs, password):
+        # first check the provided password
+        seed = self.get_seed(password)
+        for txin in tx.inputs:
+            keyid = txin.get('KeyID')
+            if keyid:
+                m = re.match("old\(([0-9a-f]+),(\d+),(\d+)", keyid)
+                if not m: continue
+                mpk = m.group(1)
+                if mpk != self.storage.get('master_public_key'): continue 
+                for_change = int(m.group(2))
+                num = int(m.group(3))
+                account = self.accounts[0]
+                addr = account.get_address(for_change, num)
+                txin['address'] = addr # fixme: side effect
+                pk = account.get_private_key(seed, (for_change, num))
+                pubkey = public_key_from_private_key(pk)
+                keypairs[pubkey] = pk
+
+
+    def get_account_name(self, k):
+        assert k == 0
+        return 'Main account'
+
index 5f82eca..71cf3b1 100644 (file)
@@ -1,11 +1,53 @@
+from version import SEED_VERSION
+from wallet import OldWallet, Wallet
+
 class WalletFactory(object):
-    def __new__(cls, config):
-        if config.get('bitkey', False):
+    def __new__(self, storage):
+
+        if storage.get('bitkey', False):
             # if user requested support for Bitkey device,
             # import Bitkey driver
             from wallet_bitkey import WalletBitkey
             return WalletBitkey(config)
         
-        # Load standard wallet
-        from wallet import Wallet
-        return Wallet(config)
+        seed_version = storage.get('seed_version', SEED_VERSION)
+        if seed_version not in [4, 6]:
+            msg = "This wallet seed is not supported."
+            if seed_version in [5]:
+                msg += "\nTo open this wallet, try 'git checkout seed_v%d'"%seed_version
+                print msg
+                sys.exit(1)
+
+
+        if seed_version == 4:
+            return OldWallet(storage)
+        else:
+            return Wallet(storage)
+
+
+    @classmethod
+    def from_seed(self, seed, storage):
+        import mnemonic
+        words = seed.strip().split()
+        try:
+            mnemonic.mn_decode(words)
+            uses_electrum_words = True
+        except Exception:
+            uses_electrum_words = False
+
+        try:
+            seed.decode('hex')
+            is_hex = True
+        except Exception:
+            is_hex = False
+         
+        if is_hex or (uses_electrum_words and len(words) != 13):
+            print "old style wallet", len(words), words
+            w = OldWallet(storage)
+            w.init_seed(seed) #hex
+        else:
+            #assert is_seed(seed)
+            w = Wallet(storage)
+            w.init_seed(seed)
+
+        return w