/ src / wallet / test / walletload_tests.cpp
walletload_tests.cpp
  1  // Copyright (c) 2022 The Bitcoin Core developers
  2  // Distributed under the MIT software license, see the accompanying
  3  // file COPYING or https://www.opensource.org/licenses/mit-license.php.
  4  
  5  #include <wallet/test/util.h>
  6  #include <wallet/wallet.h>
  7  #include <test/util/logging.h>
  8  #include <test/util/setup_common.h>
  9  
 10  #include <boost/test/unit_test.hpp>
 11  
 12  namespace wallet {
 13  
 14  BOOST_AUTO_TEST_SUITE(walletload_tests)
 15  
 16  class DummyDescriptor final : public Descriptor {
 17  private:
 18      std::string desc;
 19  public:
 20      explicit DummyDescriptor(const std::string& descriptor) : desc(descriptor) {};
 21      ~DummyDescriptor() = default;
 22  
 23      std::string ToString(bool compat_format) const override { return desc; }
 24      std::optional<OutputType> GetOutputType() const override { return OutputType::UNKNOWN; }
 25  
 26      bool IsRange() const override { return false; }
 27      bool IsSolvable() const override { return false; }
 28      bool IsSingleType() const override { return true; }
 29      bool ToPrivateString(const SigningProvider& provider, std::string& out) const override { return false; }
 30      bool ToNormalizedString(const SigningProvider& provider, std::string& out, const DescriptorCache* cache = nullptr) const override { return false; }
 31      bool Expand(int pos, const SigningProvider& provider, std::vector<CScript>& output_scripts, FlatSigningProvider& out, DescriptorCache* write_cache = nullptr) const override { return false; };
 32      bool ExpandFromCache(int pos, const DescriptorCache& read_cache, std::vector<CScript>& output_scripts, FlatSigningProvider& out) const override { return false; }
 33      void ExpandPrivate(int pos, const SigningProvider& provider, FlatSigningProvider& out) const override {}
 34      std::optional<int64_t> ScriptSize() const override { return {}; }
 35      std::optional<int64_t> MaxSatisfactionWeight(bool) const override { return {}; }
 36      std::optional<int64_t> MaxSatisfactionElems() const override { return {}; }
 37  };
 38  
 39  BOOST_FIXTURE_TEST_CASE(wallet_load_descriptors, TestingSetup)
 40  {
 41      std::unique_ptr<WalletDatabase> database = CreateMockableWalletDatabase();
 42      {
 43          // Write unknown active descriptor
 44          WalletBatch batch(*database, false);
 45          std::string unknown_desc = "trx(tpubD6NzVbkrYhZ4Y4S7m6Y5s9GD8FqEMBy56AGphZXuagajudVZEnYyBahZMgHNCTJc2at82YX6s8JiL1Lohu5A3v1Ur76qguNH4QVQ7qYrBQx/86'/1'/0'/0/*)#8pn8tzdt";
 46          WalletDescriptor wallet_descriptor(std::make_shared<DummyDescriptor>(unknown_desc), 0, 0, 0, 0);
 47          BOOST_CHECK(batch.WriteDescriptor(uint256(), wallet_descriptor));
 48          BOOST_CHECK(batch.WriteActiveScriptPubKeyMan(static_cast<uint8_t>(OutputType::UNKNOWN), uint256(), false));
 49      }
 50  
 51      {
 52          // Now try to load the wallet and verify the error.
 53          const std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", std::move(database)));
 54          BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::UNKNOWN_DESCRIPTOR);
 55      }
 56  
 57      // Test 2
 58      // Now write a valid descriptor with an invalid ID.
 59      // As the software produces another ID for the descriptor, the loading process must be aborted.
 60      database = CreateMockableWalletDatabase();
 61  
 62      // Verify the error
 63      bool found = false;
 64      DebugLogHelper logHelper("The descriptor ID calculated by the wallet differs from the one in DB", [&](const std::string* s) {
 65          found = true;
 66          return false;
 67      });
 68  
 69      {
 70          // Write valid descriptor with invalid ID
 71          WalletBatch batch(*database, false);
 72          std::string desc = "wpkh([d34db33f/84h/0h/0h]xpub6DJ2dNUysrn5Vt36jH2KLBT2i1auw1tTSSomg8PhqNiUtx8QX2SvC9nrHu81fT41fvDUnhMjEzQgXnQjKEu3oaqMSzhSrHMxyyoEAmUHQbY/0/*)#cjjspncu";
 73          WalletDescriptor wallet_descriptor(std::make_shared<DummyDescriptor>(desc), 0, 0, 0, 0);
 74          BOOST_CHECK(batch.WriteDescriptor(uint256::ONE, wallet_descriptor));
 75      }
 76  
 77      {
 78          // Now try to load the wallet and verify the error.
 79          const std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", std::move(database)));
 80          BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::CORRUPT);
 81          BOOST_CHECK(found); // The error must be logged
 82      }
 83  }
 84  
 85  bool HasAnyRecordOfType(WalletDatabase& db, const std::string& key)
 86  {
 87      std::unique_ptr<DatabaseBatch> batch = db.MakeBatch(false);
 88      BOOST_CHECK(batch);
 89      std::unique_ptr<DatabaseCursor> cursor = batch->GetNewCursor();
 90      BOOST_CHECK(cursor);
 91      while (true) {
 92          DataStream ssKey{};
 93          DataStream ssValue{};
 94          DatabaseCursor::Status status = cursor->Next(ssKey, ssValue);
 95          assert(status != DatabaseCursor::Status::FAIL);
 96          if (status == DatabaseCursor::Status::DONE) break;
 97          std::string type;
 98          ssKey >> type;
 99          if (type == key) return true;
100      }
101      return false;
102  }
103  
104  template<typename... Args>
105  SerializeData MakeSerializeData(const Args&... args)
106  {
107      DataStream s{};
108      SerializeMany(s, args...);
109      return {s.begin(), s.end()};
110  }
111  
112  
113  BOOST_FIXTURE_TEST_CASE(wallet_load_ckey, TestingSetup)
114  {
115      SerializeData ckey_record_key;
116      SerializeData ckey_record_value;
117      MockableData records;
118  
119      {
120          // Context setup.
121          // Create and encrypt legacy wallet
122          std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockableWalletDatabase()));
123          LOCK(wallet->cs_wallet);
124          auto legacy_spkm = wallet->GetOrCreateLegacyScriptPubKeyMan();
125          BOOST_CHECK(legacy_spkm->SetupGeneration(true));
126  
127          // Retrieve a key
128          CTxDestination dest = *Assert(legacy_spkm->GetNewDestination(OutputType::LEGACY));
129          CKeyID key_id = GetKeyForDestination(*legacy_spkm, dest);
130          CKey first_key;
131          BOOST_CHECK(legacy_spkm->GetKey(key_id, first_key));
132  
133          // Encrypt the wallet
134          BOOST_CHECK(wallet->EncryptWallet("encrypt"));
135          wallet->Flush();
136  
137          // Store a copy of all the records
138          records = GetMockableDatabase(*wallet).m_records;
139  
140          // Get the record for the retrieved key
141          ckey_record_key = MakeSerializeData(DBKeys::CRYPTED_KEY, first_key.GetPubKey());
142          ckey_record_value = records.at(ckey_record_key);
143      }
144  
145      {
146          // First test case:
147          // Erase all the crypted keys from db and unlock the wallet.
148          // The wallet will only re-write the crypted keys to db if any checksum is missing at load time.
149          // So, if any 'ckey' record re-appears on db, then the checksums were not properly calculated, and we are re-writing
150          // the records every time that 'CWallet::Unlock' gets called, which is not good.
151  
152          // Load the wallet and check that is encrypted
153          std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockableWalletDatabase(records)));
154          BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::LOAD_OK);
155          BOOST_CHECK(wallet->IsCrypted());
156          BOOST_CHECK(HasAnyRecordOfType(wallet->GetDatabase(), DBKeys::CRYPTED_KEY));
157  
158          // Now delete all records and check that the 'Unlock' function doesn't re-write them
159          BOOST_CHECK(wallet->GetLegacyScriptPubKeyMan()->DeleteRecords());
160          BOOST_CHECK(!HasAnyRecordOfType(wallet->GetDatabase(), DBKeys::CRYPTED_KEY));
161          BOOST_CHECK(wallet->Unlock("encrypt"));
162          BOOST_CHECK(!HasAnyRecordOfType(wallet->GetDatabase(), DBKeys::CRYPTED_KEY));
163      }
164  
165      {
166          // Second test case:
167          // Verify that loading up a 'ckey' with no checksum triggers a complete re-write of the crypted keys.
168  
169          // Cut off the 32 byte checksum from a ckey record
170          records[ckey_record_key].resize(ckey_record_value.size() - 32);
171  
172          // Load the wallet and check that is encrypted
173          std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockableWalletDatabase(records)));
174          BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::LOAD_OK);
175          BOOST_CHECK(wallet->IsCrypted());
176          BOOST_CHECK(HasAnyRecordOfType(wallet->GetDatabase(), DBKeys::CRYPTED_KEY));
177  
178          // Now delete all ckey records and check that the 'Unlock' function re-writes them
179          // (this is because the wallet, at load time, found a ckey record with no checksum)
180          BOOST_CHECK(wallet->GetLegacyScriptPubKeyMan()->DeleteRecords());
181          BOOST_CHECK(!HasAnyRecordOfType(wallet->GetDatabase(), DBKeys::CRYPTED_KEY));
182          BOOST_CHECK(wallet->Unlock("encrypt"));
183          BOOST_CHECK(HasAnyRecordOfType(wallet->GetDatabase(), DBKeys::CRYPTED_KEY));
184      }
185  
186      {
187          // Third test case:
188          // Verify that loading up a 'ckey' with an invalid checksum throws an error.
189  
190          // Cut off the 32 byte checksum from a ckey record
191          records[ckey_record_key].resize(ckey_record_value.size() - 32);
192          // Fill in the checksum space with 0s
193          records[ckey_record_key].resize(ckey_record_value.size());
194  
195          std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockableWalletDatabase(records)));
196          BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::CORRUPT);
197      }
198  
199      {
200          // Fourth test case:
201          // Verify that loading up a 'ckey' with an invalid pubkey throws an error
202          CPubKey invalid_key;
203          BOOST_CHECK(!invalid_key.IsValid());
204          SerializeData key = MakeSerializeData(DBKeys::CRYPTED_KEY, invalid_key);
205          records[key] = ckey_record_value;
206  
207          std::shared_ptr<CWallet> wallet(new CWallet(m_node.chain.get(), "", CreateMockableWalletDatabase(records)));
208          BOOST_CHECK_EQUAL(wallet->LoadWallet(), DBErrors::CORRUPT);
209      }
210  }
211  
212  BOOST_AUTO_TEST_SUITE_END()
213  } // namespace wallet