util.py
1 #!/usr/bin/env python3 2 # Copyright (c) 2014-2022 The Bitcoin Core developers 3 # Distributed under the MIT software license, see the accompanying 4 # file COPYING or http://www.opensource.org/licenses/mit-license.php. 5 """Helpful routines for regression testing.""" 6 7 from base64 import b64encode 8 from decimal import Decimal, ROUND_DOWN 9 from subprocess import CalledProcessError 10 import hashlib 11 import inspect 12 import json 13 import logging 14 import os 15 import pathlib 16 import platform 17 import re 18 import time 19 20 from . import coverage 21 from .authproxy import AuthServiceProxy, JSONRPCException 22 from collections.abc import Callable 23 from typing import Optional 24 25 logger = logging.getLogger("TestFramework.utils") 26 27 # Assert functions 28 ################## 29 30 31 def assert_approx(v, vexp, vspan=0.00001): 32 """Assert that `v` is within `vspan` of `vexp`""" 33 if isinstance(v, Decimal) or isinstance(vexp, Decimal): 34 v=Decimal(v) 35 vexp=Decimal(vexp) 36 vspan=Decimal(vspan) 37 if v < vexp - vspan: 38 raise AssertionError("%s < [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan))) 39 if v > vexp + vspan: 40 raise AssertionError("%s > [%s..%s]" % (str(v), str(vexp - vspan), str(vexp + vspan))) 41 42 43 def assert_fee_amount(fee, tx_size, feerate_BTC_kvB): 44 """Assert the fee is in range.""" 45 assert isinstance(tx_size, int) 46 target_fee = get_fee(tx_size, feerate_BTC_kvB) 47 if fee < target_fee: 48 raise AssertionError("Fee of %s BTC too low! (Should be %s BTC)" % (str(fee), str(target_fee))) 49 # allow the wallet's estimation to be at most 2 bytes off 50 high_fee = get_fee(tx_size + 2, feerate_BTC_kvB) 51 if fee > high_fee: 52 raise AssertionError("Fee of %s BTC too high! (Should be %s BTC)" % (str(fee), str(target_fee))) 53 54 55 def summarise_dict_differences(thing1, thing2): 56 if not isinstance(thing1, dict) or not isinstance(thing2, dict): 57 return thing1, thing2 58 d1, d2 = {}, {} 59 for k in sorted(thing1.keys()): 60 if k not in thing2: 61 d1[k] = thing1[k] 62 elif thing1[k] != thing2[k]: 63 d1[k], d2[k] = summarise_dict_differences(thing1[k], thing2[k]) 64 for k in sorted(thing2.keys()): 65 if k not in thing1: 66 d2[k] = thing2[k] 67 return d1, d2 68 69 def assert_equal(thing1, thing2, *args): 70 if thing1 != thing2 and not args and isinstance(thing1, dict) and isinstance(thing2, dict): 71 d1,d2 = summarise_dict_differences(thing1, thing2) 72 raise AssertionError("not(%s == %s)\n in particular not(%s == %s)" % (thing1, thing2, d1, d2)) 73 if thing1 != thing2 or any(thing1 != arg for arg in args): 74 raise AssertionError("not(%s)" % " == ".join(str(arg) for arg in (thing1, thing2) + args)) 75 76 77 def assert_greater_than(thing1, thing2): 78 if thing1 <= thing2: 79 raise AssertionError("%s <= %s" % (str(thing1), str(thing2))) 80 81 82 def assert_greater_than_or_equal(thing1, thing2): 83 if thing1 < thing2: 84 raise AssertionError("%s < %s" % (str(thing1), str(thing2))) 85 86 87 def assert_raises(exc, fun, *args, **kwds): 88 assert_raises_message(exc, None, fun, *args, **kwds) 89 90 91 def assert_raises_message(exc, message, fun, *args, **kwds): 92 try: 93 fun(*args, **kwds) 94 except JSONRPCException: 95 raise AssertionError("Use assert_raises_rpc_error() to test RPC failures") 96 except exc as e: 97 if message is not None and message not in e.error['message']: 98 raise AssertionError( 99 "Expected substring not found in error message:\nsubstring: '{}'\nerror message: '{}'.".format( 100 message, e.error['message'])) 101 except Exception as e: 102 raise AssertionError("Unexpected exception raised: " + type(e).__name__) 103 else: 104 raise AssertionError("No exception raised") 105 106 107 def assert_raises_process_error(returncode: int, output: str, fun: Callable, *args, **kwds): 108 """Execute a process and asserts the process return code and output. 109 110 Calls function `fun` with arguments `args` and `kwds`. Catches a CalledProcessError 111 and verifies that the return code and output are as expected. Throws AssertionError if 112 no CalledProcessError was raised or if the return code and output are not as expected. 113 114 Args: 115 returncode: the process return code. 116 output: [a substring of] the process output. 117 fun: the function to call. This should execute a process. 118 args*: positional arguments for the function. 119 kwds**: named arguments for the function. 120 """ 121 try: 122 fun(*args, **kwds) 123 except CalledProcessError as e: 124 if returncode != e.returncode: 125 raise AssertionError("Unexpected returncode %i" % e.returncode) 126 if output not in e.output: 127 raise AssertionError("Expected substring not found:" + e.output) 128 else: 129 raise AssertionError("No exception raised") 130 131 132 def assert_raises_rpc_error(code: Optional[int], message: Optional[str], fun: Callable, *args, **kwds): 133 """Run an RPC and verify that a specific JSONRPC exception code and message is raised. 134 135 Calls function `fun` with arguments `args` and `kwds`. Catches a JSONRPCException 136 and verifies that the error code and message are as expected. Throws AssertionError if 137 no JSONRPCException was raised or if the error code/message are not as expected. 138 139 Args: 140 code: the error code returned by the RPC call (defined in src/rpc/protocol.h). 141 Set to None if checking the error code is not required. 142 message: [a substring of] the error string returned by the RPC call. 143 Set to None if checking the error string is not required. 144 fun: the function to call. This should be the name of an RPC. 145 args*: positional arguments for the function. 146 kwds**: named arguments for the function. 147 """ 148 assert try_rpc(code, message, fun, *args, **kwds), "No exception raised" 149 150 151 def try_rpc(code, message, fun, *args, **kwds): 152 """Tries to run an rpc command. 153 154 Test against error code and message if the rpc fails. 155 Returns whether a JSONRPCException was raised.""" 156 try: 157 fun(*args, **kwds) 158 except JSONRPCException as e: 159 # JSONRPCException was thrown as expected. Check the code and message values are correct. 160 if (code is not None) and (code != e.error["code"]): 161 raise AssertionError("Unexpected JSONRPC error code %i" % e.error["code"]) 162 if (message is not None) and (message not in e.error['message']): 163 raise AssertionError( 164 "Expected substring not found in error message:\nsubstring: '{}'\nerror message: '{}'.".format( 165 message, e.error['message'])) 166 return True 167 except Exception as e: 168 raise AssertionError("Unexpected exception raised: " + type(e).__name__) 169 else: 170 return False 171 172 173 def assert_is_hex_string(string): 174 try: 175 int(string, 16) 176 except Exception as e: 177 raise AssertionError("Couldn't interpret %r as hexadecimal; raised: %s" % (string, e)) 178 179 180 def assert_is_hash_string(string, length=64): 181 if not isinstance(string, str): 182 raise AssertionError("Expected a string, got type %r" % type(string)) 183 elif length and len(string) != length: 184 raise AssertionError("String of length %d expected; got %d" % (length, len(string))) 185 elif not re.match('[abcdef0-9]+$', string): 186 raise AssertionError("String %r contains invalid characters for a hash." % string) 187 188 189 def assert_array_result(object_array, to_match, expected, should_not_find=False): 190 """ 191 Pass in array of JSON objects, a dictionary with key/value pairs 192 to match against, and another dictionary with expected key/value 193 pairs. 194 If the should_not_find flag is true, to_match should not be found 195 in object_array 196 """ 197 if should_not_find: 198 assert_equal(expected, {}) 199 num_matched = 0 200 for item in object_array: 201 all_match = True 202 for key, value in to_match.items(): 203 if item[key] != value: 204 all_match = False 205 if not all_match: 206 continue 207 elif should_not_find: 208 num_matched = num_matched + 1 209 for key, value in expected.items(): 210 if item[key] != value: 211 raise AssertionError("%s : expected %s=%s" % (str(item), str(key), str(value))) 212 num_matched = num_matched + 1 213 if num_matched == 0 and not should_not_find: 214 raise AssertionError("No objects matched %s" % (str(to_match))) 215 if num_matched > 0 and should_not_find: 216 raise AssertionError("Objects were found %s" % (str(to_match))) 217 218 219 # Utility functions 220 ################### 221 222 223 def check_json_precision(): 224 """Make sure json library being used does not lose precision converting BTC values""" 225 n = Decimal("20000000.00000003") 226 satoshis = int(json.loads(json.dumps(float(n))) * 1.0e8) 227 if satoshis != 2000000000000003: 228 raise RuntimeError("JSON encode/decode loses precision") 229 230 231 def count_bytes(hex_string): 232 return len(bytearray.fromhex(hex_string)) 233 234 235 def str_to_b64str(string): 236 return b64encode(string.encode('utf-8')).decode('ascii') 237 238 239 def ceildiv(a, b): 240 """ 241 Divide 2 ints and round up to next int rather than round down 242 Implementation requires python integers, which have a // operator that does floor division. 243 Other types like decimal.Decimal whose // operator truncates towards 0 will not work. 244 """ 245 assert isinstance(a, int) 246 assert isinstance(b, int) 247 return -(-a // b) 248 249 250 def get_fee(tx_size, feerate_btc_kvb): 251 """Calculate the fee in BTC given a feerate is BTC/kvB. Reflects CFeeRate::GetFee""" 252 feerate_sat_kvb = int(feerate_btc_kvb * Decimal(1e8)) # Fee in sat/kvb as an int to avoid float precision errors 253 target_fee_sat = ceildiv(feerate_sat_kvb * tx_size, 1000) # Round calculated fee up to nearest sat 254 return target_fee_sat / Decimal(1e8) # Return result in BTC 255 256 257 def satoshi_round(amount): 258 return Decimal(amount).quantize(Decimal('0.00000001'), rounding=ROUND_DOWN) 259 260 261 def wait_until_helper_internal(predicate, *, attempts=float('inf'), timeout=float('inf'), lock=None, timeout_factor=1.0): 262 """Sleep until the predicate resolves to be True. 263 264 Warning: Note that this method is not recommended to be used in tests as it is 265 not aware of the context of the test framework. Using the `wait_until()` members 266 from `BitcoinTestFramework` or `P2PInterface` class ensures the timeout is 267 properly scaled. Furthermore, `wait_until()` from `P2PInterface` class in 268 `p2p.py` has a preset lock. 269 """ 270 if attempts == float('inf') and timeout == float('inf'): 271 timeout = 60 272 timeout = timeout * timeout_factor 273 attempt = 0 274 time_end = time.time() + timeout 275 276 while attempt < attempts and time.time() < time_end: 277 if lock: 278 with lock: 279 if predicate(): 280 return 281 else: 282 if predicate(): 283 return 284 attempt += 1 285 time.sleep(0.05) 286 287 # Print the cause of the timeout 288 predicate_source = "''''\n" + inspect.getsource(predicate) + "'''" 289 logger.error("wait_until() failed. Predicate: {}".format(predicate_source)) 290 if attempt >= attempts: 291 raise AssertionError("Predicate {} not true after {} attempts".format(predicate_source, attempts)) 292 elif time.time() >= time_end: 293 raise AssertionError("Predicate {} not true after {} seconds".format(predicate_source, timeout)) 294 raise RuntimeError('Unreachable') 295 296 297 def sha256sum_file(filename): 298 h = hashlib.sha256() 299 with open(filename, 'rb') as f: 300 d = f.read(4096) 301 while len(d) > 0: 302 h.update(d) 303 d = f.read(4096) 304 return h.digest() 305 306 307 # RPC/P2P connection constants and functions 308 ############################################ 309 310 # The maximum number of nodes a single test can spawn 311 MAX_NODES = 12 312 # Don't assign rpc or p2p ports lower than this 313 PORT_MIN = int(os.getenv('TEST_RUNNER_PORT_MIN', default=11000)) 314 # The number of ports to "reserve" for p2p and rpc, each 315 PORT_RANGE = 5000 316 317 318 class PortSeed: 319 # Must be initialized with a unique integer for each process 320 n = None 321 322 323 def get_rpc_proxy(url: str, node_number: int, *, timeout: Optional[int]=None, coveragedir: Optional[str]=None) -> coverage.AuthServiceProxyWrapper: 324 """ 325 Args: 326 url: URL of the RPC server to call 327 node_number: the node number (or id) that this calls to 328 329 Kwargs: 330 timeout: HTTP timeout in seconds 331 coveragedir: Directory 332 333 Returns: 334 AuthServiceProxy. convenience object for making RPC calls. 335 336 """ 337 proxy_kwargs = {} 338 if timeout is not None: 339 proxy_kwargs['timeout'] = int(timeout) 340 341 proxy = AuthServiceProxy(url, **proxy_kwargs) 342 343 coverage_logfile = coverage.get_filename(coveragedir, node_number) if coveragedir else None 344 345 return coverage.AuthServiceProxyWrapper(proxy, url, coverage_logfile) 346 347 348 def p2p_port(n): 349 assert n <= MAX_NODES 350 return PORT_MIN + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES) 351 352 353 def rpc_port(n): 354 return PORT_MIN + PORT_RANGE + n + (MAX_NODES * PortSeed.n) % (PORT_RANGE - 1 - MAX_NODES) 355 356 357 def rpc_url(datadir, i, chain, rpchost): 358 rpc_u, rpc_p = get_auth_cookie(datadir, chain) 359 host = '127.0.0.1' 360 port = rpc_port(i) 361 if rpchost: 362 parts = rpchost.split(':') 363 if len(parts) == 2: 364 host, port = parts 365 else: 366 host = rpchost 367 return "http://%s:%s@%s:%d" % (rpc_u, rpc_p, host, int(port)) 368 369 370 # Node functions 371 ################ 372 373 374 def initialize_datadir(dirname, n, chain, disable_autoconnect=True): 375 datadir = get_datadir_path(dirname, n) 376 if not os.path.isdir(datadir): 377 os.makedirs(datadir) 378 write_config(os.path.join(datadir, "bitcoin.conf"), n=n, chain=chain, disable_autoconnect=disable_autoconnect) 379 os.makedirs(os.path.join(datadir, 'stderr'), exist_ok=True) 380 os.makedirs(os.path.join(datadir, 'stdout'), exist_ok=True) 381 return datadir 382 383 384 def write_config(config_path, *, n, chain, extra_config="", disable_autoconnect=True): 385 # Translate chain subdirectory name to config name 386 if chain == 'testnet3': 387 chain_name_conf_arg = 'testnet' 388 chain_name_conf_section = 'test' 389 else: 390 chain_name_conf_arg = chain 391 chain_name_conf_section = chain 392 with open(config_path, 'w', encoding='utf8') as f: 393 if chain_name_conf_arg: 394 f.write("{}=1\n".format(chain_name_conf_arg)) 395 if chain_name_conf_section: 396 f.write("[{}]\n".format(chain_name_conf_section)) 397 f.write("port=" + str(p2p_port(n)) + "\n") 398 f.write("rpcport=" + str(rpc_port(n)) + "\n") 399 # Disable server-side timeouts to avoid intermittent issues 400 f.write("rpcservertimeout=99000\n") 401 f.write("rpcdoccheck=1\n") 402 f.write("fallbackfee=0.0002\n") 403 f.write("server=1\n") 404 f.write("keypool=1\n") 405 f.write("discover=0\n") 406 f.write("dnsseed=0\n") 407 f.write("fixedseeds=0\n") 408 f.write("listenonion=0\n") 409 # Increase peertimeout to avoid disconnects while using mocktime. 410 # peertimeout is measured in mock time, so setting it large enough to 411 # cover any duration in mock time is sufficient. It can be overridden 412 # in tests. 413 f.write("peertimeout=999999999\n") 414 f.write("printtoconsole=0\n") 415 f.write("upnp=0\n") 416 f.write("natpmp=0\n") 417 f.write("shrinkdebugfile=0\n") 418 f.write("deprecatedrpc=create_bdb\n") # Required to run the tests 419 # To improve SQLite wallet performance so that the tests don't timeout, use -unsafesqlitesync 420 f.write("unsafesqlitesync=1\n") 421 if disable_autoconnect: 422 f.write("connect=0\n") 423 f.write(extra_config) 424 425 426 def get_datadir_path(dirname, n): 427 return pathlib.Path(dirname) / f"node{n}" 428 429 430 def get_temp_default_datadir(temp_dir: pathlib.Path) -> tuple[dict, pathlib.Path]: 431 """Return os-specific environment variables that can be set to make the 432 GetDefaultDataDir() function return a datadir path under the provided 433 temp_dir, as well as the complete path it would return.""" 434 if platform.system() == "Windows": 435 env = dict(APPDATA=str(temp_dir)) 436 datadir = temp_dir / "Bitcoin" 437 else: 438 env = dict(HOME=str(temp_dir)) 439 if platform.system() == "Darwin": 440 datadir = temp_dir / "Library/Application Support/Bitcoin" 441 else: 442 datadir = temp_dir / ".bitcoin" 443 return env, datadir 444 445 446 def append_config(datadir, options): 447 with open(os.path.join(datadir, "bitcoin.conf"), 'a', encoding='utf8') as f: 448 for option in options: 449 f.write(option + "\n") 450 451 452 def get_auth_cookie(datadir, chain): 453 user = None 454 password = None 455 if os.path.isfile(os.path.join(datadir, "bitcoin.conf")): 456 with open(os.path.join(datadir, "bitcoin.conf"), 'r', encoding='utf8') as f: 457 for line in f: 458 if line.startswith("rpcuser="): 459 assert user is None # Ensure that there is only one rpcuser line 460 user = line.split("=")[1].strip("\n") 461 if line.startswith("rpcpassword="): 462 assert password is None # Ensure that there is only one rpcpassword line 463 password = line.split("=")[1].strip("\n") 464 try: 465 with open(os.path.join(datadir, chain, ".cookie"), 'r', encoding="ascii") as f: 466 userpass = f.read() 467 split_userpass = userpass.split(':') 468 user = split_userpass[0] 469 password = split_userpass[1] 470 except OSError: 471 pass 472 if user is None or password is None: 473 raise ValueError("No RPC credentials") 474 return user, password 475 476 477 # If a cookie file exists in the given datadir, delete it. 478 def delete_cookie_file(datadir, chain): 479 if os.path.isfile(os.path.join(datadir, chain, ".cookie")): 480 logger.debug("Deleting leftover cookie file") 481 os.remove(os.path.join(datadir, chain, ".cookie")) 482 483 484 def softfork_active(node, key): 485 """Return whether a softfork is active.""" 486 return node.getdeploymentinfo()['deployments'][key]['active'] 487 488 489 def set_node_times(nodes, t): 490 for node in nodes: 491 node.setmocktime(t) 492 493 494 def check_node_connections(*, node, num_in, num_out): 495 info = node.getnetworkinfo() 496 assert_equal(info["connections_in"], num_in) 497 assert_equal(info["connections_out"], num_out) 498 499 500 # Transaction/Block functions 501 ############################# 502 503 504 # Create large OP_RETURN txouts that can be appended to a transaction 505 # to make it large (helper for constructing large transactions). The 506 # total serialized size of the txouts is about 66k vbytes. 507 def gen_return_txouts(): 508 from .messages import CTxOut 509 from .script import CScript, OP_RETURN 510 txouts = [CTxOut(nValue=0, scriptPubKey=CScript([OP_RETURN, b'\x01'*67437]))] 511 assert_equal(sum([len(txout.serialize()) for txout in txouts]), 67456) 512 return txouts 513 514 515 # Create a spend of each passed-in utxo, splicing in "txouts" to each raw 516 # transaction to make it large. See gen_return_txouts() above. 517 def create_lots_of_big_transactions(mini_wallet, node, fee, tx_batch_size, txouts, utxos=None): 518 txids = [] 519 use_internal_utxos = utxos is None 520 for _ in range(tx_batch_size): 521 tx = mini_wallet.create_self_transfer( 522 utxo_to_spend=None if use_internal_utxos else utxos.pop(), 523 fee=fee, 524 )["tx"] 525 tx.vout.extend(txouts) 526 res = node.testmempoolaccept([tx.serialize().hex()])[0] 527 assert_equal(res['fees']['base'], fee) 528 txids.append(node.sendrawtransaction(tx.serialize().hex())) 529 return txids 530 531 532 def mine_large_block(test_framework, mini_wallet, node): 533 # generate a 66k transaction, 534 # and 14 of them is close to the 1MB block limit 535 txouts = gen_return_txouts() 536 fee = 100 * node.getnetworkinfo()["relayfee"] 537 create_lots_of_big_transactions(mini_wallet, node, fee, 14, txouts) 538 test_framework.generate(node, 1) 539 540 541 def find_vout_for_address(node, txid, addr): 542 """ 543 Locate the vout index of the given transaction sending to the 544 given address. Raises runtime error exception if not found. 545 """ 546 tx = node.getrawtransaction(txid, True) 547 for i in range(len(tx["vout"])): 548 if addr == tx["vout"][i]["scriptPubKey"]["address"]: 549 return i 550 raise RuntimeError("Vout not found for address: txid=%s, addr=%s" % (txid, addr))