p2p.py
1 #!/usr/bin/env python3 2 # Copyright (c) 2010 ArtForz -- public domain half-a-node 3 # Copyright (c) 2012 Jeff Garzik 4 # Copyright (c) 2010-2022 The Bitcoin Core developers 5 # Distributed under the MIT software license, see the accompanying 6 # file COPYING or http://www.opensource.org/licenses/mit-license.php. 7 """Test objects for interacting with a bitcoind node over the p2p protocol. 8 9 The P2PInterface objects interact with the bitcoind nodes under test using the 10 node's p2p interface. They can be used to send messages to the node, and 11 callbacks can be registered that execute when messages are received from the 12 node. Messages are sent to/received from the node on an asyncio event loop. 13 State held inside the objects must be guarded by the p2p_lock to avoid data 14 races between the main testing thread and the event loop. 15 16 P2PConnection: A low-level connection object to a node's P2P interface 17 P2PInterface: A high-level interface object for communicating to a node over P2P 18 P2PDataStore: A p2p interface class that keeps a store of transactions and blocks 19 and can respond correctly to getdata and getheaders messages 20 P2PTxInvStore: A p2p interface class that inherits from P2PDataStore, and keeps 21 a count of how many times each txid has been announced.""" 22 23 import asyncio 24 from collections import defaultdict 25 from io import BytesIO 26 import logging 27 import platform 28 import struct 29 import sys 30 import threading 31 32 from test_framework.messages import ( 33 CBlockHeader, 34 MAX_HEADERS_RESULTS, 35 msg_addr, 36 msg_addrv2, 37 msg_block, 38 MSG_BLOCK, 39 msg_blocktxn, 40 msg_cfcheckpt, 41 msg_cfheaders, 42 msg_cfilter, 43 msg_cmpctblock, 44 msg_feefilter, 45 msg_filteradd, 46 msg_filterclear, 47 msg_filterload, 48 msg_getaddr, 49 msg_getblocks, 50 msg_getblocktxn, 51 msg_getcfcheckpt, 52 msg_getcfheaders, 53 msg_getcfilters, 54 msg_getdata, 55 msg_getheaders, 56 msg_headers, 57 msg_inv, 58 msg_mempool, 59 msg_merkleblock, 60 msg_notfound, 61 msg_ping, 62 msg_pong, 63 msg_sendaddrv2, 64 msg_sendcmpct, 65 msg_sendheaders, 66 msg_sendtxrcncl, 67 msg_tx, 68 MSG_TX, 69 MSG_TYPE_MASK, 70 msg_verack, 71 msg_version, 72 MSG_WTX, 73 msg_wtxidrelay, 74 NODE_NETWORK, 75 NODE_WITNESS, 76 MAGIC_BYTES, 77 sha256, 78 ) 79 from test_framework.util import ( 80 MAX_NODES, 81 p2p_port, 82 wait_until_helper_internal, 83 ) 84 from test_framework.v2_p2p import ( 85 EncryptedP2PState, 86 MSGTYPE_TO_SHORTID, 87 SHORTID, 88 ) 89 90 logger = logging.getLogger("TestFramework.p2p") 91 92 # The minimum P2P version that this test framework supports 93 MIN_P2P_VERSION_SUPPORTED = 60001 94 # The P2P version that this test framework implements and sends in its `version` message 95 # Version 70016 supports wtxid relay 96 P2P_VERSION = 70016 97 # The services that this test framework offers in its `version` message 98 P2P_SERVICES = NODE_NETWORK | NODE_WITNESS 99 # The P2P user agent string that this test framework sends in its `version` message 100 P2P_SUBVERSION = "/python-p2p-tester:0.0.3/" 101 # Value for relay that this test framework sends in its `version` message 102 P2P_VERSION_RELAY = 1 103 # Delay after receiving a tx inv before requesting transactions from non-preferred peers, in seconds 104 NONPREF_PEER_TX_DELAY = 2 105 # Delay for requesting transactions via txids if we have wtxid-relaying peers, in seconds 106 TXID_RELAY_DELAY = 2 107 # Delay for requesting transactions if the peer has MAX_PEER_TX_REQUEST_IN_FLIGHT or more requests 108 OVERLOADED_PEER_TX_DELAY = 2 109 # How long to wait before downloading a transaction from an additional peer 110 GETDATA_TX_INTERVAL = 60 111 112 MESSAGEMAP = { 113 b"addr": msg_addr, 114 b"addrv2": msg_addrv2, 115 b"block": msg_block, 116 b"blocktxn": msg_blocktxn, 117 b"cfcheckpt": msg_cfcheckpt, 118 b"cfheaders": msg_cfheaders, 119 b"cfilter": msg_cfilter, 120 b"cmpctblock": msg_cmpctblock, 121 b"feefilter": msg_feefilter, 122 b"filteradd": msg_filteradd, 123 b"filterclear": msg_filterclear, 124 b"filterload": msg_filterload, 125 b"getaddr": msg_getaddr, 126 b"getblocks": msg_getblocks, 127 b"getblocktxn": msg_getblocktxn, 128 b"getcfcheckpt": msg_getcfcheckpt, 129 b"getcfheaders": msg_getcfheaders, 130 b"getcfilters": msg_getcfilters, 131 b"getdata": msg_getdata, 132 b"getheaders": msg_getheaders, 133 b"headers": msg_headers, 134 b"inv": msg_inv, 135 b"mempool": msg_mempool, 136 b"merkleblock": msg_merkleblock, 137 b"notfound": msg_notfound, 138 b"ping": msg_ping, 139 b"pong": msg_pong, 140 b"sendaddrv2": msg_sendaddrv2, 141 b"sendcmpct": msg_sendcmpct, 142 b"sendheaders": msg_sendheaders, 143 b"sendtxrcncl": msg_sendtxrcncl, 144 b"tx": msg_tx, 145 b"verack": msg_verack, 146 b"version": msg_version, 147 b"wtxidrelay": msg_wtxidrelay, 148 } 149 150 151 class P2PConnection(asyncio.Protocol): 152 """A low-level connection object to a node's P2P interface. 153 154 This class is responsible for: 155 156 - opening and closing the TCP connection to the node 157 - reading bytes from and writing bytes to the socket 158 - deserializing and serializing the P2P message header 159 - logging messages as they are sent and received 160 161 This class contains no logic for handing the P2P message payloads. It must be 162 sub-classed and the on_message() callback overridden.""" 163 164 def __init__(self): 165 # The underlying transport of the connection. 166 # Should only call methods on this from the NetworkThread, c.f. call_soon_threadsafe 167 self._transport = None 168 # This lock is acquired before sending messages over the socket. There's an implied lock order and 169 # p2p_lock must not be acquired after _send_lock as it could result in deadlocks. 170 self._send_lock = threading.Lock() 171 self.v2_state = None # EncryptedP2PState object needed for v2 p2p connections 172 self.reconnect = False # set if reconnection needs to happen 173 174 @property 175 def is_connected(self): 176 return self._transport is not None 177 178 @property 179 def supports_v2_p2p(self): 180 return self.v2_state is not None 181 182 def peer_connect_helper(self, dstaddr, dstport, net, timeout_factor): 183 assert not self.is_connected 184 self.timeout_factor = timeout_factor 185 self.dstaddr = dstaddr 186 self.dstport = dstport 187 # The initial message to send after the connection was made: 188 self.on_connection_send_msg = None 189 self.recvbuf = b"" 190 self.magic_bytes = MAGIC_BYTES[net] 191 192 def peer_connect(self, dstaddr, dstport, *, net, timeout_factor, supports_v2_p2p): 193 self.peer_connect_helper(dstaddr, dstport, net, timeout_factor) 194 if supports_v2_p2p: 195 self.v2_state = EncryptedP2PState(initiating=True, net=net) 196 197 loop = NetworkThread.network_event_loop 198 logger.debug('Connecting to Bitcoin Node: %s:%d' % (self.dstaddr, self.dstport)) 199 coroutine = loop.create_connection(lambda: self, host=self.dstaddr, port=self.dstport) 200 return lambda: loop.call_soon_threadsafe(loop.create_task, coroutine) 201 202 def peer_accept_connection(self, connect_id, connect_cb=lambda: None, *, net, timeout_factor, supports_v2_p2p, reconnect): 203 self.peer_connect_helper('0', 0, net, timeout_factor) 204 self.reconnect = reconnect 205 if supports_v2_p2p: 206 self.v2_state = EncryptedP2PState(initiating=False, net=net) 207 208 logger.debug('Listening for Bitcoin Node with id: {}'.format(connect_id)) 209 return lambda: NetworkThread.listen(self, connect_cb, idx=connect_id) 210 211 def peer_disconnect(self): 212 # Connection could have already been closed by other end. 213 NetworkThread.network_event_loop.call_soon_threadsafe(lambda: self._transport and self._transport.abort()) 214 215 # Connection and disconnection methods 216 217 def connection_made(self, transport): 218 """asyncio callback when a connection is opened.""" 219 assert not self._transport 220 logger.debug("Connected & Listening: %s:%d" % (self.dstaddr, self.dstport)) 221 self._transport = transport 222 # in an inbound connection to the TestNode with P2PConnection as the initiator, [TestNode <---- P2PConnection] 223 # send the initial handshake immediately 224 if self.supports_v2_p2p and self.v2_state.initiating and not self.v2_state.tried_v2_handshake: 225 send_handshake_bytes = self.v2_state.initiate_v2_handshake() 226 self.send_raw_message(send_handshake_bytes) 227 # for v1 outbound connections, send version message immediately after opening 228 # (for v2 outbound connections, send it after the initial v2 handshake) 229 if self.p2p_connected_to_node and not self.supports_v2_p2p: 230 self.send_version() 231 self.on_open() 232 233 def connection_lost(self, exc): 234 """asyncio callback when a connection is closed.""" 235 # don't display warning if reconnection needs to be attempted using v1 P2P 236 if exc and not self.reconnect: 237 logger.warning("Connection lost to {}:{} due to {}".format(self.dstaddr, self.dstport, exc)) 238 else: 239 logger.debug("Closed connection to: %s:%d" % (self.dstaddr, self.dstport)) 240 self._transport = None 241 self.recvbuf = b"" 242 self.on_close() 243 244 # v2 handshake method 245 def _on_data_v2_handshake(self): 246 """v2 handshake performed before P2P messages are exchanged (see BIP324). P2PConnection is the initiator 247 (in inbound connections to TestNode) and the responder (in outbound connections from TestNode). 248 Performed by: 249 * initiator using `initiate_v2_handshake()`, `complete_handshake()` and `authenticate_handshake()` 250 * responder using `respond_v2_handshake()`, `complete_handshake()` and `authenticate_handshake()` 251 252 `initiate_v2_handshake()` is immediately done by the initiator when the connection is established in 253 `connection_made()`. The rest of the initial v2 handshake functions are handled here. 254 """ 255 if not self.v2_state.peer: 256 if not self.v2_state.initiating and not self.v2_state.sent_garbage: 257 # if the responder hasn't sent garbage yet, the responder is still reading ellswift bytes 258 # reads ellswift bytes till the first mismatch from 12 bytes V1_PREFIX 259 length, send_handshake_bytes = self.v2_state.respond_v2_handshake(BytesIO(self.recvbuf)) 260 self.recvbuf = self.recvbuf[length:] 261 if send_handshake_bytes == -1: 262 self.v2_state = None 263 return 264 elif send_handshake_bytes: 265 self.send_raw_message(send_handshake_bytes) 266 elif send_handshake_bytes == b"": 267 return # only after send_handshake_bytes are sent can `complete_handshake()` be done 268 269 # `complete_handshake()` reads the remaining ellswift bytes from recvbuf 270 # and sends response after deriving shared ECDH secret using received ellswift bytes 271 length, response = self.v2_state.complete_handshake(BytesIO(self.recvbuf)) 272 self.recvbuf = self.recvbuf[length:] 273 if response: 274 self.send_raw_message(response) 275 else: 276 return # only after response is sent can `authenticate_handshake()` be done 277 278 # `self.v2_state.peer` is instantiated only after shared ECDH secret/BIP324 derived keys and ciphers 279 # is derived in `complete_handshake()`. 280 # so `authenticate_handshake()` which uses the BIP324 derived ciphers gets called after `complete_handshake()`. 281 assert self.v2_state.peer 282 length, is_mac_auth = self.v2_state.authenticate_handshake(self.recvbuf) 283 if not is_mac_auth: 284 raise ValueError("invalid v2 mac tag in handshake authentication") 285 self.recvbuf = self.recvbuf[length:] 286 if self.v2_state.tried_v2_handshake: 287 # for v2 outbound connections, send version message immediately after v2 handshake 288 if self.p2p_connected_to_node: 289 self.send_version() 290 # process post-v2-handshake data immediately, if available 291 if len(self.recvbuf) > 0: 292 self._on_data() 293 294 # Socket read methods 295 296 def data_received(self, t): 297 """asyncio callback when data is read from the socket.""" 298 if len(t) > 0: 299 self.recvbuf += t 300 if self.supports_v2_p2p and not self.v2_state.tried_v2_handshake: 301 self._on_data_v2_handshake() 302 else: 303 self._on_data() 304 305 def _on_data(self): 306 """Try to read P2P messages from the recv buffer. 307 308 This method reads data from the buffer in a loop. It deserializes, 309 parses and verifies the P2P header, then passes the P2P payload to 310 the on_message callback for processing.""" 311 try: 312 while True: 313 if self.supports_v2_p2p: 314 # v2 P2P messages are read 315 msglen, msg = self.v2_state.v2_receive_packet(self.recvbuf) 316 if msglen == -1: 317 raise ValueError("invalid v2 mac tag " + repr(self.recvbuf)) 318 elif msglen == 0: # need to receive more bytes in recvbuf 319 return 320 self.recvbuf = self.recvbuf[msglen:] 321 322 if msg is None: # ignore decoy messages 323 return 324 assert msg # application layer messages (which aren't decoy messages) are non-empty 325 shortid = msg[0] # 1-byte short message type ID 326 if shortid == 0: 327 # next 12 bytes are interpreted as ASCII message type if shortid is b'\x00' 328 if len(msg) < 13: 329 raise IndexError("msg needs minimum required length of 13 bytes") 330 msgtype = msg[1:13].rstrip(b'\x00') 331 msg = msg[13:] # msg is set to be payload 332 else: 333 # a 1-byte short message type ID 334 msgtype = SHORTID.get(shortid, f"unknown-{shortid}") 335 msg = msg[1:] 336 else: 337 # v1 P2P messages are read 338 if len(self.recvbuf) < 4: 339 return 340 if self.recvbuf[:4] != self.magic_bytes: 341 raise ValueError("magic bytes mismatch: {} != {}".format(repr(self.magic_bytes), repr(self.recvbuf))) 342 if len(self.recvbuf) < 4 + 12 + 4 + 4: 343 return 344 msgtype = self.recvbuf[4:4+12].split(b"\x00", 1)[0] 345 msglen = struct.unpack("<i", self.recvbuf[4+12:4+12+4])[0] 346 checksum = self.recvbuf[4+12+4:4+12+4+4] 347 if len(self.recvbuf) < 4 + 12 + 4 + 4 + msglen: 348 return 349 msg = self.recvbuf[4+12+4+4:4+12+4+4+msglen] 350 th = sha256(msg) 351 h = sha256(th) 352 if checksum != h[:4]: 353 raise ValueError("got bad checksum " + repr(self.recvbuf)) 354 self.recvbuf = self.recvbuf[4+12+4+4+msglen:] 355 if msgtype not in MESSAGEMAP: 356 raise ValueError("Received unknown msgtype from %s:%d: '%s' %s" % (self.dstaddr, self.dstport, msgtype, repr(msg))) 357 f = BytesIO(msg) 358 t = MESSAGEMAP[msgtype]() 359 t.deserialize(f) 360 self._log_message("receive", t) 361 self.on_message(t) 362 except Exception as e: 363 if not self.reconnect: 364 logger.exception('Error reading message:', repr(e)) 365 raise 366 367 def on_message(self, message): 368 """Callback for processing a P2P payload. Must be overridden by derived class.""" 369 raise NotImplementedError 370 371 # Socket write methods 372 373 def send_message(self, message, is_decoy=False): 374 """Send a P2P message over the socket. 375 376 This method takes a P2P payload, builds the P2P header and adds 377 the message to the send buffer to be sent over the socket.""" 378 with self._send_lock: 379 tmsg = self.build_message(message, is_decoy) 380 self._log_message("send", message) 381 return self.send_raw_message(tmsg) 382 383 def send_raw_message(self, raw_message_bytes): 384 if not self.is_connected: 385 raise IOError('Not connected') 386 387 def maybe_write(): 388 if not self._transport: 389 return 390 if self._transport.is_closing(): 391 return 392 self._transport.write(raw_message_bytes) 393 NetworkThread.network_event_loop.call_soon_threadsafe(maybe_write) 394 395 # Class utility methods 396 397 def build_message(self, message, is_decoy=False): 398 """Build a serialized P2P message""" 399 msgtype = message.msgtype 400 data = message.serialize() 401 if self.supports_v2_p2p: 402 if msgtype in SHORTID.values(): 403 tmsg = MSGTYPE_TO_SHORTID.get(msgtype).to_bytes(1, 'big') 404 else: 405 tmsg = b"\x00" 406 tmsg += msgtype 407 tmsg += b"\x00" * (12 - len(msgtype)) 408 tmsg += data 409 return self.v2_state.v2_enc_packet(tmsg, ignore=is_decoy) 410 else: 411 tmsg = self.magic_bytes 412 tmsg += msgtype 413 tmsg += b"\x00" * (12 - len(msgtype)) 414 tmsg += struct.pack("<I", len(data)) 415 th = sha256(data) 416 h = sha256(th) 417 tmsg += h[:4] 418 tmsg += data 419 return tmsg 420 421 def _log_message(self, direction, msg): 422 """Logs a message being sent or received over the connection.""" 423 if direction == "send": 424 log_message = "Send message to " 425 elif direction == "receive": 426 log_message = "Received message from " 427 log_message += "%s:%d: %s" % (self.dstaddr, self.dstport, repr(msg)[:500]) 428 if len(log_message) > 500: 429 log_message += "... (msg truncated)" 430 logger.debug(log_message) 431 432 433 class P2PInterface(P2PConnection): 434 """A high-level P2P interface class for communicating with a Bitcoin node. 435 436 This class provides high-level callbacks for processing P2P message 437 payloads, as well as convenience methods for interacting with the 438 node over P2P. 439 440 Individual testcases should subclass this and override the on_* methods 441 if they want to alter message handling behaviour.""" 442 def __init__(self, support_addrv2=False, wtxidrelay=True): 443 super().__init__() 444 445 # Track number of messages of each type received. 446 # Should be read-only in a test. 447 self.message_count = defaultdict(int) 448 449 # Track the most recent message of each type. 450 # To wait for a message to be received, pop that message from 451 # this and use self.wait_until. 452 self.last_message = {} 453 454 # A count of the number of ping messages we've sent to the node 455 self.ping_counter = 1 456 457 # The network services received from the peer 458 self.nServices = 0 459 460 self.support_addrv2 = support_addrv2 461 462 # If the peer supports wtxid-relay 463 self.wtxidrelay = wtxidrelay 464 465 def peer_connect_send_version(self, services): 466 # Send a version msg 467 vt = msg_version() 468 vt.nVersion = P2P_VERSION 469 vt.strSubVer = P2P_SUBVERSION 470 vt.relay = P2P_VERSION_RELAY 471 vt.nServices = services 472 vt.addrTo.ip = self.dstaddr 473 vt.addrTo.port = self.dstport 474 vt.addrFrom.ip = "0.0.0.0" 475 vt.addrFrom.port = 0 476 self.on_connection_send_msg = vt # Will be sent in connection_made callback 477 478 def peer_connect(self, *, services=P2P_SERVICES, send_version, **kwargs): 479 create_conn = super().peer_connect(**kwargs) 480 481 if send_version: 482 self.peer_connect_send_version(services) 483 484 return create_conn 485 486 def peer_accept_connection(self, *args, services=P2P_SERVICES, **kwargs): 487 create_conn = super().peer_accept_connection(*args, **kwargs) 488 self.peer_connect_send_version(services) 489 490 return create_conn 491 492 # Message receiving methods 493 494 def on_message(self, message): 495 """Receive message and dispatch message to appropriate callback. 496 497 We keep a count of how many of each message type has been received 498 and the most recent message of each type.""" 499 with p2p_lock: 500 try: 501 msgtype = message.msgtype.decode('ascii') 502 self.message_count[msgtype] += 1 503 self.last_message[msgtype] = message 504 getattr(self, 'on_' + msgtype)(message) 505 except Exception: 506 print("ERROR delivering %s (%s)" % (repr(message), sys.exc_info()[0])) 507 raise 508 509 # Callback methods. Can be overridden by subclasses in individual test 510 # cases to provide custom message handling behaviour. 511 512 def on_open(self): 513 pass 514 515 def on_close(self): 516 pass 517 518 def on_addr(self, message): pass 519 def on_addrv2(self, message): pass 520 def on_block(self, message): pass 521 def on_blocktxn(self, message): pass 522 def on_cfcheckpt(self, message): pass 523 def on_cfheaders(self, message): pass 524 def on_cfilter(self, message): pass 525 def on_cmpctblock(self, message): pass 526 def on_feefilter(self, message): pass 527 def on_filteradd(self, message): pass 528 def on_filterclear(self, message): pass 529 def on_filterload(self, message): pass 530 def on_getaddr(self, message): pass 531 def on_getblocks(self, message): pass 532 def on_getblocktxn(self, message): pass 533 def on_getdata(self, message): pass 534 def on_getheaders(self, message): pass 535 def on_headers(self, message): pass 536 def on_mempool(self, message): pass 537 def on_merkleblock(self, message): pass 538 def on_notfound(self, message): pass 539 def on_pong(self, message): pass 540 def on_sendaddrv2(self, message): pass 541 def on_sendcmpct(self, message): pass 542 def on_sendheaders(self, message): pass 543 def on_sendtxrcncl(self, message): pass 544 def on_tx(self, message): pass 545 def on_wtxidrelay(self, message): pass 546 547 def on_inv(self, message): 548 want = msg_getdata() 549 for i in message.inv: 550 if i.type != 0: 551 want.inv.append(i) 552 if len(want.inv): 553 self.send_message(want) 554 555 def on_ping(self, message): 556 self.send_message(msg_pong(message.nonce)) 557 558 def on_verack(self, message): 559 pass 560 561 def on_version(self, message): 562 assert message.nVersion >= MIN_P2P_VERSION_SUPPORTED, "Version {} received. Test framework only supports versions greater than {}".format(message.nVersion, MIN_P2P_VERSION_SUPPORTED) 563 # for inbound connections, reply to version with own version message 564 # (could be due to v1 reconnect after a failed v2 handshake) 565 if not self.p2p_connected_to_node: 566 self.send_version() 567 self.reconnect = False 568 if message.nVersion >= 70016 and self.wtxidrelay: 569 self.send_message(msg_wtxidrelay()) 570 if self.support_addrv2: 571 self.send_message(msg_sendaddrv2()) 572 self.send_message(msg_verack()) 573 self.nServices = message.nServices 574 self.relay = message.relay 575 if self.p2p_connected_to_node: 576 self.send_message(msg_getaddr()) 577 578 # Connection helper methods 579 580 def wait_until(self, test_function_in, *, timeout=60, check_connected=True): 581 def test_function(): 582 if check_connected: 583 assert self.is_connected 584 return test_function_in() 585 586 wait_until_helper_internal(test_function, timeout=timeout, lock=p2p_lock, timeout_factor=self.timeout_factor) 587 588 def wait_for_connect(self, timeout=60): 589 test_function = lambda: self.is_connected 590 self.wait_until(test_function, timeout=timeout, check_connected=False) 591 592 def wait_for_disconnect(self, timeout=60): 593 test_function = lambda: not self.is_connected 594 self.wait_until(test_function, timeout=timeout, check_connected=False) 595 596 def wait_for_reconnect(self, timeout=60): 597 def test_function(): 598 return self.is_connected and self.last_message.get('version') and not self.supports_v2_p2p 599 self.wait_until(test_function, timeout=timeout, check_connected=False) 600 601 # Message receiving helper methods 602 603 def wait_for_tx(self, txid, timeout=60): 604 def test_function(): 605 if not self.last_message.get('tx'): 606 return False 607 return self.last_message['tx'].tx.rehash() == txid 608 609 self.wait_until(test_function, timeout=timeout) 610 611 def wait_for_block(self, blockhash, timeout=60): 612 def test_function(): 613 return self.last_message.get("block") and self.last_message["block"].block.rehash() == blockhash 614 615 self.wait_until(test_function, timeout=timeout) 616 617 def wait_for_header(self, blockhash, timeout=60): 618 def test_function(): 619 last_headers = self.last_message.get('headers') 620 if not last_headers: 621 return False 622 return last_headers.headers[0].rehash() == int(blockhash, 16) 623 624 self.wait_until(test_function, timeout=timeout) 625 626 def wait_for_merkleblock(self, blockhash, timeout=60): 627 def test_function(): 628 last_filtered_block = self.last_message.get('merkleblock') 629 if not last_filtered_block: 630 return False 631 return last_filtered_block.merkleblock.header.rehash() == int(blockhash, 16) 632 633 self.wait_until(test_function, timeout=timeout) 634 635 def wait_for_getdata(self, hash_list, timeout=60): 636 """Waits for a getdata message. 637 638 The object hashes in the inventory vector must match the provided hash_list.""" 639 def test_function(): 640 last_data = self.last_message.get("getdata") 641 if not last_data: 642 return False 643 return [x.hash for x in last_data.inv] == hash_list 644 645 self.wait_until(test_function, timeout=timeout) 646 647 def wait_for_getheaders(self, timeout=60): 648 """Waits for a getheaders message. 649 650 Receiving any getheaders message will satisfy the predicate. the last_message["getheaders"] 651 value must be explicitly cleared before calling this method, or this will return 652 immediately with success. TODO: change this method to take a hash value and only 653 return true if the correct block header has been requested.""" 654 def test_function(): 655 return self.last_message.get("getheaders") 656 657 self.wait_until(test_function, timeout=timeout) 658 659 def wait_for_inv(self, expected_inv, timeout=60): 660 """Waits for an INV message and checks that the first inv object in the message was as expected.""" 661 if len(expected_inv) > 1: 662 raise NotImplementedError("wait_for_inv() will only verify the first inv object") 663 664 def test_function(): 665 return self.last_message.get("inv") and \ 666 self.last_message["inv"].inv[0].type == expected_inv[0].type and \ 667 self.last_message["inv"].inv[0].hash == expected_inv[0].hash 668 669 self.wait_until(test_function, timeout=timeout) 670 671 def wait_for_verack(self, timeout=60): 672 def test_function(): 673 return "verack" in self.last_message 674 675 self.wait_until(test_function, timeout=timeout) 676 677 # Message sending helper functions 678 679 def send_version(self): 680 if self.on_connection_send_msg: 681 self.send_message(self.on_connection_send_msg) 682 self.on_connection_send_msg = None # Never used again 683 684 def send_and_ping(self, message, timeout=60): 685 self.send_message(message) 686 self.sync_with_ping(timeout=timeout) 687 688 def sync_with_ping(self, timeout=60): 689 """Ensure ProcessMessages and SendMessages is called on this connection""" 690 # Sending two pings back-to-back, requires that the node calls 691 # `ProcessMessage` twice, and thus ensures `SendMessages` must have 692 # been called at least once 693 self.send_message(msg_ping(nonce=0)) 694 self.send_message(msg_ping(nonce=self.ping_counter)) 695 696 def test_function(): 697 return self.last_message.get("pong") and self.last_message["pong"].nonce == self.ping_counter 698 699 self.wait_until(test_function, timeout=timeout) 700 self.ping_counter += 1 701 702 703 # One lock for synchronizing all data access between the network event loop (see 704 # NetworkThread below) and the thread running the test logic. For simplicity, 705 # P2PConnection acquires this lock whenever delivering a message to a P2PInterface. 706 # This lock should be acquired in the thread running the test logic to synchronize 707 # access to any data shared with the P2PInterface or P2PConnection. 708 p2p_lock = threading.Lock() 709 710 711 class NetworkThread(threading.Thread): 712 network_event_loop = None 713 714 def __init__(self): 715 super().__init__(name="NetworkThread") 716 # There is only one event loop and no more than one thread must be created 717 assert not self.network_event_loop 718 719 NetworkThread.listeners = {} 720 NetworkThread.protos = {} 721 if platform.system() == 'Windows': 722 asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) 723 NetworkThread.network_event_loop = asyncio.new_event_loop() 724 725 def run(self): 726 """Start the network thread.""" 727 self.network_event_loop.run_forever() 728 729 def close(self, timeout=10): 730 """Close the connections and network event loop.""" 731 self.network_event_loop.call_soon_threadsafe(self.network_event_loop.stop) 732 wait_until_helper_internal(lambda: not self.network_event_loop.is_running(), timeout=timeout) 733 self.network_event_loop.close() 734 self.join(timeout) 735 # Safe to remove event loop. 736 NetworkThread.network_event_loop = None 737 738 @classmethod 739 def listen(cls, p2p, callback, port=None, addr=None, idx=1): 740 """ Ensure a listening server is running on the given port, and run the 741 protocol specified by `p2p` on the next connection to it. Once ready 742 for connections, call `callback`.""" 743 744 if port is None: 745 assert 0 < idx <= MAX_NODES 746 port = p2p_port(MAX_NODES - idx) 747 if addr is None: 748 addr = '127.0.0.1' 749 750 def exception_handler(loop, context): 751 if not p2p.reconnect: 752 loop.default_exception_handler(context) 753 754 cls.network_event_loop.set_exception_handler(exception_handler) 755 coroutine = cls.create_listen_server(addr, port, callback, p2p) 756 cls.network_event_loop.call_soon_threadsafe(cls.network_event_loop.create_task, coroutine) 757 758 @classmethod 759 async def create_listen_server(cls, addr, port, callback, proto): 760 def peer_protocol(): 761 """Returns a function that does the protocol handling for a new 762 connection. To allow different connections to have different 763 behaviors, the protocol function is first put in the cls.protos 764 dict. When the connection is made, the function removes the 765 protocol function from that dict, and returns it so the event loop 766 can start executing it.""" 767 response = cls.protos.get((addr, port)) 768 # remove protocol function from dict only when reconnection doesn't need to happen/already happened 769 if not proto.reconnect: 770 cls.protos[(addr, port)] = None 771 return response 772 773 if (addr, port) not in cls.listeners: 774 # When creating a listener on a given (addr, port) we only need to 775 # do it once. If we want different behaviors for different 776 # connections, we can accomplish this by providing different 777 # `proto` functions 778 779 listener = await cls.network_event_loop.create_server(peer_protocol, addr, port) 780 logger.debug("Listening server on %s:%d should be started" % (addr, port)) 781 cls.listeners[(addr, port)] = listener 782 783 cls.protos[(addr, port)] = proto 784 callback(addr, port) 785 786 787 class P2PDataStore(P2PInterface): 788 """A P2P data store class. 789 790 Keeps a block and transaction store and responds correctly to getdata and getheaders requests.""" 791 792 def __init__(self): 793 super().__init__() 794 # store of blocks. key is block hash, value is a CBlock object 795 self.block_store = {} 796 self.last_block_hash = '' 797 # store of txs. key is txid, value is a CTransaction object 798 self.tx_store = {} 799 self.getdata_requests = [] 800 801 def on_getdata(self, message): 802 """Check for the tx/block in our stores and if found, reply with an inv message.""" 803 for inv in message.inv: 804 self.getdata_requests.append(inv.hash) 805 if (inv.type & MSG_TYPE_MASK) == MSG_TX and inv.hash in self.tx_store.keys(): 806 self.send_message(msg_tx(self.tx_store[inv.hash])) 807 elif (inv.type & MSG_TYPE_MASK) == MSG_BLOCK and inv.hash in self.block_store.keys(): 808 self.send_message(msg_block(self.block_store[inv.hash])) 809 else: 810 logger.debug('getdata message type {} received.'.format(hex(inv.type))) 811 812 def on_getheaders(self, message): 813 """Search back through our block store for the locator, and reply with a headers message if found.""" 814 815 locator, hash_stop = message.locator, message.hashstop 816 817 # Assume that the most recent block added is the tip 818 if not self.block_store: 819 return 820 821 headers_list = [self.block_store[self.last_block_hash]] 822 while headers_list[-1].sha256 not in locator.vHave: 823 # Walk back through the block store, adding headers to headers_list 824 # as we go. 825 prev_block_hash = headers_list[-1].hashPrevBlock 826 if prev_block_hash in self.block_store: 827 prev_block_header = CBlockHeader(self.block_store[prev_block_hash]) 828 headers_list.append(prev_block_header) 829 if prev_block_header.sha256 == hash_stop: 830 # if this is the hashstop header, stop here 831 break 832 else: 833 logger.debug('block hash {} not found in block store'.format(hex(prev_block_hash))) 834 break 835 836 # Truncate the list if there are too many headers 837 headers_list = headers_list[:-MAX_HEADERS_RESULTS - 1:-1] 838 response = msg_headers(headers_list) 839 840 if response is not None: 841 self.send_message(response) 842 843 def send_blocks_and_test(self, blocks, node, *, success=True, force_send=False, reject_reason=None, expect_disconnect=False, timeout=60, is_decoy=False): 844 """Send blocks to test node and test whether the tip advances. 845 846 - add all blocks to our block_store 847 - send a headers message for the final block 848 - the on_getheaders handler will ensure that any getheaders are responded to 849 - if force_send is False: wait for getdata for each of the blocks. The on_getdata handler will 850 ensure that any getdata messages are responded to. Otherwise send the full block unsolicited. 851 - if success is True: assert that the node's tip advances to the most recent block 852 - if success is False: assert that the node's tip doesn't advance 853 - if reject_reason is set: assert that the correct reject message is logged""" 854 855 with p2p_lock: 856 for block in blocks: 857 self.block_store[block.sha256] = block 858 self.last_block_hash = block.sha256 859 860 reject_reason = [reject_reason] if reject_reason else [] 861 with node.assert_debug_log(expected_msgs=reject_reason): 862 if is_decoy: # since decoy messages are ignored by the recipient - no need to wait for response 863 force_send = True 864 if force_send: 865 for b in blocks: 866 self.send_message(msg_block(block=b), is_decoy) 867 else: 868 self.send_message(msg_headers([CBlockHeader(block) for block in blocks])) 869 self.wait_until( 870 lambda: blocks[-1].sha256 in self.getdata_requests, 871 timeout=timeout, 872 check_connected=success, 873 ) 874 875 if expect_disconnect: 876 self.wait_for_disconnect(timeout=timeout) 877 else: 878 self.sync_with_ping(timeout=timeout) 879 880 if success: 881 self.wait_until(lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout) 882 else: 883 assert node.getbestblockhash() != blocks[-1].hash 884 885 def send_txs_and_test(self, txs, node, *, success=True, expect_disconnect=False, reject_reason=None): 886 """Send txs to test node and test whether they're accepted to the mempool. 887 888 - add all txs to our tx_store 889 - send tx messages for all txs 890 - if success is True/False: assert that the txs are/are not accepted to the mempool 891 - if expect_disconnect is True: Skip the sync with ping 892 - if reject_reason is set: assert that the correct reject message is logged.""" 893 894 with p2p_lock: 895 for tx in txs: 896 self.tx_store[tx.sha256] = tx 897 898 reject_reason = [reject_reason] if reject_reason else [] 899 with node.assert_debug_log(expected_msgs=reject_reason): 900 for tx in txs: 901 self.send_message(msg_tx(tx)) 902 903 if expect_disconnect: 904 self.wait_for_disconnect() 905 else: 906 self.sync_with_ping() 907 908 raw_mempool = node.getrawmempool() 909 if success: 910 # Check that all txs are now in the mempool 911 for tx in txs: 912 assert tx.hash in raw_mempool, "{} not found in mempool".format(tx.hash) 913 else: 914 # Check that none of the txs are now in the mempool 915 for tx in txs: 916 assert tx.hash not in raw_mempool, "{} tx found in mempool".format(tx.hash) 917 918 class P2PTxInvStore(P2PInterface): 919 """A P2PInterface which stores a count of how many times each txid has been announced.""" 920 def __init__(self): 921 super().__init__() 922 self.tx_invs_received = defaultdict(int) 923 924 def on_inv(self, message): 925 super().on_inv(message) # Send getdata in response. 926 # Store how many times invs have been received for each tx. 927 for i in message.inv: 928 if (i.type == MSG_TX) or (i.type == MSG_WTX): 929 # save txid 930 self.tx_invs_received[i.hash] += 1 931 932 def get_invs(self): 933 with p2p_lock: 934 return list(self.tx_invs_received.keys()) 935 936 def wait_for_broadcast(self, txns, timeout=60): 937 """Waits for the txns (list of txids) to complete initial broadcast. 938 The mempool should mark unbroadcast=False for these transactions. 939 """ 940 # Wait until invs have been received (and getdatas sent) for each txid. 941 self.wait_until(lambda: set(self.tx_invs_received.keys()) == set([int(tx, 16) for tx in txns]), timeout=timeout) 942 # Flush messages and wait for the getdatas to be processed 943 self.sync_with_ping()