p2p.py
  1  #!/usr/bin/env python3
  2  # Copyright (c) 2010 ArtForz -- public domain half-a-node
  3  # Copyright (c) 2012 Jeff Garzik
  4  # Copyright (c) 2010-2022 The Bitcoin Core developers
  5  # Distributed under the MIT software license, see the accompanying
  6  # file COPYING or http://www.opensource.org/licenses/mit-license.php.
  7  """Test objects for interacting with a bitcoind node over the p2p protocol.
  8  
  9  The P2PInterface objects interact with the bitcoind nodes under test using the
 10  node's p2p interface. They can be used to send messages to the node, and
 11  callbacks can be registered that execute when messages are received from the
 12  node. Messages are sent to/received from the node on an asyncio event loop.
 13  State held inside the objects must be guarded by the p2p_lock to avoid data
 14  races between the main testing thread and the event loop.
 15  
 16  P2PConnection: A low-level connection object to a node's P2P interface
 17  P2PInterface: A high-level interface object for communicating to a node over P2P
 18  P2PDataStore: A p2p interface class that keeps a store of transactions and blocks
 19                and can respond correctly to getdata and getheaders messages
 20  P2PTxInvStore: A p2p interface class that inherits from P2PDataStore, and keeps
 21                a count of how many times each txid has been announced."""
 22  
 23  import asyncio
 24  from collections import defaultdict
 25  from io import BytesIO
 26  import logging
 27  import platform
 28  import struct
 29  import sys
 30  import threading
 31  
 32  from test_framework.messages import (
 33      CBlockHeader,
 34      MAX_HEADERS_RESULTS,
 35      msg_addr,
 36      msg_addrv2,
 37      msg_block,
 38      MSG_BLOCK,
 39      msg_blocktxn,
 40      msg_cfcheckpt,
 41      msg_cfheaders,
 42      msg_cfilter,
 43      msg_cmpctblock,
 44      msg_feefilter,
 45      msg_filteradd,
 46      msg_filterclear,
 47      msg_filterload,
 48      msg_getaddr,
 49      msg_getblocks,
 50      msg_getblocktxn,
 51      msg_getcfcheckpt,
 52      msg_getcfheaders,
 53      msg_getcfilters,
 54      msg_getdata,
 55      msg_getheaders,
 56      msg_headers,
 57      msg_inv,
 58      msg_mempool,
 59      msg_merkleblock,
 60      msg_notfound,
 61      msg_ping,
 62      msg_pong,
 63      msg_sendaddrv2,
 64      msg_sendcmpct,
 65      msg_sendheaders,
 66      msg_sendtxrcncl,
 67      msg_tx,
 68      MSG_TX,
 69      MSG_TYPE_MASK,
 70      msg_verack,
 71      msg_version,
 72      MSG_WTX,
 73      msg_wtxidrelay,
 74      NODE_NETWORK,
 75      NODE_WITNESS,
 76      MAGIC_BYTES,
 77      sha256,
 78  )
 79  from test_framework.util import (
 80      MAX_NODES,
 81      p2p_port,
 82      wait_until_helper_internal,
 83  )
 84  from test_framework.v2_p2p import (
 85      EncryptedP2PState,
 86      MSGTYPE_TO_SHORTID,
 87      SHORTID,
 88  )
 89  
 90  logger = logging.getLogger("TestFramework.p2p")
 91  
 92  # The minimum P2P version that this test framework supports
 93  MIN_P2P_VERSION_SUPPORTED = 60001
 94  # The P2P version that this test framework implements and sends in its `version` message
 95  # Version 70016 supports wtxid relay
 96  P2P_VERSION = 70016
 97  # The services that this test framework offers in its `version` message
 98  P2P_SERVICES = NODE_NETWORK | NODE_WITNESS
 99  # The P2P user agent string that this test framework sends in its `version` message
