/ test / functional / wallet_import_rescan.py
wallet_import_rescan.py
  1  #!/usr/bin/env python3
  2  # Copyright (c) 2014-2022 The Bitcoin Core developers
  3  # Distributed under the MIT software license, see the accompanying
  4  # file COPYING or http://www.opensource.org/licenses/mit-license.php.
  5  """Test wallet import RPCs.
  6  
  7  Test rescan behavior of importaddress, importpubkey, importprivkey, and
  8  importmulti RPCs with different types of keys and rescan options.
  9  
 10  In the first part of the test, node 0 creates an address for each type of
 11  import RPC call and sends BTC to it. Then other nodes import the addresses,
 12  and the test makes listtransactions and getbalance calls to confirm that the
 13  importing node either did or did not execute rescans picking up the send
 14  transactions.
 15  
 16  In the second part of the test, node 0 sends more BTC to each address, and the
 17  test makes more listtransactions and getbalance calls to confirm that the
 18  importing nodes pick up the new transactions regardless of whether rescans
 19  happened previously.
 20  """
 21  
 22  from test_framework.test_framework import BitcoinTestFramework
 23  from test_framework.address import (
 24      AddressType,
 25      ADDRESS_BCRT1_UNSPENDABLE,
 26  )
 27  from test_framework.messages import COIN
 28  from test_framework.util import (
 29      assert_equal,
 30      set_node_times,
 31  )
 32  
 33  import collections
 34  from decimal import Decimal
 35  import enum
 36  import itertools
 37  import random
 38  
 39  Call = enum.Enum("Call", "single multiaddress multiscript")
 40  Data = enum.Enum("Data", "address pub priv")
 41  Rescan = enum.Enum("Rescan", "no yes late_timestamp")
 42  
 43  
 44  class Variant(collections.namedtuple("Variant", "call data address_type rescan prune")):
 45      """Helper for importing one key and verifying scanned transactions."""
 46      def do_import(self, timestamp):
 47          """Call one key import RPC."""
 48          rescan = self.rescan == Rescan.yes
 49  
 50          assert_equal(self.address["solvable"], True)
 51          assert_equal(self.address["isscript"], self.address_type == AddressType.p2sh_segwit)
 52          assert_equal(self.address["iswitness"], self.address_type == AddressType.bech32)
 53          if self.address["isscript"]:
 54              assert_equal(self.address["embedded"]["isscript"], False)
 55              assert_equal(self.address["embedded"]["iswitness"], True)
 56  
 57          if self.call == Call.single:
 58              if self.data == Data.address:
 59                  response = self.node.importaddress(address=self.address["address"], label=self.label, rescan=rescan)
 60              elif self.data == Data.pub:
 61                  response = self.node.importpubkey(pubkey=self.address["pubkey"], label=self.label, rescan=rescan)
 62              elif self.data == Data.priv:
 63                  response = self.node.importprivkey(privkey=self.key, label=self.label, rescan=rescan)
 64              assert_equal(response, None)
 65  
 66          elif self.call in (Call.multiaddress, Call.multiscript):
 67              request = {
 68                  "scriptPubKey": {
 69                      "address": self.address["address"]
 70                  } if self.call == Call.multiaddress else self.address["scriptPubKey"],
 71                  "timestamp": timestamp + TIMESTAMP_WINDOW + (1 if self.rescan == Rescan.late_timestamp else 0),
 72                  "pubkeys": [self.address["pubkey"]] if self.data == Data.pub else [],
 73                  "keys": [self.key] if self.data == Data.priv else [],
 74                  "label": self.label,
 75                  "watchonly": self.data != Data.priv
 76              }
 77              if self.address_type == AddressType.p2sh_segwit and self.data != Data.address:
 78                  # We need solving data when providing a pubkey or privkey as data
 79                  request.update({"redeemscript": self.address['embedded']['scriptPubKey']})
 80              response = self.node.importmulti(
 81                  requests=[request],
 82                  rescan=self.rescan in (Rescan.yes, Rescan.late_timestamp),
 83              )
 84              assert_equal(response, [{"success": True}])
 85  
 86      def check(self, txid=None, amount=None, confirmation_height=None):
 87          """Verify that listtransactions/listreceivedbyaddress return expected values."""
 88  
 89          txs = self.node.listtransactions(label=self.label, count=10000, include_watchonly=True)
 90          current_height = self.node.getblockcount()
 91          assert_equal(len(txs), self.expected_txs)
 92  
 93          addresses = self.node.listreceivedbyaddress(minconf=0, include_watchonly=True, address_filter=self.address['address'])
 94  
 95          if self.expected_txs:
 96              assert_equal(len(addresses[0]["txids"]), self.expected_txs)
 97  
 98          if txid is not None:
 99              tx, = [tx for tx in txs if tx["txid"] == txid]
