tcp.rs
1 // Copyright (c) 2025 ADnet Contributors 2 // This file is part of the adnet-core library. 3 // Derived from snarkOS, originally by Provable Inc. 4 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at: 8 9 // http://www.apache.org/licenses/LICENSE-2.0 10 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 use std::{ 18 collections::HashSet, 19 fmt, io, 20 net::{IpAddr, SocketAddr}, 21 ops::Deref, 22 sync::{ 23 Arc, 24 atomic::{AtomicUsize, Ordering::*}, 25 }, 26 time::{Duration, Instant}, 27 }; 28 29 #[cfg(feature = "locktick")] 30 use locktick::parking_lot::Mutex; 31 use once_cell::sync::OnceCell; 32 #[cfg(not(feature = "locktick"))] 33 use parking_lot::Mutex; 34 use tokio::{ 35 io::split, 36 net::{TcpListener, TcpStream}, 37 sync::oneshot, 38 task::JoinHandle, 39 time::timeout, 40 }; 41 use tracing::*; 42 43 use crate::{ 44 BannedPeers, Config, KnownPeers, Stats, 45 connections::{Connection, ConnectionSide, Connections}, 46 protocols::{Protocol, Protocols}, 47 }; 48 49 // A sequential numeric identifier assigned to `Tcp`s that were not provided with a name. 50 static SEQUENTIAL_NODE_ID: AtomicUsize = AtomicUsize::new(0); 51 52 /// The central object responsible for handling connections. 53 #[derive(Clone)] 54 pub struct Tcp(Arc<InnerTcp>); 55 56 impl Deref for Tcp { 57 type Target = Arc<InnerTcp>; 58 59 fn deref(&self) -> &Self::Target { 60 &self.0 61 } 62 } 63 64 /// Error types for the `Tcp::connect` function. 65 #[allow(missing_docs)] 66 #[derive(thiserror::Error, Debug)] 67 pub enum ConnectError { 68 #[error("already reached the maximum number of {limit} connections")] 69 MaximumConnectionsReached { limit: u16 }, 70 #[error("already connecting to node at {address:?}")] 71 AlreadyConnecting { address: SocketAddr }, 72 #[error("already connected to node at {address:?}")] 73 AlreadyConnected { address: SocketAddr }, 74 #[error("attempt to self-connect (at address {address:?}")] 75 SelfConnect { address: SocketAddr }, 76 #[error("I/O error: {0}")] 77 IoError(std::io::Error), 78 } 79 80 impl From<std::io::Error> for ConnectError { 81 fn from(inner: std::io::Error) -> Self { 82 Self::IoError(inner) 83 } 84 } 85 86 #[doc(hidden)] 87 pub struct InnerTcp { 88 /// The tracing span. 89 span: Span, 90 /// The node's configuration. 91 config: Config, 92 /// The node's listening address. 93 listening_addr: OnceCell<SocketAddr>, 94 /// Contains objects used by the protocols implemented by the node. 95 pub(crate) protocols: Protocols, 96 /// A set of connections that have not been finalized yet. 97 connecting: Mutex<HashSet<SocketAddr>>, 98 /// Contains objects related to the node's active connections. 99 connections: Connections, 100 /// Collects statistics related to the node's peers. 101 known_peers: KnownPeers, 102 /// Contains the set of currently banned peers. 103 banned_peers: BannedPeers, 104 /// Collects statistics related to the node itself. 105 stats: Stats, 106 /// The node's tasks. 107 pub(crate) tasks: Mutex<Vec<JoinHandle<()>>>, 108 } 109 110 impl Tcp { 111 /// Creates a new [`Tcp`] using the given [`Config`]. 112 pub fn new(mut config: Config) -> Self { 113 // If there is no pre-configured name, assign a sequential numeric identifier. 114 if config.name.is_none() { 115 config.name = Some(SEQUENTIAL_NODE_ID.fetch_add(1, Relaxed).to_string()); 116 } 117 118 // Create a tracing span containing the node's name. 119 let span = crate::helpers::create_span(config.name.as_deref().unwrap()); 120 121 // Initialize the Tcp stack. 122 let tcp = Tcp(Arc::new(InnerTcp { 123 span, 124 config, 125 listening_addr: Default::default(), 126 protocols: Default::default(), 127 connecting: Default::default(), 128 connections: Default::default(), 129 known_peers: Default::default(), 130 banned_peers: Default::default(), 131 stats: Stats::new(Instant::now()), 132 tasks: Default::default(), 133 })); 134 135 debug!(parent: tcp.span(), "The node is ready"); 136 137 tcp 138 } 139 140 /// Returns the name assigned. 141 #[inline] 142 pub fn name(&self) -> &str { 143 // safe; can be set as None in Config, but receives a default value on Tcp creation 144 self.config.name.as_deref().unwrap() 145 } 146 147 /// Returns a reference to the configuration. 148 #[inline] 149 pub fn config(&self) -> &Config { 150 &self.config 151 } 152 153 /// Returns the listening address; returns an error if Tcp was not configured 154 /// to listen for inbound connections. 155 pub fn listening_addr(&self) -> io::Result<SocketAddr> { 156 self.listening_addr 157 .get() 158 .copied() 159 .ok_or_else(|| io::ErrorKind::AddrNotAvailable.into()) 160 } 161 162 /// Checks whether the provided address is connected. 163 pub fn is_connected(&self, addr: SocketAddr) -> bool { 164 self.connections.is_connected(addr) 165 } 166 167 /// Checks if Tcp is currently setting up a connection with the provided address. 168 pub fn is_connecting(&self, addr: SocketAddr) -> bool { 169 self.connecting.lock().contains(&addr) 170 } 171 172 /// Returns the number of active connections. 173 pub fn num_connected(&self) -> usize { 174 self.connections.num_connected() 175 } 176 177 /// Returns the number of connections that are currently being set up. 178 pub fn num_connecting(&self) -> usize { 179 self.connecting.lock().len() 180 } 181 182 /// Returns a list containing addresses of active connections. 183 pub fn connected_addrs(&self) -> Vec<SocketAddr> { 184 self.connections.addrs() 185 } 186 187 /// Returns a list containing addresses of pending connections. 188 pub fn connecting_addrs(&self) -> Vec<SocketAddr> { 189 self.connecting.lock().iter().copied().collect() 190 } 191 192 /// Returns a reference to the collection of statistics of known peers. 193 #[inline] 194 pub fn known_peers(&self) -> &KnownPeers { 195 &self.known_peers 196 } 197 198 /// Returns a reference to the set of currently banned peers. 199 #[inline] 200 pub fn banned_peers(&self) -> &BannedPeers { 201 &self.banned_peers 202 } 203 204 /// Returns a reference to the statistics. 205 #[inline] 206 pub fn stats(&self) -> &Stats { 207 &self.stats 208 } 209 210 /// Returns the tracing [`Span`] associated with Tcp. 211 #[inline] 212 pub fn span(&self) -> &Span { 213 &self.span 214 } 215 216 /// Gracefully shuts down the stack. 217 pub async fn shut_down(&self) { 218 debug!(parent: self.span(), "Shutting down the TCP stack"); 219 220 // Retrieve all tasks. 221 let mut tasks = std::mem::take(&mut *self.tasks.lock()).into_iter(); 222 223 // Abort the listening task first. 224 if let Some(listening_task) = tasks.next() { 225 listening_task.abort(); // abort the listening task first 226 } 227 // Disconnect from all connected peers. 228 for addr in self.connected_addrs() { 229 self.disconnect(addr).await; 230 } 231 // Abort all remaining tasks. 232 for handle in tasks { 233 handle.abort(); 234 } 235 } 236 } 237 238 impl Tcp { 239 /// Connects to the provided `SocketAddr`. 240 pub async fn connect(&self, addr: SocketAddr) -> Result<(), ConnectError> { 241 if let Ok(listening_addr) = self.listening_addr() { 242 // TODO(nkls): maybe this first check can be dropped; though it might be best to keep just in case. 243 if addr == listening_addr || self.is_self_connect(addr) { 244 error!(parent: self.span(), "Attempted to self-connect ({addr})"); 245 return Err(ConnectError::SelfConnect { address: addr }); 246 } 247 } 248 249 if !self.can_add_connection() { 250 error!(parent: self.span(), "Too many connections; refusing to connect to {addr}"); 251 return Err(ConnectError::MaximumConnectionsReached { 252 limit: self.config.max_connections, 253 }); 254 } 255 256 if self.is_connected(addr) { 257 warn!(parent: self.span(), "Already connected to {addr}"); 258 return Err(ConnectError::AlreadyConnected { address: addr }); 259 } 260 261 if !self.connecting.lock().insert(addr) { 262 warn!(parent: self.span(), "Already connecting to {addr}"); 263 return Err(ConnectError::AlreadyConnecting { address: addr }); 264 } 265 266 let timeout_duration = Duration::from_millis(self.config().connection_timeout_ms.into()); 267 268 // Bind the tcp socket to the configured listener ip if it's set. 269 // Otherwise default to the system's default interface. 270 let res = if let Some(listen_ip) = self.config().listener_ip { 271 let sock = if listen_ip.is_ipv4() { 272 tokio::net::TcpSocket::new_v4()? 273 } else { 274 tokio::net::TcpSocket::new_v6()? 275 }; 276 sock.bind(SocketAddr::new(listen_ip, 0))?; 277 timeout(timeout_duration, sock.connect(addr)).await 278 } else { 279 timeout(timeout_duration, TcpStream::connect(addr)).await 280 }; 281 282 let stream = match res { 283 Ok(Ok(stream)) => Ok(stream), 284 Ok(err) => { 285 self.connecting.lock().remove(&addr); 286 err 287 } 288 Err(err) => { 289 self.connecting.lock().remove(&addr); 290 error!("connection timeout error: {}", err); 291 Err(io::ErrorKind::TimedOut.into()) 292 } 293 }?; 294 295 let ret = self 296 .adapt_stream(stream, addr, ConnectionSide::Initiator) 297 .await; 298 299 if let Err(ref e) = ret { 300 self.connecting.lock().remove(&addr); 301 self.known_peers().register_failure(addr.ip()); 302 error!(parent: self.span(), "Unable to initiate a connection with {addr}: {e}"); 303 } 304 305 ret.map_err(|err| err.into()) 306 } 307 308 /// Disconnects from the provided `SocketAddr`. 309 /// 310 /// Returns true if the we were connected to the given address. 311 pub async fn disconnect(&self, addr: SocketAddr) -> bool { 312 // claim the disconnect to avoid duplicate executions, or return early if already claimed 313 if let Some(conn) = self.connections.0.read().get(&addr) { 314 if conn.disconnecting.swap(true, Relaxed) { 315 // valid connection, but someone else is already disconnecting it 316 return false; 317 } 318 } else { 319 // not connected 320 return false; 321 }; 322 323 if let Some(handler) = self.protocols.disconnect.get() { 324 let (sender, receiver) = oneshot::channel(); 325 handler.trigger((addr, sender)); 326 let _ = receiver.await; // can't really fail 327 } 328 329 let conn = self.connections.remove(addr); 330 331 if let Some(ref conn) = conn { 332 debug!(parent: self.span(), "Disconnecting from {}", conn.addr()); 333 334 // Shut down the associated tasks of the peer. 335 for task in conn.tasks.iter().rev() { 336 task.abort(); 337 } 338 339 debug!(parent: self.span(), "Disconnected from {}", conn.addr()); 340 } else { 341 warn!(parent: self.span(), "Failed to disconnect, was not connected to {addr}"); 342 } 343 344 conn.is_some() 345 } 346 } 347 348 impl Tcp { 349 /// Spawns a task that listens for incoming connections. 350 pub async fn enable_listener(&self) -> io::Result<SocketAddr> { 351 // Retrieve the listening IP address, which must be set. 352 let listener_ip = self 353 .config() 354 .listener_ip 355 .expect("Tcp::enable_listener was called, but Config::listener_ip is not set"); 356 357 // Initialize the TCP listener. 358 let listener = self.create_listener(listener_ip).await?; 359 360 // Discover the port, if it was unspecified. 361 let port = listener.local_addr()?.port(); 362 363 // Set the listening IP address. 364 let listening_addr = (listener_ip, port).into(); 365 self.listening_addr 366 .set(listening_addr) 367 .expect("The node's listener was started more than once"); 368 369 // Use a channel to know when the listening task is ready. 370 let (tx, rx) = oneshot::channel(); 371 372 let tcp = self.clone(); 373 let listening_task = tokio::spawn(async move { 374 trace!(parent: tcp.span(), "Spawned the listening task"); 375 tx.send(()).unwrap(); // safe; the channel was just opened 376 377 loop { 378 // Await for a new connection. 379 match listener.accept().await { 380 Ok((stream, addr)) => tcp.handle_connection(stream, addr), 381 Err(e) => error!(parent: tcp.span(), "Failed to accept a connection: {e}"), 382 } 383 } 384 }); 385 self.tasks.lock().push(listening_task); 386 let _ = rx.await; 387 debug!(parent: self.span(), "Listening on {listening_addr}"); 388 389 Ok(listening_addr) 390 } 391 392 /// Creates an instance of `TcpListener` based on the node's configuration. 393 async fn create_listener(&self, listener_ip: IpAddr) -> io::Result<TcpListener> { 394 debug!("Creating a TCP listener on {listener_ip}..."); 395 let listener = if let Some(port) = self.config().desired_listening_port { 396 // Construct the desired listening IP address. 397 let desired_listening_addr = SocketAddr::new(listener_ip, port); 398 // If a desired listening port is set, try to bind to it. 399 match TcpListener::bind(desired_listening_addr).await { 400 Ok(listener) => listener, 401 Err(e) => { 402 if self.config().allow_random_port { 403 warn!( 404 parent: self.span(), 405 "Trying any listening port, as the desired port is unavailable: {e}" 406 ); 407 let random_available_addr = SocketAddr::new(listener_ip, 0); 408 TcpListener::bind(random_available_addr).await? 409 } else { 410 error!(parent: self.span(), "The desired listening port is unavailable: {e}"); 411 return Err(e); 412 } 413 } 414 } 415 } else if self.config().allow_random_port { 416 let random_available_addr = SocketAddr::new(listener_ip, 0); 417 TcpListener::bind(random_available_addr).await? 418 } else { 419 panic!( 420 "As 'listener_ip' is set, either 'desired_listening_port' or 'allow_random_port' must be set" 421 ); 422 }; 423 424 Ok(listener) 425 } 426 427 /// Handles a new inbound connection. 428 fn handle_connection(&self, stream: TcpStream, addr: SocketAddr) { 429 debug!(parent: self.span(), "Received a connection from {addr}"); 430 431 if !self.can_add_connection() || self.is_self_connect(addr) { 432 debug!(parent: self.span(), "Rejecting the connection from {addr}"); 433 return; 434 } 435 436 self.connecting.lock().insert(addr); 437 438 let tcp = self.clone(); 439 tokio::spawn(async move { 440 if let Err(e) = tcp 441 .adapt_stream(stream, addr, ConnectionSide::Responder) 442 .await 443 { 444 tcp.connecting.lock().remove(&addr); 445 tcp.known_peers().register_failure(addr.ip()); 446 error!(parent: tcp.span(), "Failed to connect with {addr}: {e}"); 447 } 448 }); 449 } 450 451 /// Checks if the given IP address is the same as the listening address of this `Tcp`. 452 fn is_self_connect(&self, addr: SocketAddr) -> bool { 453 // SAFETY: if we're opening connections, this should never fail. 454 let listening_addr = self.listening_addr().unwrap(); 455 456 match listening_addr.ip().is_loopback() { 457 // If localhost, check the ports, this only works on outbound connections, since we 458 // don't know the ephemeral port a peer might be using if they initiate the connection. 459 true => listening_addr.port() == addr.port(), 460 // If it's not localhost, matching IPs indicate a self-connect in both directions. 461 false => listening_addr.ip() == addr.ip(), 462 } 463 } 464 465 /// Checks whether the `Tcp` can handle an additional connection. 466 fn can_add_connection(&self) -> bool { 467 // Retrieve the number of connected peers. 468 let num_connected = self.num_connected(); 469 // Retrieve the maximum number of connected peers. 470 let limit = self.config.max_connections as usize; 471 472 if num_connected >= limit { 473 warn!(parent: self.span(), "Maximum number of active connections ({limit}) reached"); 474 false 475 } else if num_connected + self.num_connecting() >= limit { 476 warn!(parent: self.span(), "Maximum number of active & pending connections ({limit}) reached"); 477 false 478 } else { 479 true 480 } 481 } 482 483 /// Prepares the freshly acquired connection to handle the protocols the Tcp implements. 484 async fn adapt_stream( 485 &self, 486 stream: TcpStream, 487 peer_addr: SocketAddr, 488 own_side: ConnectionSide, 489 ) -> io::Result<()> { 490 self.known_peers.add(peer_addr.ip()); 491 492 // Register the port seen by the peer. 493 if own_side == ConnectionSide::Initiator { 494 if let Ok(addr) = stream.local_addr() { 495 debug!( 496 parent: self.span(), "establishing connection with {}; the peer is connected on port {}", 497 peer_addr, addr.port() 498 ); 499 } else { 500 warn!(parent: self.span(), "couldn't determine the peer's port"); 501 } 502 } 503 504 let connection = Connection::new(peer_addr, stream, !own_side); 505 506 // Enact the enabled protocols. 507 let mut connection = self.enable_protocols(connection).await?; 508 509 // if Reading is enabled, we'll notify the related task when the connection is fully ready. 510 let conn_ready_tx = connection.readiness_notifier.take(); 511 512 self.connections.add(connection); 513 self.connecting.lock().remove(&peer_addr); 514 515 // Send the aforementioned notification so that reading from the socket can commence. 516 if let Some(tx) = conn_ready_tx { 517 let _ = tx.send(()); 518 } 519 520 // If enabled, enact OnConnect. 521 if let Some(handler) = self.protocols.on_connect.get() { 522 let (sender, receiver) = oneshot::channel(); 523 handler.trigger((peer_addr, sender)); 524 let _ = receiver.await; // can't really fail 525 } 526 527 Ok(()) 528 } 529 530 /// Enacts the enabled protocols on the provided connection. 531 async fn enable_protocols(&self, conn: Connection) -> io::Result<Connection> { 532 /// A helper macro to enable a protocol on a connection. 533 macro_rules! enable_protocol { 534 ($handler_type: ident, $node:expr, $conn: expr) => { 535 if let Some(handler) = $node.protocols.$handler_type.get() { 536 let (conn_returner, conn_retriever) = oneshot::channel(); 537 538 handler.trigger(($conn, conn_returner)); 539 540 match conn_retriever.await { 541 Ok(Ok(conn)) => conn, 542 Err(_) => return Err(io::ErrorKind::BrokenPipe.into()), 543 Ok(e) => return e, 544 } 545 } else { 546 $conn 547 } 548 }; 549 } 550 551 let mut conn = enable_protocol!(handshake, self, conn); 552 553 // Split the stream after the handshake (if not done before). 554 if let Some(stream) = conn.stream.take() { 555 let (reader, writer) = split(stream); 556 conn.reader = Some(Box::new(reader)); 557 conn.writer = Some(Box::new(writer)); 558 } 559 560 let conn = enable_protocol!(reading, self, conn); 561 let conn = enable_protocol!(writing, self, conn); 562 563 Ok(conn) 564 } 565 } 566 567 impl fmt::Debug for Tcp { 568 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 569 write!(f, "The TCP stack config: {:?}", self.config) 570 } 571 } 572 573 #[cfg(test)] 574 mod tests { 575 use super::*; 576 577 use std::{ 578 net::{IpAddr, Ipv4Addr}, 579 str::FromStr, 580 }; 581 582 #[tokio::test] 583 async fn test_new() { 584 let tcp = Tcp::new(Config { 585 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 586 max_connections: 200, 587 ..Default::default() 588 }); 589 590 assert_eq!(tcp.config.max_connections, 200); 591 assert_eq!( 592 tcp.config.listener_ip, 593 Some(IpAddr::V4(Ipv4Addr::LOCALHOST)) 594 ); 595 assert_eq!( 596 tcp.enable_listener().await.unwrap().ip(), 597 IpAddr::V4(Ipv4Addr::LOCALHOST) 598 ); 599 600 assert_eq!(tcp.num_connected(), 0); 601 assert_eq!(tcp.num_connecting(), 0); 602 } 603 604 #[tokio::test] 605 async fn test_connect() { 606 let tcp = Tcp::new(Config::default()); 607 let node_ip = tcp.enable_listener().await.unwrap(); 608 609 // Ensure self-connecting is not possible. 610 let result = tcp.connect(node_ip).await; 611 assert!(matches!(result, Err(ConnectError::SelfConnect { .. }))); 612 613 assert_eq!(tcp.num_connected(), 0); 614 assert_eq!(tcp.num_connecting(), 0); 615 assert!(!tcp.is_connected(node_ip)); 616 assert!(!tcp.is_connecting(node_ip)); 617 618 // Initialize the peer. 619 let peer = Tcp::new(Config { 620 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 621 desired_listening_port: Some(0), 622 max_connections: 1, 623 ..Default::default() 624 }); 625 let peer_ip = peer.enable_listener().await.unwrap(); 626 627 // Connect to the peer. 628 tcp.connect(peer_ip).await.unwrap(); 629 assert_eq!(tcp.num_connected(), 1); 630 assert_eq!(tcp.num_connecting(), 0); 631 assert!(tcp.is_connected(peer_ip)); 632 assert!(!tcp.is_connecting(peer_ip)); 633 } 634 635 #[tokio::test] 636 async fn test_disconnect() { 637 let tcp = Tcp::new(Config::default()); 638 let _node_ip = tcp.enable_listener().await.unwrap(); 639 640 // Initialize the peer. 641 let peer = Tcp::new(Config { 642 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 643 desired_listening_port: Some(0), 644 max_connections: 1, 645 ..Default::default() 646 }); 647 let peer_ip = peer.enable_listener().await.unwrap(); 648 649 // Connect to the peer. 650 tcp.connect(peer_ip).await.unwrap(); 651 assert_eq!(tcp.num_connected(), 1); 652 assert_eq!(tcp.num_connecting(), 0); 653 assert!(tcp.is_connected(peer_ip)); 654 assert!(!tcp.is_connecting(peer_ip)); 655 656 // Disconnect from the peer. 657 let has_disconnected = tcp.disconnect(peer_ip).await; 658 assert!(has_disconnected); 659 assert_eq!(tcp.num_connected(), 0); 660 assert_eq!(tcp.num_connecting(), 0); 661 assert!(!tcp.is_connected(peer_ip)); 662 assert!(!tcp.is_connecting(peer_ip)); 663 664 // Ensure disconnecting from the peer a second time is okay. 665 let has_disconnected = tcp.disconnect(peer_ip).await; 666 assert!(!has_disconnected); 667 assert_eq!(tcp.num_connected(), 0); 668 assert_eq!(tcp.num_connecting(), 0); 669 assert!(!tcp.is_connected(peer_ip)); 670 assert!(!tcp.is_connecting(peer_ip)); 671 } 672 673 #[tokio::test] 674 async fn test_can_add_connection() { 675 let tcp = Tcp::new(Config { 676 max_connections: 1, 677 ..Default::default() 678 }); 679 680 // Initialize the peer. 681 let peer = Tcp::new(Config { 682 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 683 desired_listening_port: Some(0), 684 max_connections: 1, 685 ..Default::default() 686 }); 687 let peer_ip = peer.enable_listener().await.unwrap(); 688 689 assert!(tcp.can_add_connection()); 690 691 // Simulate an active connection. 692 let stream = TcpStream::connect(peer_ip).await.unwrap(); 693 tcp.connections 694 .add(Connection::new(peer_ip, stream, ConnectionSide::Initiator)); 695 assert!(!tcp.can_add_connection()); 696 697 // Ensure that we cannot invoke connect() successfully in this case. 698 // Use a non-local IP, to ensure it is never qual to peer IP. 699 let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap(); 700 let result = tcp.connect(another_ip).await; 701 assert!(matches!( 702 result, 703 Err(ConnectError::MaximumConnectionsReached { .. }) 704 )); 705 706 // Remove the active connection. 707 tcp.connections.remove(peer_ip); 708 assert!(tcp.can_add_connection()); 709 710 // Simulate a pending connection. 711 tcp.connecting.lock().insert(peer_ip); 712 assert!(!tcp.can_add_connection()); 713 714 // Ensure that we cannot invoke connect() successfully in this case either. 715 let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap(); 716 let result = tcp.connect(another_ip).await; 717 assert!(matches!( 718 result, 719 Err(ConnectError::MaximumConnectionsReached { .. }) 720 )); 721 722 // Remove the pending connection. 723 tcp.connecting.lock().remove(&peer_ip); 724 assert!(tcp.can_add_connection()); 725 726 // Simulate an active and a pending connection (this case should never occur). 727 let stream = TcpStream::connect(peer_ip).await.unwrap(); 728 tcp.connections 729 .add(Connection::new(peer_ip, stream, ConnectionSide::Responder)); 730 tcp.connecting.lock().insert(peer_ip); 731 assert!(!tcp.can_add_connection()); 732 733 // Remove the active and pending connection. 734 tcp.connections.remove(peer_ip); 735 tcp.connecting.lock().remove(&peer_ip); 736 assert!(tcp.can_add_connection()); 737 } 738 739 #[tokio::test] 740 async fn test_handle_connection() { 741 let tcp = Tcp::new(Config { 742 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 743 max_connections: 1, 744 ..Default::default() 745 }); 746 747 // Initialize peer 1. 748 let peer1 = Tcp::new(Config { 749 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 750 desired_listening_port: Some(0), 751 max_connections: 1, 752 ..Default::default() 753 }); 754 let peer1_ip = peer1.enable_listener().await.unwrap(); 755 756 // Simulate an active connection. 757 let stream = TcpStream::connect(peer1_ip).await.unwrap(); 758 tcp.connections 759 .add(Connection::new(peer1_ip, stream, ConnectionSide::Responder)); 760 assert!(!tcp.can_add_connection()); 761 assert_eq!(tcp.num_connected(), 1); 762 assert_eq!(tcp.num_connecting(), 0); 763 assert!(tcp.is_connected(peer1_ip)); 764 assert!(!tcp.is_connecting(peer1_ip)); 765 766 // Initialize peer 2. 767 let peer2 = Tcp::new(Config { 768 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 769 desired_listening_port: Some(0), 770 max_connections: 1, 771 ..Default::default() 772 }); 773 let peer2_ip = peer2.enable_listener().await.unwrap(); 774 775 // Handle the connection. 776 let stream = TcpStream::connect(peer2_ip).await.unwrap(); 777 tcp.handle_connection(stream, peer2_ip); 778 assert!(!tcp.can_add_connection()); 779 assert_eq!(tcp.num_connected(), 1); 780 assert_eq!(tcp.num_connecting(), 0); 781 assert!(tcp.is_connected(peer1_ip)); 782 assert!(!tcp.is_connected(peer2_ip)); 783 assert!(!tcp.is_connecting(peer1_ip)); 784 assert!(!tcp.is_connecting(peer2_ip)); 785 } 786 787 #[tokio::test] 788 async fn test_adapt_stream() { 789 let tcp = Tcp::new(Config { 790 max_connections: 1, 791 ..Default::default() 792 }); 793 794 // Initialize the peer. 795 let peer = Tcp::new(Config { 796 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 797 desired_listening_port: Some(0), 798 max_connections: 1, 799 ..Default::default() 800 }); 801 let peer_ip = peer.enable_listener().await.unwrap(); 802 803 // Simulate a pending connection. 804 tcp.connecting.lock().insert(peer_ip); 805 assert_eq!(tcp.num_connected(), 0); 806 assert_eq!(tcp.num_connecting(), 1); 807 assert!(!tcp.is_connected(peer_ip)); 808 assert!(tcp.is_connecting(peer_ip)); 809 810 // Simulate a new connection. 811 let stream = TcpStream::connect(peer_ip).await.unwrap(); 812 tcp.adapt_stream(stream, peer_ip, ConnectionSide::Responder) 813 .await 814 .unwrap(); 815 assert_eq!(tcp.num_connected(), 1); 816 assert_eq!(tcp.num_connecting(), 0); 817 assert!(tcp.is_connected(peer_ip)); 818 assert!(!tcp.is_connecting(peer_ip)); 819 } 820 }