100  P2P_SUBVERSION = "/python-p2p-tester:0.0.3/"
101  # Value for relay that this test framework sends in its `version` message
102  P2P_VERSION_RELAY = 1
103  # Delay after receiving a tx inv before requesting transactions from non-preferred peers, in seconds
104  NONPREF_PEER_TX_DELAY = 2
105  # Delay for requesting transactions via txids if we have wtxid-relaying peers, in seconds
106  TXID_RELAY_DELAY = 2
107  # Delay for requesting transactions if the peer has MAX_PEER_TX_REQUEST_IN_FLIGHT or more requests
108  OVERLOADED_PEER_TX_DELAY = 2
109  # How long to wait before downloading a transaction from an additional peer
110  GETDATA_TX_INTERVAL = 60
111  
112  MESSAGEMAP = {
113      b"addr": msg_addr,
114      b"addrv2": msg_addrv2,
115      b"block": msg_block,
116      b"blocktxn": msg_blocktxn,
117      b"cfcheckpt": msg_cfcheckpt,
118      b"cfheaders": msg_cfheaders,
119      b"cfilter": msg_cfilter,
120      b"cmpctblock": msg_cmpctblock,
121      b"feefilter": msg_feefilter,
122      b"filteradd": msg_filteradd,
123      b"filterclear": msg_filterclear,
124      b"filterload": msg_filterload,
125      b"getaddr": msg_getaddr,
126      b"getblocks": msg_getblocks,
127      b"getblocktxn": msg_getblocktxn,
128      b"getcfcheckpt": msg_getcfcheckpt,
129      b"getcfheaders": msg_getcfheaders,
130      b"getcfilters": msg_getcfilters,
131      b"getdata": msg_getdata,
132      b"getheaders": msg_getheaders,
133      b"headers": msg_headers,
134      b"inv": msg_inv,
135      b"mempool": msg_mempool,
136      b"merkleblock": msg_merkleblock,
137      b"notfound": msg_notfound,
138      b"ping": msg_ping,
139      b"pong": msg_pong,
140      b"sendaddrv2": msg_sendaddrv2,
141      b"sendcmpct": msg_sendcmpct,
142      b"sendheaders": msg_sendheaders,
143      b"sendtxrcncl": msg_sendtxrcncl,
144      b"tx": msg_tx,
145      b"verack": msg_verack,
146      b"version": msg_version,
147      b"wtxidrelay": msg_wtxidrelay,
148  }
149  
150  
151  class P2PConnection(asyncio.Protocol):
152      """A low-level connection object to a node's P2P interface.
153  
154      This class is responsible for:
155  
156      - opening and closing the TCP connection to the node
157      - reading bytes from and writing bytes to the socket
158      - deserializing and serializing the P2P message header
159      - logging messages as they are sent and received
160  
161      This class contains no logic for handing the P2P message payloads. It must be
162      sub-classed and the on_message() callback overridden."""
163  
164      def __init__(self):
165          # The underlying transport of the connection.
166          # Should only call methods on this from the NetworkThread, c.f. call_soon_threadsafe
167          self._transport = None
168          # This lock is acquired before sending messages over the socket. There's an implied lock order and
169          # p2p_lock must not be acquired after _send_lock as it could result in deadlocks.
170          self._send_lock = threading.Lock()
171          self.v2_state = None  # EncryptedP2PState object needed for v2 p2p connections
172          self.reconnect = False  # set if reconnection needs to happen
173  
174      @property
175      def is_connected(self):
176          return self._transport is not None
177  
178      @property
179      def supports_v2_p2p(self):
180          return self.v2_state is not None
181  
182      def peer_connect_helper(self, dstaddr, dstport, net, timeout_factor):
183          assert not self.is_connected
184          self.timeout_factor = timeout_factor
185          self.dstaddr = dstaddr
186          self.dstport = dstport
187          # The initial message to send after the connection was made:
188          self.on_connection_send_msg = None
189          self.recvbuf = b""
190          self.magic_bytes = MAGIC_BYTES[net]
191  
192      def peer_connect(self, dstaddr, dstport, *, net, timeout_factor, supports_v2_p2p):
193          self.peer_connect_helper(dstaddr, dstport, net, timeout_factor)
194          if supports_v2_p2p:
195              self.v2_state = EncryptedP2PState(initiating=True, net=net)
196  
197          loop = NetworkThread.network_event_loop
198          logger.debug('Connecting to Bitcoin Node: %s:%d' % (self.dstaddr, self.dstport))
199          coroutine = loop.create_connection(lambda: self, host=self.dstaddr, port=self.dstport)
200          return lambda: loop.call_soon_threadsafe(loop.create_task, coroutine)
201  
202      def peer_accept_connection(self, connect_id, connect_cb=lambda: None, *, net, timeout_factor, supports_v2_p2p, reconnect):
203          self.peer_connect_helper('0', 0, net, timeout_factor)
204          self.reconnect = reconnect
205          if supports_v2_p2p:
206              self.v2_state = EncryptedP2PState(initiating=False, net=net)
207  
208          logger.debug('Listening for Bitcoin Node with id: {}'.format(connect_id))
209          return lambda: NetworkThread.listen(self, connect_cb, idx=connect_id)
210  
211      def peer_disconnect(self):
212          # Connection could have already been closed by other end.
213          NetworkThread.network_event_loop.call_soon_threadsafe(lambda: self._transport and self._transport.abort())
214  
215      # Connection and disconnection methods
216  
217      def connection_made(self, transport):
218          """asyncio callback when a connection is opened."""
219          assert not self._transport
220          logger.debug("Connected & Listening: %s:%d" % (self.dstaddr, self.dstport))
221          self._transport = transport
222          # in an inbound connection to the TestNode with P2PConnection as the initiator, [TestNode <---- P2PConnection]
223          # send the initial handshake immediately
224          if self.supports_v2_p2p and self.v2_state.initiating and not self.v2_state.tried_v2_handshake:
225              send_handshake_bytes = self.v2_state.initiate_v2_handshake()
226              self.send_raw_message(send_handshake_bytes)
227          # for v1 outbound connections, send version message immediately after opening
228          # (for v2 outbound connections, send it after the initial v2 handshake)
229          if self.p2p_connected_to_node and not self.supports_v2_p2p:
230              self.send_version()
231          self.on_open()
232  
233      def connection_lost(self, exc):
234          """asyncio callback when a connection is closed."""
235          # don't display warning if reconnection needs to be attempted using v1 P2P
236          if exc and not self.reconnect:
237              logger.warning("Connection lost to {}:{} due to {}".format(self.dstaddr, self.dstport, exc))
238          else:
239              logger.debug("Closed connection to: %s:%d" % (self.dstaddr, self.dstport))
240          self._transport = None
241          self.recvbuf = b""
242          self.on_close()
243  
244      # v2 handshake method
245      def _on_data_v2_handshake(self):
246          """v2 handshake performed before P2P messages are exchanged (see BIP324). P2PConnection is the initiator
247          (in inbound connections to TestNode) and the responder (in outbound connections from TestNode).
248          Performed by:
249              * initiator using `initiate_v2_handshake()`, `complete_handshake()` and `authenticate_handshake()`
250              * responder using `respond_v2_handshake()`, `complete_handshake()` and `authenticate_handshake()`
251  
252          `initiate_v2_handshake()` is immediately done by the initiator when the connection is established in
253          `connection_made()`. The rest of the initial v2 handshake functions are handled here.
254          """
255          if not self.v2_state.peer:
256              if not self.v2_state.initiating and not self.v2_state.sent_garbage:
257                  # if the responder hasn't sent garbage yet, the responder is still reading ellswift bytes
258                  # reads ellswift bytes till the first mismatch from 12 bytes V1_PREFIX
259                  length, send_handshake_bytes = self.v2_state.respond_v2_handshake(BytesIO(self.recvbuf))
260                  self.recvbuf = self.recvbuf[length:]
261                  if send_handshake_bytes == -1:
262                      self.v2_state = None
263                      return
264                  elif send_handshake_bytes:
265                      self.send_raw_message(send_handshake_bytes)
266                  elif send_handshake_bytes == b"":
267                      return  # only after send_handshake_bytes are sent can `complete_handshake()` be done
268  
269              # `complete_handshake()` reads the remaining ellswift bytes from recvbuf
270              # and sends response after deriving shared ECDH secret using received ellswift bytes
271              length, response = self.v2_state.complete_handshake(BytesIO(self.recvbuf))
272              self.recvbuf = self.recvbuf[length:]
273              if response:
274                  self.send_raw_message(response)
275              else:
276                  return  # only after response is sent can `authenticate_handshake()` be done
277  
278          # `self.v2_state.peer` is instantiated only after shared ECDH secret/BIP324 derived keys and ciphers
279          # is derived in `complete_handshake()`.
280          # so `authenticate_handshake()` which uses the BIP324 derived ciphers gets called after `complete_handshake()`.
281          assert self.v2_state.peer
282          length, is_mac_auth = self.v2_state.authenticate_handshake(self.recvbuf)
283          if not is_mac_auth:
284              raise ValueError("invalid v2 mac tag in handshake authentication")
285          self.recvbuf = self.recvbuf[length:]
286          if self.v2_state.tried_v2_handshake:
287              # for v2 outbound connections, send version message immediately after v2 handshake
288              if self.p2p_connected_to_node:
289                  self.send_version()
290              # process post-v2-handshake data immediately, if available
291              if len(self.recvbuf) > 0:
292                  self._on_data()
293  
294      # Socket read methods
295  
296      def data_received(self, t):
297          """asyncio callback when data is read from the socket."""
298          if len(t) > 0:
299              self.recvbuf += t
300              if self.supports_v2_p2p and not self.v2_state.tried_v2_handshake:
301                  self._on_data_v2_handshake()
302              else:
303                  self._on_data()
304  
305      def _on_data(self):
306          """Try to read P2P messages from the recv buffer.
307  
308          This method reads data from the buffer in a loop. It deserializes,
309          parses and verifies the P2P header, then passes the P2P payload to
310          the on_message callback for processing."""
311          try:
312              while True:
313                  if self.supports_v2_p2p:
314                      # v2 P2P messages are read
315                      msglen, msg = self.v2_state.v2_receive_packet(self.recvbuf)
316                      if msglen == -1:
317                          raise ValueError("invalid v2 mac tag " + repr(self.recvbuf))
318                      elif msglen == 0:  # need to receive more bytes in recvbuf
319                          return
320                      self.recvbuf = self.recvbuf[msglen:]
321  
322                      if msg is None:  # ignore decoy messages
323                          return
324                      assert msg  # application layer messages (which aren't decoy messages) are non-empty
325                      shortid = msg[0]  # 1-byte short message type ID
326                      if shortid == 0:
327                          # next 12 bytes are interpreted as ASCII message type if shortid is b'\x00'
328                          if len(msg) < 13:
329                              raise IndexError("msg needs minimum required length of 13 bytes")
330                          msgtype = msg[1:13].rstrip(b'\x00')
331                          msg = msg[13:]  # msg is set to be payload
332                      else:
333                          # a 1-byte short message type ID
334                          msgtype = SHORTID.get(shortid, f"unknown-{shortid}")
335                          msg = msg[1:]
336                  else:
337                      # v1 P2P messages are read
338                      if len(self.recvbuf) < 4:
339                          return
340                      if self.recvbuf[:4] != self.magic_bytes:
341                          raise ValueError("magic bytes mismatch: {} != {}".format(repr(self.magic_bytes), repr(self.recvbuf)))
342                      if len(self.recvbuf) < 4 + 12 + 4 + 4:
343                          return
344                      msgtype = self.recvbuf[4:4+12].split(b"\x00", 1)[0]
345                      msglen = struct.unpack("<i", self.recvbuf[4+12:4+12+4])[0]
346                      checksum = self.recvbuf[4+12+4:4+12+4+4]
347                      if len(self.recvbuf) < 4 + 12 + 4 + 4 + msglen:
348                          return
349                      msg = self.recvbuf[4+12+4+4:4+12+4+4+msglen]
350                      th = sha256(msg)
351                      h = sha256(th)
352                      if checksum != h[:4]:
353                          raise ValueError("got bad checksum " + repr(self.recvbuf))
354                      self.recvbuf = self.recvbuf[4+12+4+4+msglen:]
355                  if msgtype not in MESSAGEMAP:
356                      raise ValueError("Received unknown msgtype from %s:%d: '%s' %s" % (self.dstaddr, self.dstport, msgtype, repr(msg)))
357                  f = BytesIO(msg)
358                  t = MESSAGEMAP[msgtype]()
359                  t.deserialize(f)
360                  self._log_message("receive", t)
361                  self.on_message(t)
362          except Exception as e:
363              if not self.reconnect:
364                  logger.exception('Error reading message:', repr(e))
365              raise
366  
367      def on_message(self, message):
368          """Callback for processing a P2P payload. Must be overridden by derived class."""
369          raise NotImplementedError
370  
371      # Socket write methods
372  
373      def send_message(self, message, is_decoy=False):
374          """Send a P2P message over the socket.
375  
376          This method takes a P2P payload, builds the P2P header and adds
377          the message to the send buffer to be sent over the socket."""
378          with self._send_lock:
379              tmsg = self.build_message(message, is_decoy)
380              self._log_message("send", message)
381              return self.send_raw_message(tmsg)
382  
383      def send_raw_message(self, raw_message_bytes):
384          if not self.is_connected:
385              raise IOError('Not connected')
386  
387          def maybe_write():
388              if not self._transport:
389                  return
390              if self._transport.is_closing():
391                  return
392              self._transport.write(raw_message_bytes)
393          NetworkThread.network_event_loop.call_soon_threadsafe(maybe_write)
394  
395      # Class utility methods
396  
397      def build_message(self, message, is_decoy=False):
398          """Build a serialized P2P message"""
399          msgtype = message.msgtype
400          data = message.serialize()
401          if self.supports_v2_p2p:
402              if msgtype in SHORTID.values():
403                  tmsg = MSGTYPE_TO_SHORTID.get(msgtype).to_bytes(1, 'big')
404              else:
405                  tmsg = b"\x00"
406                  tmsg += msgtype
407                  tmsg += b"\x00" * (12 - len(msgtype))
408              tmsg += data
409              return self.v2_state.v2_enc_packet(tmsg, ignore=is_decoy)
410          else:
411              tmsg = self.magic_bytes
412              tmsg += msgtype
413              tmsg += b"\x00" * (12 - len(msgtype))
414              tmsg += struct.pack("<I", len(data))
415              th = sha256(data)
416              h = sha256(th)
417              tmsg += h[:4]
418              tmsg += data
419              return tmsg
420  
421      def _log_message(self, direction, msg):
422          """Logs a message being sent or received over the connection."""
423          if direction == "send":
424              log_message = "Send message to "
425          elif direction == "receive":
426              log_message = "Received message from "
427          log_message += "%s:%d: %s" % (self.dstaddr, self.dstport, repr(msg)[:500])
428          if len(log_message) > 500:
429              log_message += "... (msg truncated)"
430          logger.debug(log_message)
431  
432  
433  class P2PInterface(P2PConnection):
434      """A high-level P2P interface class for communicating with a Bitcoin node.
435  
436      This class provides high-level callbacks for processing P2P message
437      payloads, as well as convenience methods for interacting with the
438      node over P2P.
439  
440      Individual testcases should subclass this and override the on_* methods
441      if they want to alter message handling behaviour."""
442      def __init__(self, support_addrv2=False, wtxidrelay=True):
443          super().__init__()
444  
445          # Track number of messages of each type received.
446          # Should be read-only in a test.
447          self.message_count = defaultdict(int)
448  
449          # Track the most recent message of each type.
450          # To wait for a message to be received, pop that message from
451          # this and use self.wait_until.
452          self.last_message = {}
453  
454          # A count of the number of ping messages we've sent to the node
455          self.ping_counter = 1
456  
457          # The network services received from the peer
458          self.nServices = 0
459  
460          self.support_addrv2 = support_addrv2
461  
462          # If the peer supports wtxid-relay
463          self.wtxidrelay = wtxidrelay
464  
465      def peer_connect_send_version(self, services):
466          # Send a version msg
467          vt = msg_version()
468          vt.nVersion = P2P_VERSION
469          vt.strSubVer = P2P_SUBVERSION
470          vt.relay = P2P_VERSION_RELAY
471          vt.nServices = services
472          vt.addrTo.ip = self.dstaddr
473          vt.addrTo.port = self.dstport
474          vt.addrFrom.ip = "0.0.0.0"
475          vt.addrFrom.port = 0
476          self.on_connection_send_msg = vt  # Will be sent in connection_made callback
477  
478      def peer_connect(self, *, services=P2P_SERVICES, send_version, **kwargs):
479          create_conn = super().peer_connect(**kwargs)
480  
481          if send_version:
482              self.peer_connect_send_version(services)
483  
484          return create_conn
485  
486      def peer_accept_connection(self, *args, services=P2P_SERVICES, **kwargs):
487          create_conn = super().peer_accept_connection(*args, **kwargs)
488          self.peer_connect_send_version(services)
489  
490          return create_conn
491  
492      # Message receiving methods
493  
494      def on_message(self, message):
495          """Receive message and dispatch message to appropriate callback.
496  
497          We keep a count of how many of each message type has been received
498          and the most recent message of each type."""
499          with p2p_lock:
500              try:
501                  msgtype = message.msgtype.decode('ascii')
502                  self.message_count[msgtype] += 1
503                  self.last_message[msgtype] = message
504                  getattr(self, 'on_' + msgtype)(message)
505              except Exception:
506                  print("ERROR delivering %s (%s)" % (repr(message), sys.exc_info()[0]))
507                  raise
508  
509      # Callback methods. Can be overridden by subclasses in individual test
510      # cases to provide custom message handling behaviour.
511  
512      def on_open(self):
513          pass
514  
515      def on_close(self):
516          pass
517  
518      def on_addr(self, message): pass
519      def on_addrv2(self, message): pass
520      def on_block(self, message): pass
521      def on_blocktxn(self, message): pass
522      def on_cfcheckpt(self, message): pass
523      def on_cfheaders(self, message): pass
524      def on_cfilter(self, message): pass
525      def on_cmpctblock(self, message): pass
526      def on_feefilter(self, message): pass
527      def on_filteradd(self, message): pass
528      def on_filterclear(self, message): pass
529      def on_filterload(self, message): pass
530      def on_getaddr(self, message): pass
531      def on_getblocks(self, message): pass
532      def on_getblocktxn(self, message): pass
533      def on_getdata(self, message): pass
534      def on_getheaders(self, message): pass
535      def on_headers(self, message): pass
536      def on_mempool(self, message): pass
537      def on_merkleblock(self, message): pass
538      def on_notfound(self, message): pass
539      def on_pong(self, message): pass
540      def on_sendaddrv2(self, message): pass
541      def on_sendcmpct(self, message): pass
542      def on_sendheaders(self, message): pass
543      def on_sendtxrcncl(self, message): pass
544      def on_tx(self, message): pass
545      def on_wtxidrelay(self, message): pass
546  
547      def on_inv(self, message):
548          want = msg_getdata()
549          for i in message.inv:
550              if i.type != 0:
551                  want.inv.append(i)
552          if len(want.inv):
553              self.send_message(want)
554  
555      def on_ping(self, message):
556          self.send_message(msg_pong(message.nonce))
557  
558      def on_verack(self, message):
559          pass
560  
561      def on_version(self, message):
562          assert message.nVersion >= MIN_P2P_VERSION_SUPPORTED, "Version {} received. Test framework only supports versions greater than {}".format(message.nVersion, MIN_P2P_VERSION_SUPPORTED)
563          # for inbound connections, reply to version with own version message
564          # (could be due to v1 reconnect after a failed v2 handshake)
565          if not self.p2p_connected_to_node:
566              self.send_version()
567              self.reconnect = False
568          if message.nVersion >= 70016 and self.wtxidrelay:
569              self.send_message(msg_wtxidrelay())
570          if self.support_addrv2:
571              self.send_message(msg_sendaddrv2())
572          self.send_message(msg_verack())
573          self.nServices = message.nServices
574          self.relay = message.relay
575          if self.p2p_connected_to_node:
576              self.send_message(msg_getaddr())
577  
578      # Connection helper methods
579  
580      def wait_until(self, test_function_in, *, timeout=60, check_connected=True):
581          def test_function():
582              if check_connected:
583                  assert self.is_connected
584              return test_function_in()
585  
586          wait_until_helper_internal(test_function, timeout=timeout, lock=p2p_lock, timeout_factor=self.timeout_factor)
587  
588      def wait_for_connect(self, timeout=60):
589          test_function = lambda: self.is_connected
590          self.wait_until(test_function, timeout=timeout, check_connected=False)
591  
592      def wait_for_disconnect(self, timeout=60):
593          test_function = lambda: not self.is_connected
594          self.wait_until(test_function, timeout=timeout, check_connected=False)
595  
596      def wait_for_reconnect(self, timeout=60):
597          def test_function():
598              return self.is_connected and self.last_message.get('version') and not self.supports_v2_p2p
599          self.wait_until(test_function, timeout=timeout, check_connected=False)
600  
601      # Message receiving helper methods
602  
603      def wait_for_tx(self, txid, timeout=60):
604          def test_function():
605              if not self.last_message.get('tx'):
606                  return False
607              return self.last_message['tx'].tx.rehash() == txid
608  
609          self.wait_until(test_function, timeout=timeout)
610  
611      def wait_for_block(self, blockhash, timeout=60):
612          def test_function():
613              return self.last_message.get("block") and self.last_message["block"].block.rehash() == blockhash
614  
615          self.wait_until(test_function, timeout=timeout)
616  
617      def wait_for_header(self, blockhash, timeout=60):
618          def test_function():
619              last_headers = self.last_message.get('headers')
620              if not last_headers:
621                  return False
622              return last_headers.headers[0].rehash() == int(blockhash, 16)
623  
624          self.wait_until(test_function, timeout=timeout)
625  
626      def wait_for_merkleblock(self, blockhash, timeout=60):
627          def test_function():
628              last_filtered_block = self.last_message.get('merkleblock')
629              if not last_filtered_block:
630                  return False
631              return last_filtered_block.merkleblock.header.rehash() == int(blockhash, 16)
632  
633          self.wait_until(test_function, timeout=timeout)
634  
635      def wait_for_getdata(self, hash_list, timeout=60):
636          """Waits for a getdata message.
637  
638          The object hashes in the inventory vector must match the provided hash_list."""
639          def test_function():
640              last_data = self.last_message.get("getdata")
641              if not last_data:
642                  return False
643              return [x.hash for x in last_data.inv] == hash_list
644  
645          self.wait_until(test_function, timeout=timeout)
646  
647      def wait_for_getheaders(self, timeout=60):
648          """Waits for a getheaders message.
649  
650          Receiving any getheaders message will satisfy the predicate. the last_message["getheaders"]
651          value must be explicitly cleared before calling this method, or this will return
652          immediately with success. TODO: change this method to take a hash value and only
653          return true if the correct block header has been requested."""
654          def test_function():
655              return self.last_message.get("getheaders")
656  
657          self.wait_until(test_function, timeout=timeout)
658  
659      def wait_for_inv(self, expected_inv, timeout=60):
660          """Waits for an INV message and checks that the first inv object in the message was as expected."""
661          if len(expected_inv) > 1:
662              raise NotImplementedError("wait_for_inv() will only verify the first inv object")
663  
664          def test_function():
665              return self.last_message.get("inv") and \
666                                  self.last_message["inv"].inv[0].type == expected_inv[0].type and \
667                                  self.last_message["inv"].inv[0].hash == expected_inv[0].hash
668  
669          self.wait_until(test_function, timeout=timeout)
670  
671      def wait_for_verack(self, timeout=60):
672          def test_function():
673              return "verack" in self.last_message
674  
675          self.wait_until(test_function, timeout=timeout)
676  
677      # Message sending helper functions
678  
679      def send_version(self):
680          if self.on_connection_send_msg:
681              self.send_message(self.on_connection_send_msg)
682              self.on_connection_send_msg = None  # Never used again
683  
684      def send_and_ping(self, message, timeout=60):
685          self.send_message(message)
686          self.sync_with_ping(timeout=timeout)
687  
688      def sync_with_ping(self, timeout=60):
689          """Ensure ProcessMessages and SendMessages is called on this connection"""
690          # Sending two pings back-to-back, requires that the node calls
691          # `ProcessMessage` twice, and thus ensures `SendMessages` must have
692          # been called at least once
693          self.send_message(msg_ping(nonce=0))
694          self.send_message(msg_ping(nonce=self.ping_counter))
695  
696          def test_function():
697              return self.last_message.get("pong") and self.last_message["pong"].nonce == self.ping_counter
698  
699          self.wait_until(test_function, timeout=timeout)
700          self.ping_counter += 1
701  
702  
703  # One lock for synchronizing all data access between the network event loop (see
704  # NetworkThread below) and the thread running the test logic.  For simplicity,
705  # P2PConnection acquires this lock whenever delivering a message to a P2PInterface.
706  # This lock should be acquired in the thread running the test logic to synchronize
707  # access to any data shared with the P2PInterface or P2PConnection.
708  p2p_lock = threading.Lock()
709  
710  
711  class NetworkThread(threading.Thread):
712      network_event_loop = None
713  
714      def __init__(self):
715          super().__init__(name="NetworkThread")
716          # There is only one event loop and no more than one thread must be created
717          assert not self.network_event_loop
718  
719          NetworkThread.listeners = {}
720          NetworkThread.protos = {}
721          if platform.system() == 'Windows':
722              asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
723          NetworkThread.network_event_loop = asyncio.new_event_loop()
724  
725      def run(self):
726          """Start the network thread."""
727          self.network_event_loop.run_forever()
728  
729      def close(self, timeout=10):
730          """Close the connections and network event loop."""
731          self.network_event_loop.call_soon_threadsafe(self.network_event_loop.stop)
732          wait_until_helper_internal(lambda: not self.network_event_loop.is_running(), timeout=timeout)
733          self.network_event_loop.close()
734          self.join(timeout)
735          # Safe to remove event loop.
736          NetworkThread.network_event_loop = None
737  
738      @classmethod
739      def listen(cls, p2p, callback, port=None, addr=None, idx=1):
740          """ Ensure a listening server is running on the given port, and run the
741          protocol specified by `p2p` on the next connection to it. Once ready
742          for connections, call `callback`."""
743  
744          if port is None:
745              assert 0 < idx <= MAX_NODES
746              port = p2p_port(MAX_NODES - idx)
747          if addr is None:
748              addr = '127.0.0.1'
749  
750          def exception_handler(loop, context):
751              if not p2p.reconnect:
752                  loop.default_exception_handler(context)
753  
754          cls.network_event_loop.set_exception_handler(exception_handler)
755          coroutine = cls.create_listen_server(addr, port, callback, p2p)
756          cls.network_event_loop.call_soon_threadsafe(cls.network_event_loop.create_task, coroutine)
757  
758      @classmethod
759      async def create_listen_server(cls, addr, port, callback, proto):
760          def peer_protocol():
761              """Returns a function that does the protocol handling for a new
762              connection. To allow different connections to have different
763              behaviors, the protocol function is first put in the cls.protos
764              dict. When the connection is made, the function removes the
765              protocol function from that dict, and returns it so the event loop
766              can start executing it."""
767              response = cls.protos.get((addr, port))
768              # remove protocol function from dict only when reconnection doesn't need to happen/already happened
769              if not proto.reconnect:
770                  cls.protos[(addr, port)] = None
771              return response
772  
773          if (addr, port) not in cls.listeners:
774              # When creating a listener on a given (addr, port) we only need to
775              # do it once. If we want different behaviors for different
776              # connections, we can accomplish this by providing different
777              # `proto` functions
778  
779              listener = await cls.network_event_loop.create_server(peer_protocol, addr, port)
780              logger.debug("Listening server on %s:%d should be started" % (addr, port))
781              cls.listeners[(addr, port)] = listener
782  
783          cls.protos[(addr, port)] = proto
784          callback(addr, port)
785  
786  
787  class P2PDataStore(P2PInterface):
788      """A P2P data store class.
789  
790      Keeps a block and transaction store and responds correctly to getdata and getheaders requests."""
791  
792      def __init__(self):
793          super().__init__()
794          # store of blocks. key is block hash, value is a CBlock object
795          self.block_store = {}
796          self.last_block_hash = ''
797          # store of txs. key is txid, value is a CTransaction object
798          self.tx_store = {}
799          self.getdata_requests = []
800  
801      def on_getdata(self, message):
802          """Check for the tx/block in our stores and if found, reply with an inv message."""
803          for inv in message.inv:
804              self.getdata_requests.append(inv.hash)
805              if (inv.type & MSG_TYPE_MASK) == MSG_TX and inv.hash in self.tx_store.keys():
806                  self.send_message(msg_tx(self.tx_store[inv.hash]))
807              elif (inv.type & MSG_TYPE_MASK) == MSG_BLOCK and inv.hash in self.block_store.keys():
808                  self.send_message(msg_block(self.block_store[inv.hash]))
809              else:
810                  logger.debug('getdata message type {} received.'.format(hex(inv.type)))
811  
812      def on_getheaders(self, message):
813          """Search back through our block store for the locator, and reply with a headers message if found."""
814  
815          locator, hash_stop = message.locator, message.hashstop
816  
817          # Assume that the most recent block added is the tip
818          if not self.block_store:
819              return
820  
821          headers_list = [self.block_store[self.last_block_hash]]
822          while headers_list[-1].sha256 not in locator.vHave:
823              # Walk back through the block store, adding headers to headers_list
824              # as we go.
825              prev_block_hash = headers_list[-1].hashPrevBlock
826              if prev_block_hash in self.block_store:
827                  prev_block_header = CBlockHeader(self.block_store[prev_block_hash])
828                  headers_list.append(prev_block_header)
829                  if prev_block_header.sha256 == hash_stop:
830                      # if this is the hashstop header, stop here
831                      break
832              else:
833                  logger.debug('block hash {} not found in block store'.format(hex(prev_block_hash)))
834                  break
835  
836          # Truncate the list if there are too many headers
837          headers_list = headers_list[:-MAX_HEADERS_RESULTS - 1:-1]
838          response = msg_headers(headers_list)
839  
840          if response is not None:
841              self.send_message(response)
842  
843      def send_blocks_and_test(self, blocks, node, *, success=True, force_send=False, reject_reason=None, expect_disconnect=False, timeout=60, is_decoy=False):
844          """Send blocks to test node and test whether the tip advances.
845  
846           - add all blocks to our block_store
847           - send a headers message for the final block
848           - the on_getheaders handler will ensure that any getheaders are responded to
849           - if force_send is False: wait for getdata for each of the blocks. The on_getdata handler will
850             ensure that any getdata messages are responded to. Otherwise send the full block unsolicited.
851           - if success is True: assert that the node's tip advances to the most recent block
852           - if success is False: assert that the node's tip doesn't advance
853           - if reject_reason is set: assert that the correct reject message is logged"""
854  
855          with p2p_lock:
856              for block in blocks:
857                  self.block_store[block.sha256] = block
858                  self.last_block_hash = block.sha256
859  
860          reject_reason = [reject_reason] if reject_reason else []
861          with node.assert_debug_log(expected_msgs=reject_reason):
862              if is_decoy:  # since decoy messages are ignored by the recipient - no need to wait for response
863                  force_send = True
864              if force_send:
865                  for b in blocks:
866                      self.send_message(msg_block(block=b), is_decoy)
867              else:
868                  self.send_message(msg_headers([CBlockHeader(block) for block in blocks]))
869                  self.wait_until(
870                      lambda: blocks[-1].sha256 in self.getdata_requests,
871                      timeout=timeout,
872                      check_connected=success,
873                  )
874  
875              if expect_disconnect:
876                  self.wait_for_disconnect(timeout=timeout)
877              else:
878                  self.sync_with_ping(timeout=timeout)
879  
880              if success:
881                  self.wait_until(lambda: node.getbestblockhash() == blocks[-1].hash, timeout=timeout)
882              else:
883                  assert node.getbestblockhash() != blocks[-1].hash
884  
885      def send_txs_and_test(self, txs, node, *, success=True, expect_disconnect=False, reject_reason=None):
886          """Send txs to test node and test whether they're accepted to the mempool.
887  
888           - add all txs to our tx_store
889           - send tx messages for all txs
890           - if success is True/False: assert that the txs are/are not accepted to the mempool
891           - if expect_disconnect is True: Skip the sync with ping
892           - if reject_reason is set: assert that the correct reject message is logged."""
893  
894          with p2p_lock:
895              for tx in txs:
896                  self.tx_store[tx.sha256] = tx
897  
898          reject_reason = [reject_reason] if reject_reason else []
899          with node.assert_debug_log(expected_msgs=reject_reason):
900              for tx in txs:
901                  self.send_message(msg_tx(tx))
902  
903              if expect_disconnect:
904                  self.wait_for_disconnect()
905              else:
906                  self.sync_with_ping()
907  
908              raw_mempool = node.getrawmempool()
909              if success:
910                  # Check that all txs are now in the mempool
911                  for tx in txs:
912                      assert tx.hash in raw_mempool, "{} not found in mempool".format(tx.hash)
913              else:
914                  # Check that none of the txs are now in the mempool
915                  for tx in txs:
916                      assert tx.hash not in raw_mempool, "{} tx found in mempool".format(tx.hash)
917  
918  class P2PTxInvStore(P2PInterface):
919      """A P2PInterface which stores a count of how many times each txid has been announced."""
920      def __init__(self):
921          super().__init__()
922          self.tx_invs_received = defaultdict(int)
923  
924      def on_inv(self, message):
925          super().on_inv(message) # Send getdata in response.
926          # Store how many times invs have been received for each tx.
927          for i in message.inv:
928              if (i.type == MSG_TX) or (i.type == MSG_WTX):
929                  # save txid
930                  self.tx_invs_received[i.hash] += 1
931  
932      def get_invs(self):
933          with p2p_lock:
934              return list(self.tx_invs_received.keys())
935  
936      def wait_for_broadcast(self, txns, timeout=60):
937          """Waits for the txns (list of txids) to complete initial broadcast.
938          The mempool should mark unbroadcast=False for these transactions.
939          """
940          # Wait until invs have been received (and getdatas sent) for each txid.
941          self.wait_until(lambda: set(self.tx_invs_received.keys()) == set([int(tx, 16) for tx in txns]), timeout=timeout)
942          # Flush messages and wait for the getdatas to be processed
943          self.sync_with_ping()