100              assert_equal(tx["label"], self.label)
101              assert_equal(tx["address"], self.address["address"])
102              assert_equal(tx["amount"], amount)
103              assert_equal(tx["category"], "receive")
104              assert_equal(tx["label"], self.label)
105              assert_equal(tx["txid"], txid)
106  
107              # If no confirmation height is given, the tx is still in the
108              # mempool.
109              confirmations = (1 + current_height - confirmation_height) if confirmation_height else 0
110              assert_equal(tx["confirmations"], confirmations)
111              if confirmations:
112                  assert "trusted" not in tx
113  
114              address, = [ad for ad in addresses if txid in ad["txids"]]
115              assert_equal(address["address"], self.address["address"])
116              assert_equal(address["amount"], self.amount_received)
117              assert_equal(address["confirmations"], confirmations)
118              # Verify the transaction is correctly marked watchonly depending on
119              # whether the transaction pays to an imported public key or
120              # imported private key. The test setup ensures that transaction
121              # inputs will not be from watchonly keys (important because
122              # involvesWatchonly will be true if either the transaction output
123              # or inputs are watchonly).
124              if self.data != Data.priv:
125                  assert_equal(address["involvesWatchonly"], True)
126              else:
127                  assert_equal("involvesWatchonly" not in address, True)
128  
129  
130  # List of Variants for each way a key or address could be imported.
131  IMPORT_VARIANTS = [Variant(*variants) for variants in itertools.product(Call, Data, AddressType, Rescan, (False, True))]
132  
133  # List of nodes to import keys to. Half the nodes will have pruning disabled,
134  # half will have it enabled. Different nodes will be used for imports that are
135  # expected to cause rescans, and imports that are not expected to cause
136  # rescans, in order to prevent rescans during later imports picking up
137  # transactions associated with earlier imports. This makes it easier to keep
138  # track of expected balances and transactions.
139  ImportNode = collections.namedtuple("ImportNode", "prune rescan")
140  IMPORT_NODES = [ImportNode(*fields) for fields in itertools.product((False, True), repeat=2)]
141  
142  # Rescans start at the earliest block up to 2 hours before the key timestamp.
143  TIMESTAMP_WINDOW = 2 * 60 * 60
144  
145  AMOUNT_DUST = 0.00000546
146  
147  
148  def get_rand_amount(min_amount=AMOUNT_DUST):
149      assert min_amount <= 1
150      r = random.uniform(min_amount, 1)
151      # note: min_amount can get rounded down here
152      return Decimal(str(round(r, 8)))
153  
154  
155  class ImportRescanTest(BitcoinTestFramework):
156      def add_options(self, parser):
157          self.add_wallet_options(parser, descriptors=False)
158  
159      def set_test_params(self):
160          self.num_nodes = 2 + len(IMPORT_NODES)
161          self.supports_cli = False
162          self.rpc_timeout = 120
163          # whitelist peers to speed up tx relay / mempool sync
164          self.noban_tx_relay = True
165  
166      def skip_test_if_missing_module(self):
167          self.skip_if_no_wallet()
168  
169      def setup_network(self):
170          self.extra_args = [[] for _ in range(self.num_nodes)]
171          for i, import_node in enumerate(IMPORT_NODES, 2):
172              if import_node.prune:
173                  self.extra_args[i] += ["-prune=1"]
174  
175          self.add_nodes(self.num_nodes, extra_args=self.extra_args)
176  
177          # Import keys with pruning disabled
178          self.start_nodes(extra_args=[[]] * self.num_nodes)
179          self.import_deterministic_coinbase_privkeys()
180          self.stop_nodes()
181  
182          self.start_nodes()
183          for i in range(1, self.num_nodes):
184              self.connect_nodes(i, 0)
185  
186      def run_test(self):
187  
188          # Create one transaction on node 0 with a unique amount for
189          # each possible type of wallet import RPC.
190          last_variants = []
191          for i, variant in enumerate(IMPORT_VARIANTS):
192              if i % 10 == 0:
193                  blockhash = self.generate(self.nodes[0], 1)[0]
194                  conf_height = self.nodes[0].getblockcount()
195                  timestamp = self.nodes[0].getblockheader(blockhash)["time"]
196                  for var in last_variants:
197                      var.confirmation_height = conf_height
198                      var.timestamp = timestamp
199                  last_variants.clear()
200              variant.label = "label {} {}".format(i, variant)
201              variant.address = self.nodes[1].getaddressinfo(self.nodes[1].getnewaddress(
202                  label=variant.label,
203                  address_type=variant.address_type.value,
204              ))
205              variant.key = self.nodes[1].dumpprivkey(variant.address["address"])
206              variant.initial_amount = get_rand_amount()
207              variant.initial_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.initial_amount)
208              last_variants.append(variant)
209  
210          blockhash = self.generate(self.nodes[0], 1)[0]
211          conf_height = self.nodes[0].getblockcount()
212          timestamp = self.nodes[0].getblockheader(blockhash)["time"]
213          for var in last_variants:
214              var.confirmation_height = conf_height
215              var.timestamp = timestamp
216          last_variants.clear()
217  
218          # Generate a block further in the future (past the rescan window).
219          assert_equal(self.nodes[0].getrawmempool(), [])
220          set_node_times(
221              self.nodes,
222              self.nodes[0].getblockheader(self.nodes[0].getbestblockhash())["time"] + TIMESTAMP_WINDOW + 1,
223          )
224          self.generate(self.nodes[0], 1)
225  
226          # For each variation of wallet key import, invoke the import RPC and
227          # check the results from getbalance and listtransactions.
228          for variant in IMPORT_VARIANTS:
229              self.log.info('Run import for variant {}'.format(variant))
230              expect_rescan = variant.rescan == Rescan.yes
231              variant.node = self.nodes[2 + IMPORT_NODES.index(ImportNode(variant.prune, expect_rescan))]
232              variant.do_import(variant.timestamp)
233              if expect_rescan:
234                  variant.amount_received = variant.initial_amount
235                  variant.expected_txs = 1
236                  variant.check(variant.initial_txid, variant.initial_amount, variant.confirmation_height)
237              else:
238                  variant.amount_received = 0
239                  variant.expected_txs = 0
240                  variant.check()
241  
242          # Create new transactions sending to each address.
243          for i, variant in enumerate(IMPORT_VARIANTS):
244              if i % 10 == 0:
245                  blockhash = self.generate(self.nodes[0], 1)[0]
246                  conf_height = self.nodes[0].getblockcount() + 1
247              variant.sent_amount = get_rand_amount()
248              variant.sent_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.sent_amount)
249              variant.confirmation_height = conf_height
250          self.generate(self.nodes[0], 1)
251  
252          assert_equal(self.nodes[0].getrawmempool(), [])
253          self.sync_all()
254  
255          # Check the latest results from getbalance and listtransactions.
256          for variant in IMPORT_VARIANTS:
257              self.log.info('Run check for variant {}'.format(variant))
258              variant.amount_received += variant.sent_amount
259              variant.expected_txs += 1
260              variant.check(variant.sent_txid, variant.sent_amount, variant.confirmation_height)
261  
262          self.log.info('Test that the mempool is rescanned as well if the rescan parameter is set to true')
263  
264          # The late timestamp and pruned variants are not necessary when testing mempool rescan
265          mempool_variants = [variant for variant in IMPORT_VARIANTS if variant.rescan != Rescan.late_timestamp and not variant.prune]
266          # No further blocks are mined so the timestamp will stay the same
267          timestamp = self.nodes[0].getblockheader(self.nodes[0].getbestblockhash())["time"]
268  
269          # Create one transaction on node 0 with a unique amount for
270          # each possible type of wallet import RPC.
271          for i, variant in enumerate(mempool_variants):
272              variant.label = "mempool label {} {}".format(i, variant)
273              variant.address = self.nodes[1].getaddressinfo(self.nodes[1].getnewaddress(
274                  label=variant.label,
275                  address_type=variant.address_type.value,
276              ))
277              variant.key = self.nodes[1].dumpprivkey(variant.address["address"])
278              # Ensure output is large enough to pay for fees: conservatively assuming txsize of
279              # 500 vbytes and feerate of 20 sats/vbytes
280              variant.initial_amount = get_rand_amount(min_amount=((500 * 20 / COIN) + AMOUNT_DUST))
281              variant.initial_txid = self.nodes[0].sendtoaddress(variant.address["address"], variant.initial_amount)
282              variant.confirmation_height = 0
283              variant.timestamp = timestamp
284  
285          # Mine a block so these parents are confirmed
286          assert_equal(len(self.nodes[0].getrawmempool()), len(mempool_variants))
287          self.sync_mempools()
288          block_to_disconnect = self.generate(self.nodes[0], 1)[0]
289          assert_equal(len(self.nodes[0].getrawmempool()), 0)
290  
291          # For each variant, create an unconfirmed child transaction from initial_txid, sending all
292          # the funds to an unspendable address. Importantly, no change output is created so the
293          # transaction can't be recognized using its outputs. The wallet rescan needs to know the
294          # inputs of the transaction to detect it, so the parent must be processed before the child.
295          # An equivalent test for descriptors exists in wallet_rescan_unconfirmed.py.
296          unspent_txid_map = {txin["txid"] : txin for txin in self.nodes[1].listunspent()}
297          for variant in mempool_variants:
298              # Send full amount, subtracting fee from outputs, to ensure no change is created.
299              child = self.nodes[1].send(
300                  add_to_wallet=False,
301                  inputs=[unspent_txid_map[variant.initial_txid]],
302                  outputs=[{ADDRESS_BCRT1_UNSPENDABLE : variant.initial_amount}],
303                  subtract_fee_from_outputs=[0]
304              )
305              variant.child_txid = child["txid"]
306              variant.amount_received = 0
307              self.nodes[0].sendrawtransaction(child["hex"])
308  
309          # Mempools should contain the child transactions for each variant.
310          assert_equal(len(self.nodes[0].getrawmempool()), len(mempool_variants))
311          self.sync_mempools()
312  
313          # Mock a reorg so the parent transactions are added back to the mempool
314          for node in self.nodes:
315              node.invalidateblock(block_to_disconnect)
316              # Mempools should now contain the parent and child for each variant.
317              assert_equal(len(node.getrawmempool()), 2 * len(mempool_variants))
318  
319          # For each variation of wallet key import, invoke the import RPC and
320          # check the results from getbalance and listtransactions.
321          for variant in mempool_variants:
322              self.log.info('Run import for mempool variant {}'.format(variant))
323              expect_rescan = variant.rescan == Rescan.yes
324              variant.node = self.nodes[2 + IMPORT_NODES.index(ImportNode(variant.prune, expect_rescan))]
325              variant.do_import(variant.timestamp)
326              if expect_rescan:
327                  # Ensure both transactions were rescanned. This would raise a JSONRPCError if the
328                  # transactions were not identified as belonging to the wallet.
329                  assert_equal(variant.node.gettransaction(variant.initial_txid)['confirmations'], 0)
330                  assert_equal(variant.node.gettransaction(variant.child_txid)['confirmations'], 0)
331                  variant.amount_received = variant.initial_amount
332                  variant.expected_txs = 1
333                  variant.check(variant.initial_txid, variant.initial_amount, 0)
334              else:
335                  variant.amount_received = 0
336                  variant.expected_txs = 0
337                  variant.check()
338  
339  
340  if __name__ == "__main__":
341      ImportRescanTest().main()