util.py
  1  #!/usr/bin/env python3
  2  # Copyright (c) 2014-2022 The Bitcoin Core developers
  3  # Distributed under the MIT software license, see the accompanying
  4  # file COPYING or http://www.opensource.org/licenses/mit-license.php.
  5  """Helpful routines for regression testing."""
  6  
  7  from base64 import b64encode
  8  from decimal import Decimal, ROUND_DOWN
  9  from subprocess import CalledProcessError
 10  import hashlib
 11  import inspect
 12  import json
 13  import logging
 14  import os
 15  import pathlib
 16  import platform
 17  import re
 18  import time
 19  
 20  from . import coverage
 21  from .authproxy import AuthServiceProxy, JSONRPCException
 22  from collections.abc import Callable
 23  from typing import Optional
 24  
 25  logger = logging.getLogger("TestFramework.utils")
 26  
 27  # Assert functions
 28  ##################
 29  
 30  
 31  def assert_approx(v, vexp, vspan=0.00001):
 32      """Assert that `v` is within `vspan` of `vexp`"""
 33      if isinstance(v, Decimal) or isinstance(vexp, Decimal):
 34          v=Decimal(v)
 35          vexp=Decimal(vexp)
 36          vspan=Decimal(vspan)
 37      if v < vexp - vspan:
 38          raise AssertionError("%s < [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan)))
 39      if v > vexp + vspan:
 40          raise AssertionError("%s > [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan)))
 41  
 42  
 43  def assert_fee_amount(fee, tx_size, feerate_BTC_kvB):
 44      """Assert the fee is in range."""
 45      assert isinstance(tx_size, int)
 46      target_fee = get_fee(tx_size, feerate_BTC_kvB)
 47      if fee < target_fee:
 48          raise AssertionError("Fee of %s BTC too low! (Should be %s BTC)" % (str(fee), str(target_fee)))
 49      # allow the wallet's estimation to be at most 2 bytes off
 50      high_fee = get_fee(tx_size + 2, feerate_BTC_kvB)
 51      if fee > high_fee:
 52          raise AssertionError("Fee of %s BTC too high! (Should be %s BTC)" % (str(fee), str(target_fee)))
 53  
 54  
 55  def summarise_dict_differences(thing1, thing2):
 56      if not isinstance(thing1, dict) or not isinstance(thing2, dict):
 57          return thing1, thing2
 58      d1, d2 = {}, {}
 59      for k in sorted(thing1.keys()):
 60          if k not in thing2:
 61              d1[k] = thing1[k]
 62          elif thing1[k] != thing2[k]:
 63              d1[k], d2[k] = summarise_dict_differences(thing1[k], thing2[k])
 64      for k in sorted(thing2.keys()):
 65          if k not in thing1:
 66              d2[k] = thing2[k]
 67      return d1, d2
 68  
 69  def assert_equal(thing1, thing2, *args):
 70      if thing1 != thing2 and not args and isinstance(thing1, dict) and isinstance(thing2, dict):
 71          d1,d2 = summarise_dict_differences(thing1, thing2)
 72          raise AssertionError("not(%s == %s)\n  in particular not(%s == %s)" % (thing1, thing2, d1, d2))
 73      if thing1 != thing2 or any(thing1 != arg for arg in args):
 74          raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args))
 75  
 76  
 77  def assert_greater_than(thing1, thing2):
 78      if thing1 <= thing2:
 79          raise AssertionError("%s <= %s" % (str(thing1), str(thing2)))
 80  
 81  
 82  def assert_greater_than_or_equal(thing1, thing2):
 83      if thing1 < thing2:
 84          raise AssertionError("%s < %s" % (str(thing1), str(thing2)))
 85  
 86  
 87  def assert_raises(exc, fun, *args, **kwds):
 88      assert_raises_message(exc, None, fun, *args, **kwds)
 89  
 90  
 91  def assert_raises_message(exc, message, fun, *args, **kwds):
 92      try:
 93          fun(*args, **kwds)
 94      except JSONRPCException:
 95          raise AssertionError("Use assert_raises_rpc_error() to test RPC failures")
 96      except exc as e:
 97          if message is not None and message not in e.error['message']:
 98              raise AssertionError(
 99                  "Expected substring not found in error message:\nsubstring: '{}'\nerror message: '{}'.".format(
100                      message, e.error['message']))
101      except Exception as e:
102          raise AssertionError("Unexpected exception raised: " + type(e).__name__)
103      else:
104          raise AssertionError("No exception raised")
105  
106  
107  def assert_raises_process_error(returncode: int, output: str, fun: Callable, *args, **kwds):
108      """Execute a process and asserts the process return code and output.
109  
110      Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError
111      and verifies that the return code and output are as expected. Throws AssertionError if
112      no CalledProcessError was raised or if the return code and output are not as expected.
113  
114      Args:
115          returncode: the process return code.
116          output: [a substring of] the process output.
117          fun: the function to call. This should execute a process.
118          args*: positional arguments for the function.
119          kwds**: named arguments for the function.
120      """
121      try:
122          fun(*args, **kwds)
123      except CalledProcessError as e:
124          if returncode != e.returncode:
125              raise AssertionError("Unexpected returncode %i" % e.returncode)
126          if output not in e.output:
127              raise AssertionError("Expected substring not found:" + e.output)
128      else:
129          raise AssertionError("No exception raised")
130  
131  
132  def assert_raises_rpc_error(code: Optional[int], message: Optional[str], fun: Callable, *args, **kwds):
133      """Run an RPC and verify that a specific JSONRPC exception code and message is raised.
134  
135      Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException
136      and verifies that the error code and message are as expected. Throws AssertionError if
137      no JSONRPCException was raised or if the error code/message are not as expected.
138  
139      Args:
140          code: the error code returned by the RPC call (defined in src/rpc/protocol.h).
141              Set to None if checking the error code is not required.
142          message: [a substring of] the error string returned by the RPC call.
143              Set to None if checking the error string is not required.
144          fun: the function to call. This should be the name of an RPC.
145          args*: positional arguments for the function.
146          kwds**: named arguments for the function.
147      """
148      assert try_rpc(code, message, fun, *args, **kwds), "No exception raised"
149  
150  
151  def try_rpc(code, message, fun, *args, **kwds):
152      """Tries to run an rpc command.
153  
154      Test against error code and message if the rpc fails.
155      Returns whether a JSONRPCException was raised."""
156      try:
157          fun(*args, **kwds)
158      except JSONRPCException as e:
159          # JSONRPCException was thrown as expected. Check the code and message values are correct.
160          if (code is not None) and (code != e.error["code"]):
161              raise AssertionError("Unexpected JSONRPC error code %i" % e.error["code"])
162          if (message is not None) and (message not in e.error['message']):
163              raise AssertionError(
164                  "Expected substring not found in error message:\nsubstring: '{}'\nerror message: '{}'.".format(
165                      message, e.error['message']))
166          return True
167      except Exception as e:
168          raise AssertionError("Unexpected exception raised: " + type(e).__name__)
169      else:
170          return False
171  
172  
173  def assert_is_hex_string(string):
174      try:
175          int(string, 16)
176      except Exception as e:
177          raise AssertionError("Couldn't interpret %r as hexadecimal; raised: %s" % (string, e))
178  
179  
180  def assert_is_hash_string(string, length=64):
181      if not isinstance(string, str):
182          raise AssertionError("Expected a string, got type %r" % type(string))
183      elif length and len(string) != length:
184          raise AssertionError("String of length %d expected; got %d" % (length, len(string)))
185      elif not re.match('[abcdef0-9]+$', string):
186          raise AssertionError("String %r contains invalid characters for a hash." % string)
187  
188  
189  def assert_array_result(object_array, to_match, expected, should_not_find=False):
190      """
191          Pass in array of JSON objects, a dictionary with key/value pairs
192          to match against, and another dictionary with expected key/value
193          pairs.
194          If the should_not_find flag is true, to_match should not be found
195          in object_array
196          """
197      if should_not_find:
198          assert_equal(expected, {})
199      num_matched = 0
200      for item in object_array:
201          all_match = True
202          for key, value in to_match.items():
203              if item[key] != value:
204                  all_match = False
205          if not all_match:
206              continue
207          elif should_not_find:
208              num_matched = num_matched + 1
209          for key, value in expected.items():
210              if item[key] != value:
211                  raise AssertionError("%s : expected %s=%s" % (str(item), str(key), str(value)))
212              num_matched = num_matched + 1
213      if num_matched == 0 and not should_not_find:
214          raise AssertionError("No objects matched %s" % (str(to_match)))
215      if num_matched > 0 and should_not_find:
216          raise AssertionError("Objects were found %s" % (str(to_match)))
217  
218  
219  # Utility functions
220  ###################
221  
222  
223  def check_json_precision():
224      """Make sure json library being used does not lose precision converting BTC values"""
225      n = Decimal("20000000.00000003")
226      satoshis = int(json.loads(json.dumps(float(n))) * 1.0e8)
227      if satoshis != 2000000000000003:
228          raise RuntimeError("JSON encode/decode loses precision")
229  
230  
231  def count_bytes(hex_string):
232      return len(bytearray.fromhex(hex_string))
233  
234  
235  def str_to_b64str(string):
236      return b64encode(string.encode('utf-8')).decode('ascii')
237  
238  
239  def ceildiv(a, b):
240      """
241      Divide 2 ints and round up to next int rather than round down
242      Implementation requires python integers, which have a // operator that does floor division.
243      Other types like decimal.Decimal whose // operator truncates towards 0 will not work.
244      """
245      assert isinstance(a, int)
246      assert isinstance(b, int)
247      return -(-a // b)
248  
249  
250  def get_fee(tx_size, feerate_btc_kvb):
251      """Calculate the fee in BTC given a feerate is BTC/kvB. Reflects CFeeRate::GetFee"""
252      feerate_sat_kvb = int(feerate_btc_kvb * Decimal(1e8)) # Fee in sat/kvb as an int to avoid float precision errors
253      target_fee_sat = ceildiv(feerate_sat_kvb * tx_size, 1000) # Round calculated fee up to nearest sat
254      return target_fee_sat / Decimal(1e8) # Return result in  BTC
255  
256  
257  def satoshi_round(amount):
258      return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN)
259  
260  
261  def wait_until_helper_internal(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None, timeout_factor=1.0):
262      """Sleep until the predicate resolves to be True.
263  
264      Warning: Note that this method is not recommended to be used in tests as it is
265      not aware of the context of the test framework. Using the `wait_until()` members
266      from `BitcoinTestFramework` or `P2PInterface` class ensures the timeout is
267      properly scaled. Furthermore, `wait_until()` from `P2PInterface` class in
268      `p2p.py` has a preset lock.
269      """
270      if attempts == float('inf') and timeout == float('inf'):
271          timeout = 60
272      timeout = timeout * timeout_factor
273      attempt = 0
274      time_end = time.time() + timeout
275  
276      while attempt < attempts and time.time() < time_end:
277          if lock:
278              with lock:
279                  if predicate():
280                      return
281          else:
282              if predicate():
283                  return
284          attempt += 1
285          time.sleep(0.05)
286  
287      # Print the cause of the timeout
288      predicate_source = "''''\n" + inspect.getsource(predicate) + "'''"
289      logger.error("wait_until() failed. Predicate: {}".format(predicate_source))
290      if attempt >= attempts:
291          raise AssertionError("Predicate {} not true after {} attempts".format(predicate_source, attempts))
292      elif time.time() >= time_end:
293          raise AssertionError("Predicate {} not true after {} seconds".format(predicate_source, timeout))
294      raise RuntimeError('Unreachable')
295  
296  
297  def sha256sum_file(filename):
298      h = hashlib.sha256()
299      with open(filename, 'rb') as f:
300          d = f.read(4096)
301          while len(d) > 0:
302              h.update(d)
303              d = f.read(4096)
304      return h.digest()
305  
306  
307  # RPC/P2P connection constants and functions
308  ############################################
309  
310  # The maximum number of nodes a single test can spawn
311  MAX_NODES = 12
312  # Don't assign rpc or p2p ports lower than this
313  PORT_MIN = int(os.getenv('TEST_RUNNER_PORT_MIN', default=11000))
314  # The number of ports to "reserve" for p2p and rpc, each
315  PORT_RANGE = 5000
316  
317  
318  class PortSeed:
319      # Must be initialized with a unique integer for each process
320      n = None
321  
322  
323  def get_rpc_proxy(url: str, node_number: int, *, timeout: Optional[int]=None, coveragedir: Optional[str]=None) -> coverage.AuthServiceProxyWrapper:
324      """
325      Args:
326          url: URL of the RPC server to call
327          node_number: the node number (or id) that this calls to
328  
329      Kwargs:
330          timeout: HTTP timeout in seconds
331          coveragedir: Directory
332  
333      Returns:
334          AuthServiceProxy. convenience object for making RPC calls.
335  
336      """
337      proxy_kwargs = {}
338      if timeout is not None:
339          proxy_kwargs['timeout'] = int(timeout)
340  
341      proxy = AuthServiceProxy(url, **proxy_kwargs)
342  
343      coverage_logfile = coverage.get_filename(coveragedir, node_number) if coveragedir else None
344  
345      return coverage.AuthServiceProxyWrapper(proxy, url, coverage_logfile)
346  
347  
348  def p2p_port(n):
349      assert n <= MAX_NODES
350      return PORT_MIN + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES)
351  
352  
353  def rpc_port(n):
354      return PORT_MIN + PORT_RANGE + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES)
355  
356  
357  def rpc_url(datadir, i, chain, rpchost):
358      rpc_u, rpc_p = get_auth_cookie(datadir, chain)
359      host = '127.0.0.1'
360      port = rpc_port(i)
361      if rpchost:
362          parts = rpchost.split(':')
363          if len(parts) == 2:
364              host, port = parts
365          else:
366              host = rpchost
367      return "http://%s:%s@%s:%d" % (rpc_u, rpc_p, host, int(port))
368  
369  
370  # Node functions
371  ################
372  
373  
374  def initialize_datadir(dirname, n, chain, disable_autoconnect=True):
375      datadir = get_datadir_path(dirname, n)
376      if not os.path.isdir(datadir):
377          os.makedirs(datadir)
378      write_config(os.path.join(datadir, "bitcoin.conf"), n=n, chain=chain, disable_autoconnect=disable_autoconnect)
379      os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True)
380      os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True)
381      return datadir
382  
383  
384  def write_config(config_path, *, n, chain, extra_config="", disable_autoconnect=True):
385      # Translate chain subdirectory name to config name
386      if chain == 'testnet3':
387          chain_name_conf_arg = 'testnet'
388          chain_name_conf_section = 'test'
389      else:
390          chain_name_conf_arg = chain
391          chain_name_conf_section = chain
392      with open(config_path, 'w', encoding='utf8') as f:
393          if chain_name_conf_arg:
394              f.write("{}=1\n".format(chain_name_conf_arg))
395          if chain_name_conf_section:
396              f.write("[{}]\n".format(chain_name_conf_section))
397          f.write("port=" + str(p2p_port(n)) + "\n")
398          f.write("rpcport=" + str(rpc_port(n)) + "\n")
399          # Disable server-side timeouts to avoid intermittent issues
400          f.write("rpcservertimeout=99000\n")
401          f.write("rpcdoccheck=1\n")
402          f.write("fallbackfee=0.0002\n")
403          f.write("server=1\n")
404          f.write("keypool=1\n")
405          f.write("discover=0\n")
406          f.write("dnsseed=0\n")
407          f.write("fixedseeds=0\n")
408          f.write("listenonion=0\n")
409          # Increase peertimeout to avoid disconnects while using mocktime.
410          # peertimeout is measured in mock time, so setting it large enough to
411          # cover any duration in mock time is sufficient. It can be overridden
412          # in tests.
413          f.write("peertimeout=999999999\n")
414          f.write("printtoconsole=0\n")
415          f.write("upnp=0\n")
416          f.write("natpmp=0\n")
417          f.write("shrinkdebugfile=0\n")
418          f.write("deprecatedrpc=create_bdb\n")  # Required to run the tests
419          # To improve SQLite wallet performance so that the tests don't timeout, use -unsafesqlitesync
420          f.write("unsafesqlitesync=1\n")
421          if disable_autoconnect:
422              f.write("connect=0\n")
423          f.write(extra_config)
424  
425  
426  def get_datadir_path(dirname, n):
427      return pathlib.Path(dirname) / f"node{n}"
428  
429  
430  def get_temp_default_datadir(temp_dir: pathlib.Path) -> tuple[dict, pathlib.Path]:
431      """Return os-specific environment variables that can be set to make the
432      GetDefaultDataDir() function return a datadir path under the provided
433      temp_dir, as well as the complete path it would return."""
434      if platform.system() == "Windows":
435          env = dict(APPDATA=str(temp_dir))
436          datadir = temp_dir / "Bitcoin"
437      else:
438          env = dict(HOME=str(temp_dir))
439          if platform.system() == "Darwin":
440              datadir = temp_dir / "Library/Application Support/Bitcoin"
441          else:
442              datadir = temp_dir / ".bitcoin"
443      return env, datadir
444  
445  
446  def append_config(datadir, options):
447      with open(os.path.join(datadir, "bitcoin.conf"), 'a', encoding='utf8') as f:
448          for option in options:
449              f.write(option + "\n")
450  
451  
452  def get_auth_cookie(datadir, chain):
453      user = None
454      password = None
455      if os.path.isfile(os.path.join(datadir, "bitcoin.conf")):
456          with open(os.path.join(datadir, "bitcoin.conf"), 'r', encoding='utf8') as f:
457              for line in f:
458                  if line.startswith("rpcuser="):
459                      assert user is None  # Ensure that there is only one rpcuser line
460                      user = line.split("=")[1].strip("\n")
461                  if line.startswith("rpcpassword="):
462                      assert password is None  # Ensure that there is only one rpcpassword line
463                      password = line.split("=")[1].strip("\n")
464      try:
465          with open(os.path.join(datadir, chain, ".cookie"), 'r', encoding="ascii") as f:
466              userpass = f.read()
467              split_userpass = userpass.split(':')
468              user = split_userpass[0]
469              password = split_userpass[1]
470      except OSError:
471          pass
472      if user is None or password is None:
473          raise ValueError("No RPC credentials")
474      return user, password
475  
476  
477  # If a cookie file exists in the given datadir, delete it.
478  def delete_cookie_file(datadir, chain):
479      if os.path.isfile(os.path.join(datadir, chain, ".cookie")):
480          logger.debug("Deleting leftover cookie file")
481          os.remove(os.path.join(datadir, chain, ".cookie"))
482  
483  
484  def softfork_active(node, key):
485      """Return whether a softfork is active."""
486      return node.getdeploymentinfo()['deployments'][key]['active']
487  
488  
489  def set_node_times(nodes, t):
490      for node in nodes:
491          node.setmocktime(t)
492  
493  
494  def check_node_connections(*, node, num_in, num_out):
495      info = node.getnetworkinfo()
496      assert_equal(info["connections_in"], num_in)
497      assert_equal(info["connections_out"], num_out)
498  
499  
500  # Transaction/Block functions
501  #############################
502  
503  
504  # Create large OP_RETURN txouts that can be appended to a transaction
505  # to make it large (helper for constructing large transactions). The
506  # total serialized size of the txouts is about 66k vbytes.
507  def gen_return_txouts():
508      from .messages import CTxOut
509      from .script import CScript, OP_RETURN
510      txouts = [CTxOut(nValue=0, scriptPubKey=CScript([OP_RETURN, b'\x01'*67437]))]
511      assert_equal(sum([len(txout.serialize()) for txout in txouts]), 67456)
512      return txouts
513  
514  
515  # Create a spend of each passed-in utxo, splicing in "txouts" to each raw
516  # transaction to make it large.  See gen_return_txouts() above.
517  def create_lots_of_big_transactions(mini_wallet, node, fee, tx_batch_size, txouts, utxos=None):
518      txids = []
519      use_internal_utxos = utxos is None
520      for _ in range(tx_batch_size):
521          tx = mini_wallet.create_self_transfer(
522              utxo_to_spend=None if use_internal_utxos else utxos.pop(),
523              fee=fee,
524          )["tx"]
525          tx.vout.extend(txouts)
526          res = node.testmempoolaccept([tx.serialize().hex()])[0]
527          assert_equal(res['fees']['base'], fee)
528          txids.append(node.sendrawtransaction(tx.serialize().hex()))
529      return txids
530  
531  
532  def mine_large_block(test_framework, mini_wallet, node):
533      # generate a 66k transaction,
534      # and 14 of them is close to the 1MB block limit
535      txouts = gen_return_txouts()
536      fee = 100 * node.getnetworkinfo()["relayfee"]
537      create_lots_of_big_transactions(mini_wallet, node, fee, 14, txouts)
538      test_framework.generate(node, 1)
539  
540  
541  def find_vout_for_address(node, txid, addr):
542      """
543      Locate the vout index of the given transaction sending to the
544      given address. Raises runtime error exception if not found.
545      """
546      tx = node.getrawtransaction(txid, True)
547      for i in range(len(tx["vout"])):
548          if addr == tx["vout"][i]["scriptPubKey"]["address"]:
549              return i
550      raise RuntimeError("Vout not found for address: txid=%s, addr=%s" % (txid, addr))