Refactored ConnectInputs, so valid-transaction-checks are done before ECDSA-verifying...
authorGavin Andresen <gavinandresen@gmail.com>
Wed, 11 Jan 2012 01:18:00 +0000 (20:18 -0500)
committerGavin Andresen <gavinandresen@gmail.com>
Fri, 13 Jan 2012 15:22:24 +0000 (10:22 -0500)
src/main.cpp
src/main.h
src/test/transaction_tests.cpp

index 6858336..3a43f5e 100644 (file)
@@ -275,26 +275,22 @@ bool CTransaction::IsStandard() const
 // expensive-to-check-upon-redemption script like:
 //   DUP CHECKSIG DROP ... repeated 100 times... OP_1
 //
-bool CTransaction::AreInputsStandard(const std::map<uint256, std::pair<CTxIndex, CTransaction> >& mapInputs) const
+bool CTransaction::AreInputsStandard(const MapPrevTx& mapInputs) const
 {
     if (fTestNet)
         return true; // Allow non-standard on testnet
 
+    if (IsCoinBase())
+        return true; // Coinbases are allowed to have any input
+
     for (int i = 0; i < vin.size(); i++)
     {
-        COutPoint prevout = vin[i].prevout;
-
-        std::map<uint256, std::pair<CTxIndex, CTransaction> >::const_iterator mi = mapInputs.find(prevout.hash);
-        if (mi == mapInputs.end())
-            return false;
-
-        const CTransaction& txPrev = (mi->second).second;
-        assert(prevout.n < txPrev.vout.size());
+        const CTxOut& prev = GetOutputFor(vin[i], mapInputs);
 
         vector<vector<unsigned char> > vSolutions;
         txnouttype whichType;
         // get the scriptPubKey corresponding to this input:
-        const CScript& prevScript = txPrev.vout[prevout.n].scriptPubKey;
+        const CScript& prevScript = prev.scriptPubKey;
         if (!Solver(prevScript, whichType, vSolutions))
             return false;
         if (whichType == TX_SCRIPTHASH)
@@ -494,7 +490,7 @@ bool CTransaction::AcceptToMemoryPool(CTxDB& txdb, bool fCheckInputs, bool* pfMi
 
     if (fCheckInputs)
     {
-        map<uint256, pair<CTxIndex, CTransaction> > mapInputs;
+        MapPrevTx mapInputs;
         map<uint256, CTxIndex> mapUnused;
         if (!FetchInputs(txdb, mapUnused, false, false, mapInputs))
         {
@@ -507,27 +503,20 @@ bool CTransaction::AcceptToMemoryPool(CTxDB& txdb, bool fCheckInputs, bool* pfMi
         if (!AreInputsStandard(mapInputs))
             return error("AcceptToMemoryPool() : nonstandard transaction input");
 
-        // Check against previous transactions
-        int64 nFees = 0;
-        int nSigOps = 0;
-        if (!ConnectInputs(mapInputs, mapUnused, CDiskTxPos(1,1,1), pindexBest, nFees, false, false, nSigOps))
-        {
-            if (pfMissingInputs)
-                *pfMissingInputs = true;
-            return error("AcceptToMemoryPool() : ConnectInputs failed %s", hash.ToString().substr(0,10).c_str());
-        }
+        int64 nFees = GetValueIn(mapInputs)-GetValueOut();
+        int nSigOps = GetSigOpCount(mapInputs);
+        unsigned int nSize = ::GetSerializeSize(*this, SER_NETWORK);
+
+        // Don't accept it if it can't get into a block
+        if (nFees < GetMinFee(1000, true, GMF_RELAY))
+            return error("AcceptToMemoryPool() : not enough fees");
 
         // Checking ECDSA signatures is a CPU bottleneck, so to avoid denial-of-service
         // attacks disallow transactions with more than one SigOp per 65 bytes.
         // 65 bytes because that is the minimum size of an ECDSA signature
-        unsigned int nSize = ::GetSerializeSize(*this, SER_NETWORK);
         if (nSigOps > nSize / 65 || nSize < 100)
             return error("AcceptToMemoryPool() : transaction with out-of-bounds SigOpCount");
 
-        // Don't accept it if it can't get into a block
-        if (nFees < GetMinFee(1000, true, GMF_RELAY))
-            return error("AcceptToMemoryPool() : not enough fees");
-
         // Continuously rate-limit free transactions
         // This mitigates 'penny-flooding' -- sending thousands of free transactions just to
         // be annoying or make other's transactions take longer to confirm.
@@ -552,6 +541,15 @@ bool CTransaction::AcceptToMemoryPool(CTxDB& txdb, bool fCheckInputs, bool* pfMi
                 dFreeCount += nSize;
             }
         }
+
+        // Check against previous transactions
+        // This is done last to help prevent CPU exhaustion denial-of-service attacks.
+        if (!ConnectInputs(mapInputs, mapUnused, CDiskTxPos(1,1,1), pindexBest, false, false))
+        {
+            if (pfMissingInputs)
+                *pfMissingInputs = true;
+            return error("AcceptToMemoryPool() : ConnectInputs failed %s", hash.ToString().substr(0,10).c_str());
+        }
     }
 
     // Store transaction in memory
@@ -925,7 +923,7 @@ bool CTransaction::DisconnectInputs(CTxDB& txdb)
 
 
 bool CTransaction::FetchInputs(CTxDB& txdb, const map<uint256, CTxIndex>& mapTestPool,
-                               bool fBlock, bool fMiner, map<uint256, pair<CTxIndex, CTransaction> >& inputsRet)
+                               bool fBlock, bool fMiner, MapPrevTx& inputsRet)
 {
     if (IsCoinBase())
         return true; // Coinbase transactions have no inputs to fetch.
@@ -978,6 +976,7 @@ bool CTransaction::FetchInputs(CTxDB& txdb, const map<uint256, CTxIndex>& mapTes
     for (int i = 0; i < vin.size(); i++)
     {
         const COutPoint prevout = vin[i].prevout;
+        assert(inputsRet.count(prevout.hash) != 0);
         const CTxIndex& txindex = inputsRet[prevout.hash].first;
         const CTransaction& txPrev = inputsRet[prevout.hash].second;
         if (prevout.n >= txPrev.vout.size() || prevout.n >= txindex.vSpent.size())
@@ -987,9 +986,49 @@ bool CTransaction::FetchInputs(CTxDB& txdb, const map<uint256, CTxIndex>& mapTes
     return true;
 }
 
-bool CTransaction::ConnectInputs(map<uint256, pair<CTxIndex, CTransaction> > inputs,
+const CTxOut& CTransaction::GetOutputFor(const CTxIn& input, const MapPrevTx& inputs) const
+{
+    MapPrevTx::const_iterator mi = inputs.find(input.prevout.hash);
+    if (mi == inputs.end())
+        throw std::runtime_error("CTransaction::GetOutputFor() : prevout.hash not found");
+
+    const CTransaction& txPrev = (mi->second).second;
+    if (input.prevout.n >= txPrev.vout.size())
+        throw std::runtime_error("CTransaction::GetOutputFor() : prevout.n out of range");
+
+    return txPrev.vout[input.prevout.n];
+}
+
+int64 CTransaction::GetValueIn(const MapPrevTx& inputs) const
+{
+    if (IsCoinBase())
+        return 0;
+
+    int64 nResult = 0;
+    for (int i = 0; i < vin.size(); i++)
+    {
+        nResult += GetOutputFor(vin[i], inputs).nValue;
+    }
+    return nResult;
+
+}
+
+int CTransaction::GetSigOpCount(const MapPrevTx& inputs) const
+{
+    if (IsCoinBase())
+        return 0;
+
+    int nSigOps = 0;
+    for (int i = 0; i < vin.size(); i++)
+    {
+        nSigOps += GetOutputFor(vin[i], inputs).scriptPubKey.GetSigOpCount(vin[i].scriptSig);
+    }
+    return nSigOps;
+}
+
+bool CTransaction::ConnectInputs(MapPrevTx inputs,
                                  map<uint256, CTxIndex>& mapTestPool, const CDiskTxPos& posThisTx,
-                                 const CBlockIndex* pindexBlock, int64& nFees, bool fBlock, bool fMiner, int& nSigOpsRet, int64 nMinFee)
+                                 const CBlockIndex* pindexBlock, bool fBlock, bool fMiner)
 {
     // Take over previous transactions' spent pointers
     // fBlock is true when this is called from AcceptBlock when a new best-block is added to the blockchain
@@ -998,6 +1037,7 @@ bool CTransaction::ConnectInputs(map<uint256, pair<CTxIndex, CTransaction> > inp
     if (!IsCoinBase())
     {
         int64 nValueIn = 0;
+        int64 nFees = 0;
         for (int i = 0; i < vin.size(); i++)
         {
             COutPoint prevout = vin[i].prevout;
@@ -1014,6 +1054,17 @@ bool CTransaction::ConnectInputs(map<uint256, pair<CTxIndex, CTransaction> > inp
                     if (pindex->nBlockPos == txindex.pos.nBlockPos && pindex->nFile == txindex.pos.nFile)
                         return error("ConnectInputs() : tried to spend coinbase at depth %d", pindexBlock->nHeight - pindex->nHeight);
 
+            // Check for conflicts (double-spend)
+            // This doesn't trigger the DoS code on purpose; if it did, it would make it easier
+            // for an attacker to attempt to split the network.
+            if (!txindex.vSpent[prevout.n].IsNull())
+                return fMiner ? false : error("ConnectInputs() : %s prev tx already used at %s", GetHash().ToString().substr(0,10).c_str(), txindex.vSpent[prevout.n].ToString().c_str());
+
+            // Check for negative or overflow input values
+            nValueIn += txPrev.vout[prevout.n].nValue;
+            if (!MoneyRange(txPrev.vout[prevout.n].nValue) || !MoneyRange(nValueIn))
+                return DoS(100, error("ConnectInputs() : txin values out of range"));
+
             bool fStrictPayToScriptHash = true;
             if (fBlock)
             {
@@ -1038,20 +1089,6 @@ bool CTransaction::ConnectInputs(map<uint256, pair<CTxIndex, CTransaction> > inp
                     return DoS(100,error("ConnectInputs() : %s VerifySignature failed", GetHash().ToString().substr(0,10).c_str()));
             }
 
-            // Check for conflicts (double-spend)
-            // This doesn't trigger the DoS code on purpose; if it did, it would make it easier
-            // for an attacker to attempt to split the network.
-            if (!txindex.vSpent[prevout.n].IsNull())
-                return fMiner ? false : error("ConnectInputs() : %s prev tx already used at %s", GetHash().ToString().substr(0,10).c_str(), txindex.vSpent[prevout.n].ToString().c_str());
-
-            // Check for negative or overflow input values
-            nValueIn += txPrev.vout[prevout.n].nValue;
-            if (!MoneyRange(txPrev.vout[prevout.n].nValue) || !MoneyRange(nValueIn))
-                return DoS(100, error("ConnectInputs() : txin values out of range"));
-
-            // Calculate sigOps accurately:
-            nSigOpsRet += txPrev.vout[prevout.n].scriptPubKey.GetSigOpCount(vin[i].scriptSig);
-
             // Mark outpoints as spent
             txindex.vSpent[prevout.n] = posThisTx;
 
@@ -1069,8 +1106,6 @@ bool CTransaction::ConnectInputs(map<uint256, pair<CTxIndex, CTransaction> > inp
         int64 nTxFee = nValueIn - GetValueOut();
         if (nTxFee < 0)
             return DoS(100, error("ConnectInputs() : %s nTxFee < 0", GetHash().ToString().substr(0,10).c_str()));
-        if (nTxFee < nMinFee)
-            return false;
         nFees += nTxFee;
         if (!MoneyRange(nFees))
             return DoS(100, error("ConnectInputs() : nFees out of range"));
@@ -1176,20 +1211,27 @@ bool CBlock::ConnectBlock(CTxDB& txdb, CBlockIndex* pindex)
         CDiskTxPos posThisTx(pindex->nFile, pindex->nBlockPos, nTxPos);
         nTxPos += ::GetSerializeSize(tx, SER_DISK);
 
-        map<uint256, pair<CTxIndex, CTransaction> > mapInputs;
-        if (!tx.FetchInputs(txdb, mapQueuedChanges, true, false, mapInputs))
-            return false;
+        MapPrevTx mapInputs;
+        if (!tx.IsCoinBase())
+        {
+            if (!tx.FetchInputs(txdb, mapQueuedChanges, true, false, mapInputs))
+                return false;
 
-        int nTxOps = 0;
-        if (!tx.ConnectInputs(mapInputs, mapQueuedChanges, posThisTx, pindex, nFees, true, false, nTxOps))
-            return false;
+            int nTxOps = tx.GetSigOpCount(mapInputs);
+            nSigOps += nTxOps;
+            if (nSigOps > MAX_BLOCK_SIGOPS)
+                return DoS(100, error("ConnectBlock() : too many sigops"));
+            // There is a different MAX_BLOCK_SIGOPS check in AcceptBlock();
+            // a block must satisfy both to make it into the best-chain
+            // (AcceptBlock() is always called before ConnectBlock())
 
-        nSigOps += nTxOps;
-        if (nSigOps > MAX_BLOCK_SIGOPS)
-            return DoS(100, error("ConnectBlock() : too many sigops"));
-        // There is a different MAX_BLOCK_SIGOPS check in AcceptBlock();
-        // a block must satisfy both to make it into the best-chain
-        // (AcceptBlock() is always called before ConnectBlock())
+            nFees += tx.GetValueIn(mapInputs)-tx.GetValueOut();
+        }
+
+        // It seems wrong that ConnectInputs must be called on the coinbase transaction
+        // (which has no inputs) : TODO: refactor the code at the end of ConnectInputs out...
+        if (!tx.ConnectInputs(mapInputs, mapQueuedChanges, posThisTx, pindex, true, false))
+            return false;
     }
 
     // Write queued txindex changes
@@ -3031,15 +3073,20 @@ CBlock* CreateNewBlock(CReserveKey& reservekey)
             // Connecting shouldn't fail due to dependency on other memory pool transactions
             // because we're already processing them in order of dependency
             map<uint256, CTxIndex> mapTestPoolTmp(mapTestPool);
-            map<uint256, pair<CTxIndex, CTransaction> > mapInputs;
+            MapPrevTx mapInputs;
             if (!tx.FetchInputs(txdb, mapTestPoolTmp, false, true, mapInputs))
                 continue;
 
-            int nTxSigOps2 = 0;
-            if (!tx.ConnectInputs(mapInputs, mapTestPoolTmp, CDiskTxPos(1,1,1), pindexPrev, nFees, false, true, nTxSigOps2, nMinFee))
+            int64 nFees = tx.GetValueIn(mapInputs)-tx.GetValueOut();
+            if (nFees < nMinFee)
                 continue;
+
+            int nTxSigOps2 = tx.GetSigOpCount(mapInputs);
             if (nBlockSigOps2 + nTxSigOps2 >= MAX_BLOCK_SIGOPS)
                 continue;
+
+            if (!tx.ConnectInputs(mapInputs, mapTestPoolTmp, CDiskTxPos(1,1,1), pindexPrev, false, true))
+                continue;
             swap(mapTestPool, mapTestPoolTmp);
 
             // Added
index 6f71fd2..678c14c 100644 (file)
@@ -402,6 +402,8 @@ enum GetMinFee_mode
     GMF_SEND,
 };
 
+typedef std::map<uint256, std::pair<CTxIndex, CTransaction> > MapPrevTx;
+
 //
 // The basic transaction that is broadcasted on the network and contained in
 // blocks.  A transaction can contain multiple inputs and outputs.
@@ -502,11 +504,36 @@ public:
         return (vin.size() == 1 && vin[0].prevout.IsNull());
     }
 
+    /** Check for standard transaction types
+        @return True if all outputs (scriptPubKeys) use only standard transaction forms
+    */
     bool IsStandard() const;
-    bool AreInputsStandard(const std::map<uint256, std::pair<CTxIndex, CTransaction> >& mapInputs) const;
 
+    /** Check for standard transaction types
+        @param[in] mapInputs   Map of previous transactions that have outputs we're spending
+        @return True if all inputs (scriptSigs) use only standard transaction forms
+        @see CTransaction::FetchInputs
+    */
+    bool AreInputsStandard(const MapPrevTx& mapInputs) const;
+
+    /** Count ECDSA signature operations the old-fashioned (pre-0.6) way
+        @return number of sigops this transaction's outputs will produce when spent
+        @see CTransaction::FetchInputs
+    */
     int GetLegacySigOpCount() const;
 
+    /** Count ECDSA signature operations the new (0.6-and-later) way
+        This is a better measure of how expensive it is to process this transaction.
+
+        @param[in] mapInputs   Map of previous transactions that have outputs we're spending
+        @return maximum number of sigops required to validate this transaction's inputs
+        @see CTransaction::FetchInputs
+     */
+    int GetSigOpCount(const MapPrevTx& mapInputs) const;
+
+    /** Amount of bitcoins spent by this transaction.
+        @return sum of all outputs (note: does not include fees)
+     */
     int64 GetValueOut() const
     {
         int64 nValueOut = 0;
@@ -519,6 +546,16 @@ public:
         return nValueOut;
     }
 
+    /** Amount of bitcoins coming in to this transaction
+        Note that lightweight clients may not know anything besides the hash of previous transactions,
+        so may not be able to calculate this.
+
+        @param[in] mapInputs   Map of previous transactions that have outputs we're spending
+        @return        Sum of value of all inputs (scriptSigs)
+        @see CTransaction::FetchInputs
+     */
+    int64 GetValueIn(const MapPrevTx& mapInputs) const;
+
     static bool AllowFree(double dPriority)
     {
         // Large (in bytes) low-priority (new, small-coin) transactions
@@ -634,17 +671,39 @@ public:
     bool ReadFromDisk(COutPoint prevout);
     bool DisconnectInputs(CTxDB& txdb);
 
-    // Fetch from memory and/or disk. inputsRet keys are transaction hashes.
+    /** Fetch from memory and/or disk. inputsRet keys are transaction hashes.
+
+     @param[in] txdb   Transaction database
+     @param[in] mapTestPool    List of pending changes to the transaction index database
+     @param[in] fBlock True if being called to add a new best-block to the chain
+     @param[in] fMiner True if being called by CreateNewBlock
+     @param[out] inputsRet     Pointers to this transaction's inputs
+     @return   Returns true if all inputs are in txdb or mapTestPool
+     */
     bool FetchInputs(CTxDB& txdb, const std::map<uint256, CTxIndex>& mapTestPool,
-                     bool fBlock, bool fMiner, std::map<uint256, std::pair<CTxIndex, CTransaction> >& inputsRet);
-    bool ConnectInputs(std::map<uint256, std::pair<CTxIndex, CTransaction> > inputs,
+                     bool fBlock, bool fMiner, MapPrevTx& inputsRet);
+
+    /** Sanity check previous transactions, then, if all checks succeed,
+        mark them as spent by this transaction.
+
+        @param[in] inputs      Previous transactions (from FetchInputs)
+        @param[out] mapTestPool        Keeps track of inputs that need to be updated on disk
+        @param[in] posThisTx   Position of this transaction on disk
+        @param[in] pindexBlock
+        @param[in] fBlock      true if called from ConnectBlock
+        @param[in] fMiner      true if called from CreateNewBlock
+        @return Returns true if all checks succeed
+     */
+    bool ConnectInputs(MapPrevTx inputs,
                        std::map<uint256, CTxIndex>& mapTestPool, const CDiskTxPos& posThisTx,
-                       const CBlockIndex* pindexBlock, int64& nFees, bool fBlock, bool fMiner, int& nSigOpsRet, int64 nMinFee=0);
+                       const CBlockIndex* pindexBlock, bool fBlock, bool fMiner);
     bool ClientConnectInputs();
     bool CheckTransaction() const;
     bool AcceptToMemoryPool(CTxDB& txdb, bool fCheckInputs=true, bool* pfMissingInputs=NULL);
     bool AcceptToMemoryPool(bool fCheckInputs=true, bool* pfMissingInputs=NULL);
+
 protected:
+    const CTxOut& GetOutputFor(const CTxIn& input, const MapPrevTx& inputs) const;
     bool AddToMemoryPoolUnchecked();
 public:
     bool RemoveFromMemoryPool();
index 3268343..33765ca 100644 (file)
@@ -22,4 +22,89 @@ BOOST_AUTO_TEST_CASE(basic_transaction_tests)
     BOOST_CHECK_MESSAGE(!tx.CheckTransaction(), "Transaction with duplicate txins should be invalid.");
 }
 
+//
+// Helper: create two dummy transactions, each with
+// two outputs.  The first has 11 and 50 CENT outputs,
+// the second 21 and 22 CENT outputs.
+//
+static std::vector<CTransaction>
+SetupDummyInputs(CBasicKeyStore& keystoreRet, MapPrevTx& inputsRet)
+{
+    std::vector<CTransaction> dummyTransactions;
+    dummyTransactions.resize(2);
+
+    // Add some keys to the keystore:
+    CKey key[4];
+    for (int i = 0; i < 4; i++)
+    {
+        key[i].MakeNewKey();
+        keystoreRet.AddKey(key[i]);
+    }
+
+    // Create some dummy input transactions
+    dummyTransactions[0].vout.resize(2);
+    dummyTransactions[0].vout[0].nValue = 11*CENT;
+    dummyTransactions[0].vout[0].scriptPubKey.SetBitcoinAddress(key[0].GetPubKey());
+    dummyTransactions[0].vout[1].nValue = 50*CENT;
+    dummyTransactions[0].vout[1].scriptPubKey.SetBitcoinAddress(key[1].GetPubKey());
+    inputsRet[dummyTransactions[0].GetHash()] = make_pair(CTxIndex(), dummyTransactions[0]);
+
+    dummyTransactions[1].vout.resize(2);
+    dummyTransactions[1].vout[0].nValue = 21*CENT;
+    dummyTransactions[1].vout[0].scriptPubKey.SetBitcoinAddress(key[2].GetPubKey());
+    dummyTransactions[1].vout[1].nValue = 22*CENT;
+    dummyTransactions[1].vout[1].scriptPubKey.SetBitcoinAddress(key[3].GetPubKey());
+    inputsRet[dummyTransactions[1].GetHash()] = make_pair(CTxIndex(), dummyTransactions[1]);
+
+    return dummyTransactions;
+}
+
+BOOST_AUTO_TEST_CASE(test_Get)
+{
+    CBasicKeyStore keystore;
+    MapPrevTx dummyInputs;
+    std::vector<CTransaction> dummyTransactions = SetupDummyInputs(keystore, dummyInputs);
+
+    CTransaction t1;
+    t1.vin.resize(3);
+    t1.vin[0].prevout.hash = dummyTransactions[0].GetHash();
+    t1.vin[0].prevout.n = 1;
+    t1.vin[1].prevout.hash = dummyTransactions[1].GetHash();;
+    t1.vin[1].prevout.n = 0;
+    t1.vin[2].prevout.hash = dummyTransactions[1].GetHash();;
+    t1.vin[2].prevout.n = 1;
+    t1.vout.resize(2);
+    t1.vout[0].nValue = 90*CENT;
+    t1.vout[0].scriptPubKey << OP_1;
+
+    BOOST_CHECK(t1.AreInputsStandard(dummyInputs));
+    BOOST_CHECK_EQUAL(t1.GetSigOpCount(dummyInputs), 3);
+    BOOST_CHECK_EQUAL(t1.GetValueIn(dummyInputs), (50+21+22)*CENT);
+}
+
+BOOST_AUTO_TEST_CASE(test_GetThrow)
+{
+    CBasicKeyStore keystore;
+    MapPrevTx dummyInputs;
+    std::vector<CTransaction> dummyTransactions = SetupDummyInputs(keystore, dummyInputs);
+
+    MapPrevTx missingInputs;
+
+    CTransaction t1;
+    t1.vin.resize(3);
+    t1.vin[0].prevout.hash = dummyTransactions[0].GetHash();
+    t1.vin[0].prevout.n = 0;
+    t1.vin[1].prevout.hash = dummyTransactions[1].GetHash();;
+    t1.vin[1].prevout.n = 0;
+    t1.vin[2].prevout.hash = dummyTransactions[1].GetHash();;
+    t1.vin[2].prevout.n = 1;
+    t1.vout.resize(2);
+    t1.vout[0].nValue = 90*CENT;
+    t1.vout[0].scriptPubKey << OP_1;
+
+    BOOST_CHECK_THROW(t1.AreInputsStandard(missingInputs), runtime_error);
+    BOOST_CHECK_THROW(t1.GetSigOpCount(missingInputs), runtime_error);
+    BOOST_CHECK_THROW(t1.GetValueIn(missingInputs), runtime_error);
+}
+
 BOOST_AUTO_TEST_SUITE_END()