tcp.rs
1 // Copyright (c) 2025 ADnet Contributors 2 // This file is part of the AlphaOS library. 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at: 7 8 // http://www.apache.org/licenses/LICENSE-2.0 9 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 use std::{ 17 collections::HashSet, 18 fmt, 19 io, 20 net::{IpAddr, SocketAddr}, 21 ops::Deref, 22 sync::{ 23 Arc, 24 atomic::{AtomicUsize, Ordering::*}, 25 }, 26 time::{Duration, Instant}, 27 }; 28 29 #[cfg(feature = "locktick")] 30 use locktick::parking_lot::Mutex; 31 use once_cell::sync::OnceCell; 32 #[cfg(not(feature = "locktick"))] 33 use parking_lot::Mutex; 34 use tokio::{ 35 io::split, 36 net::{TcpListener, TcpStream}, 37 sync::oneshot, 38 task::JoinHandle, 39 time::timeout, 40 }; 41 use tracing::*; 42 43 use crate::{ 44 BannedPeers, 45 Config, 46 KnownPeers, 47 Stats, 48 connections::{Connection, ConnectionSide, Connections}, 49 protocols::{Protocol, Protocols}, 50 }; 51 52 // A sequential numeric identifier assigned to `Tcp`s that were not provided with a name. 53 static SEQUENTIAL_NODE_ID: AtomicUsize = AtomicUsize::new(0); 54 55 /// The central object responsible for handling connections. 56 #[derive(Clone)] 57 pub struct Tcp(Arc<InnerTcp>); 58 59 impl Deref for Tcp { 60 type Target = Arc<InnerTcp>; 61 62 fn deref(&self) -> &Self::Target { 63 &self.0 64 } 65 } 66 67 /// Error types for the `Tcp::connect` function. 68 #[allow(missing_docs)] 69 #[derive(thiserror::Error, Debug)] 70 pub enum ConnectError { 71 #[error("already reached the maximum number of {limit} connections")] 72 MaximumConnectionsReached { limit: u16 }, 73 #[error("already connecting to node at {address:?}")] 74 AlreadyConnecting { address: SocketAddr }, 75 #[error("already connected to node at {address:?}")] 76 AlreadyConnected { address: SocketAddr }, 77 #[error("attempt to self-connect (at address {address:?}")] 78 SelfConnect { address: SocketAddr }, 79 #[error("I/O error: {0}")] 80 IoError(std::io::Error), 81 } 82 83 impl From<std::io::Error> for ConnectError { 84 fn from(inner: std::io::Error) -> Self { 85 Self::IoError(inner) 86 } 87 } 88 89 #[doc(hidden)] 90 pub struct InnerTcp { 91 /// The tracing span. 92 span: Span, 93 /// The node's configuration. 94 config: Config, 95 /// The node's listening address. 96 listening_addr: OnceCell<SocketAddr>, 97 /// Contains objects used by the protocols implemented by the node. 98 pub(crate) protocols: Protocols, 99 /// A set of connections that have not been finalized yet. 100 connecting: Mutex<HashSet<SocketAddr>>, 101 /// Contains objects related to the node's active connections. 102 connections: Connections, 103 /// Collects statistics related to the node's peers. 104 known_peers: KnownPeers, 105 /// Contains the set of currently banned peers. 106 banned_peers: BannedPeers, 107 /// Collects statistics related to the node itself. 108 stats: Stats, 109 /// The node's tasks. 110 pub(crate) tasks: Mutex<Vec<JoinHandle<()>>>, 111 } 112 113 impl Tcp { 114 /// Creates a new [`Tcp`] using the given [`Config`]. 115 pub fn new(mut config: Config) -> Self { 116 // If there is no pre-configured name, assign a sequential numeric identifier. 117 if config.name.is_none() { 118 config.name = Some(SEQUENTIAL_NODE_ID.fetch_add(1, Relaxed).to_string()); 119 } 120 121 // Create a tracing span containing the node's name. 122 let span = crate::helpers::create_span(config.name.as_deref().unwrap()); 123 124 // Initialize the Tcp stack. 125 let tcp = Tcp(Arc::new(InnerTcp { 126 span, 127 config, 128 listening_addr: Default::default(), 129 protocols: Default::default(), 130 connecting: Default::default(), 131 connections: Default::default(), 132 known_peers: Default::default(), 133 banned_peers: Default::default(), 134 stats: Stats::new(Instant::now()), 135 tasks: Default::default(), 136 })); 137 138 debug!(parent: tcp.span(), "The node is ready"); 139 140 tcp 141 } 142 143 /// Returns the name assigned. 144 #[inline] 145 pub fn name(&self) -> &str { 146 // safe; can be set as None in Config, but receives a default value on Tcp creation 147 self.config.name.as_deref().unwrap() 148 } 149 150 /// Returns a reference to the configuration. 151 #[inline] 152 pub fn config(&self) -> &Config { 153 &self.config 154 } 155 156 /// Returns the listening address; returns an error if Tcp was not configured 157 /// to listen for inbound connections. 158 pub fn listening_addr(&self) -> io::Result<SocketAddr> { 159 self.listening_addr.get().copied().ok_or_else(|| io::ErrorKind::AddrNotAvailable.into()) 160 } 161 162 /// Checks whether the provided address is connected. 163 pub fn is_connected(&self, addr: SocketAddr) -> bool { 164 self.connections.is_connected(addr) 165 } 166 167 /// Checks if Tcp is currently setting up a connection with the provided address. 168 pub fn is_connecting(&self, addr: SocketAddr) -> bool { 169 self.connecting.lock().contains(&addr) 170 } 171 172 /// Returns the number of active connections. 173 pub fn num_connected(&self) -> usize { 174 self.connections.num_connected() 175 } 176 177 /// Returns the number of connections that are currently being set up. 178 pub fn num_connecting(&self) -> usize { 179 self.connecting.lock().len() 180 } 181 182 /// Returns a list containing addresses of active connections. 183 pub fn connected_addrs(&self) -> Vec<SocketAddr> { 184 self.connections.addrs() 185 } 186 187 /// Returns a list containing addresses of pending connections. 188 pub fn connecting_addrs(&self) -> Vec<SocketAddr> { 189 self.connecting.lock().iter().copied().collect() 190 } 191 192 /// Returns a reference to the collection of statistics of known peers. 193 #[inline] 194 pub fn known_peers(&self) -> &KnownPeers { 195 &self.known_peers 196 } 197 198 /// Returns a reference to the set of currently banned peers. 199 #[inline] 200 pub fn banned_peers(&self) -> &BannedPeers { 201 &self.banned_peers 202 } 203 204 /// Returns a reference to the statistics. 205 #[inline] 206 pub fn stats(&self) -> &Stats { 207 &self.stats 208 } 209 210 /// Returns the tracing [`Span`] associated with Tcp. 211 #[inline] 212 pub fn span(&self) -> &Span { 213 &self.span 214 } 215 216 /// Gracefully shuts down the stack. 217 pub async fn shut_down(&self) { 218 debug!(parent: self.span(), "Shutting down the TCP stack"); 219 220 // Retrieve all tasks. 221 let mut tasks = std::mem::take(&mut *self.tasks.lock()).into_iter(); 222 223 // Abort the listening task first. 224 if let Some(listening_task) = tasks.next() { 225 listening_task.abort(); // abort the listening task first 226 } 227 // Disconnect from all connected peers. 228 for addr in self.connected_addrs() { 229 self.disconnect(addr).await; 230 } 231 // Abort all remaining tasks. 232 for handle in tasks { 233 handle.abort(); 234 } 235 } 236 } 237 238 impl Tcp { 239 /// Connects to the provided `SocketAddr`. 240 pub async fn connect(&self, addr: SocketAddr) -> Result<(), ConnectError> { 241 if let Ok(listening_addr) = self.listening_addr() { 242 // TODO(nkls): maybe this first check can be dropped; though it might be best to keep just in case. 243 if addr == listening_addr || self.is_self_connect(addr) { 244 error!(parent: self.span(), "Attempted to self-connect ({addr})"); 245 return Err(ConnectError::SelfConnect { address: addr }); 246 } 247 } 248 249 if !self.can_add_connection() { 250 error!(parent: self.span(), "Too many connections; refusing to connect to {addr}"); 251 return Err(ConnectError::MaximumConnectionsReached { limit: self.config.max_connections }); 252 } 253 254 if self.is_connected(addr) { 255 warn!(parent: self.span(), "Already connected to {addr}"); 256 return Err(ConnectError::AlreadyConnected { address: addr }); 257 } 258 259 if !self.connecting.lock().insert(addr) { 260 warn!(parent: self.span(), "Already connecting to {addr}"); 261 return Err(ConnectError::AlreadyConnecting { address: addr }); 262 } 263 264 let timeout_duration = Duration::from_millis(self.config().connection_timeout_ms.into()); 265 266 // Bind the tcp socket to the configured listener ip if it's set. 267 // Otherwise default to the system's default interface. 268 let res = if let Some(listen_ip) = self.config().listener_ip { 269 let sock = 270 if listen_ip.is_ipv4() { tokio::net::TcpSocket::new_v4()? } else { tokio::net::TcpSocket::new_v6()? }; 271 sock.bind(SocketAddr::new(listen_ip, 0))?; 272 timeout(timeout_duration, sock.connect(addr)).await 273 } else { 274 timeout(timeout_duration, TcpStream::connect(addr)).await 275 }; 276 277 let stream = match res { 278 Ok(Ok(stream)) => Ok(stream), 279 Ok(err) => { 280 self.connecting.lock().remove(&addr); 281 err 282 } 283 Err(err) => { 284 self.connecting.lock().remove(&addr); 285 error!("connection timeout error: {}", err); 286 Err(io::ErrorKind::TimedOut.into()) 287 } 288 }?; 289 290 let ret = self.adapt_stream(stream, addr, ConnectionSide::Initiator).await; 291 292 if let Err(ref e) = ret { 293 self.connecting.lock().remove(&addr); 294 self.known_peers().register_failure(addr.ip()); 295 error!(parent: self.span(), "Unable to initiate a connection with {addr}: {e}"); 296 } 297 298 ret.map_err(|err| err.into()) 299 } 300 301 /// Disconnects from the provided `SocketAddr`. 302 /// 303 /// Returns true if the we were connected to the given address. 304 pub async fn disconnect(&self, addr: SocketAddr) -> bool { 305 // claim the disconnect to avoid duplicate executions, or return early if already claimed 306 if let Some(conn) = self.connections.0.read().get(&addr) { 307 if conn.disconnecting.swap(true, Relaxed) { 308 // valid connection, but someone else is already disconnecting it 309 return false; 310 } 311 } else { 312 // not connected 313 return false; 314 }; 315 316 if let Some(handler) = self.protocols.disconnect.get() { 317 let (sender, receiver) = oneshot::channel(); 318 handler.trigger((addr, sender)); 319 let _ = receiver.await; // can't really fail 320 } 321 322 let conn = self.connections.remove(addr); 323 324 if let Some(ref conn) = conn { 325 debug!(parent: self.span(), "Disconnecting from {}", conn.addr()); 326 327 // Shut down the associated tasks of the peer. 328 for task in conn.tasks.iter().rev() { 329 task.abort(); 330 } 331 332 debug!(parent: self.span(), "Disconnected from {}", conn.addr()); 333 } else { 334 warn!(parent: self.span(), "Failed to disconnect, was not connected to {addr}"); 335 } 336 337 conn.is_some() 338 } 339 } 340 341 impl Tcp { 342 /// Spawns a task that listens for incoming connections. 343 pub async fn enable_listener(&self) -> io::Result<SocketAddr> { 344 // Retrieve the listening IP address, which must be set. 345 let listener_ip = 346 self.config().listener_ip.expect("Tcp::enable_listener was called, but Config::listener_ip is not set"); 347 348 // Initialize the TCP listener. 349 let listener = self.create_listener(listener_ip).await?; 350 351 // Discover the port, if it was unspecified. 352 let port = listener.local_addr()?.port(); 353 354 // Set the listening IP address. 355 let listening_addr = (listener_ip, port).into(); 356 self.listening_addr.set(listening_addr).expect("The node's listener was started more than once"); 357 358 // Use a channel to know when the listening task is ready. 359 let (tx, rx) = oneshot::channel(); 360 361 let tcp = self.clone(); 362 let listening_task = tokio::spawn(async move { 363 trace!(parent: tcp.span(), "Spawned the listening task"); 364 tx.send(()).unwrap(); // safe; the channel was just opened 365 366 loop { 367 // Await for a new connection. 368 match listener.accept().await { 369 Ok((stream, addr)) => tcp.handle_connection(stream, addr), 370 Err(e) => error!(parent: tcp.span(), "Failed to accept a connection: {e}"), 371 } 372 } 373 }); 374 self.tasks.lock().push(listening_task); 375 let _ = rx.await; 376 debug!(parent: self.span(), "Listening on {listening_addr}"); 377 378 Ok(listening_addr) 379 } 380 381 /// Creates an instance of `TcpListener` based on the node's configuration. 382 async fn create_listener(&self, listener_ip: IpAddr) -> io::Result<TcpListener> { 383 debug!("Creating a TCP listener on {listener_ip}..."); 384 let listener = if let Some(port) = self.config().desired_listening_port { 385 // Construct the desired listening IP address. 386 let desired_listening_addr = SocketAddr::new(listener_ip, port); 387 // If a desired listening port is set, try to bind to it. 388 match TcpListener::bind(desired_listening_addr).await { 389 Ok(listener) => listener, 390 Err(e) => { 391 if self.config().allow_random_port { 392 warn!( 393 parent: self.span(), 394 "Trying any listening port, as the desired port is unavailable: {e}" 395 ); 396 let random_available_addr = SocketAddr::new(listener_ip, 0); 397 TcpListener::bind(random_available_addr).await? 398 } else { 399 error!(parent: self.span(), "The desired listening port is unavailable: {e}"); 400 return Err(e); 401 } 402 } 403 } 404 } else if self.config().allow_random_port { 405 let random_available_addr = SocketAddr::new(listener_ip, 0); 406 TcpListener::bind(random_available_addr).await? 407 } else { 408 panic!("As 'listener_ip' is set, either 'desired_listening_port' or 'allow_random_port' must be set"); 409 }; 410 411 Ok(listener) 412 } 413 414 /// Handles a new inbound connection. 415 fn handle_connection(&self, stream: TcpStream, addr: SocketAddr) { 416 debug!(parent: self.span(), "Received a connection from {addr}"); 417 418 if !self.can_add_connection() || self.is_self_connect(addr) { 419 debug!(parent: self.span(), "Rejecting the connection from {addr}"); 420 return; 421 } 422 423 self.connecting.lock().insert(addr); 424 425 let tcp = self.clone(); 426 tokio::spawn(async move { 427 if let Err(e) = tcp.adapt_stream(stream, addr, ConnectionSide::Responder).await { 428 tcp.connecting.lock().remove(&addr); 429 tcp.known_peers().register_failure(addr.ip()); 430 error!(parent: tcp.span(), "Failed to connect with {addr}: {e}"); 431 } 432 }); 433 } 434 435 /// Checks if the given IP address is the same as the listening address of this `Tcp`. 436 fn is_self_connect(&self, addr: SocketAddr) -> bool { 437 // SAFETY: if we're opening connections, this should never fail. 438 let listening_addr = self.listening_addr().unwrap(); 439 440 match listening_addr.ip().is_loopback() { 441 // If localhost, check the ports, this only works on outbound connections, since we 442 // don't know the ephemeral port a peer might be using if they initiate the connection. 443 true => listening_addr.port() == addr.port(), 444 // If it's not localhost, matching IPs indicate a self-connect in both directions. 445 false => listening_addr.ip() == addr.ip(), 446 } 447 } 448 449 /// Checks whether the `Tcp` can handle an additional connection. 450 fn can_add_connection(&self) -> bool { 451 // Retrieve the number of connected peers. 452 let num_connected = self.num_connected(); 453 // Retrieve the maximum number of connected peers. 454 let limit = self.config.max_connections as usize; 455 456 if num_connected >= limit { 457 warn!(parent: self.span(), "Maximum number of active connections ({limit}) reached"); 458 false 459 } else if num_connected + self.num_connecting() >= limit { 460 warn!(parent: self.span(), "Maximum number of active & pending connections ({limit}) reached"); 461 false 462 } else { 463 true 464 } 465 } 466 467 /// Prepares the freshly acquired connection to handle the protocols the Tcp implements. 468 async fn adapt_stream(&self, stream: TcpStream, peer_addr: SocketAddr, own_side: ConnectionSide) -> io::Result<()> { 469 self.known_peers.add(peer_addr.ip()); 470 471 // Register the port seen by the peer. 472 if own_side == ConnectionSide::Initiator { 473 if let Ok(addr) = stream.local_addr() { 474 debug!( 475 parent: self.span(), "establishing connection with {}; the peer is connected on port {}", 476 peer_addr, addr.port() 477 ); 478 } else { 479 warn!(parent: self.span(), "couldn't determine the peer's port"); 480 } 481 } 482 483 let connection = Connection::new(peer_addr, stream, !own_side); 484 485 // Enact the enabled protocols. 486 let mut connection = self.enable_protocols(connection).await?; 487 488 // if Reading is enabled, we'll notify the related task when the connection is fully ready. 489 let conn_ready_tx = connection.readiness_notifier.take(); 490 491 self.connections.add(connection); 492 self.connecting.lock().remove(&peer_addr); 493 494 // Send the aforementioned notification so that reading from the socket can commence. 495 if let Some(tx) = conn_ready_tx { 496 let _ = tx.send(()); 497 } 498 499 // If enabled, enact OnConnect. 500 if let Some(handler) = self.protocols.on_connect.get() { 501 let (sender, receiver) = oneshot::channel(); 502 handler.trigger((peer_addr, sender)); 503 let _ = receiver.await; // can't really fail 504 } 505 506 Ok(()) 507 } 508 509 /// Enacts the enabled protocols on the provided connection. 510 async fn enable_protocols(&self, conn: Connection) -> io::Result<Connection> { 511 /// A helper macro to enable a protocol on a connection. 512 macro_rules! enable_protocol { 513 ($handler_type: ident, $node:expr, $conn: expr) => { 514 if let Some(handler) = $node.protocols.$handler_type.get() { 515 let (conn_returner, conn_retriever) = oneshot::channel(); 516 517 handler.trigger(($conn, conn_returner)); 518 519 match conn_retriever.await { 520 Ok(Ok(conn)) => conn, 521 Err(_) => return Err(io::ErrorKind::BrokenPipe.into()), 522 Ok(e) => return e, 523 } 524 } else { 525 $conn 526 } 527 }; 528 } 529 530 let mut conn = enable_protocol!(handshake, self, conn); 531 532 // Split the stream after the handshake (if not done before). 533 if let Some(stream) = conn.stream.take() { 534 let (reader, writer) = split(stream); 535 conn.reader = Some(Box::new(reader)); 536 conn.writer = Some(Box::new(writer)); 537 } 538 539 let conn = enable_protocol!(reading, self, conn); 540 let conn = enable_protocol!(writing, self, conn); 541 542 Ok(conn) 543 } 544 } 545 546 impl fmt::Debug for Tcp { 547 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 548 write!(f, "The TCP stack config: {:?}", self.config) 549 } 550 } 551 552 #[cfg(test)] 553 mod tests { 554 use super::*; 555 556 use std::{ 557 net::{IpAddr, Ipv4Addr}, 558 str::FromStr, 559 }; 560 561 #[tokio::test] 562 async fn test_new() { 563 let tcp = Tcp::new(Config { 564 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 565 max_connections: 200, 566 ..Default::default() 567 }); 568 569 assert_eq!(tcp.config.max_connections, 200); 570 assert_eq!(tcp.config.listener_ip, Some(IpAddr::V4(Ipv4Addr::LOCALHOST))); 571 assert_eq!(tcp.enable_listener().await.unwrap().ip(), IpAddr::V4(Ipv4Addr::LOCALHOST)); 572 573 assert_eq!(tcp.num_connected(), 0); 574 assert_eq!(tcp.num_connecting(), 0); 575 } 576 577 #[tokio::test] 578 async fn test_connect() { 579 let tcp = Tcp::new(Config::default()); 580 let node_ip = tcp.enable_listener().await.unwrap(); 581 582 // Ensure self-connecting is not possible. 583 let result = tcp.connect(node_ip).await; 584 assert!(matches!(result, Err(ConnectError::SelfConnect { .. }))); 585 586 assert_eq!(tcp.num_connected(), 0); 587 assert_eq!(tcp.num_connecting(), 0); 588 assert!(!tcp.is_connected(node_ip)); 589 assert!(!tcp.is_connecting(node_ip)); 590 591 // Initialize the peer. 592 let peer = Tcp::new(Config { 593 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 594 desired_listening_port: Some(0), 595 max_connections: 1, 596 ..Default::default() 597 }); 598 let peer_ip = peer.enable_listener().await.unwrap(); 599 600 // Connect to the peer. 601 tcp.connect(peer_ip).await.unwrap(); 602 assert_eq!(tcp.num_connected(), 1); 603 assert_eq!(tcp.num_connecting(), 0); 604 assert!(tcp.is_connected(peer_ip)); 605 assert!(!tcp.is_connecting(peer_ip)); 606 } 607 608 #[tokio::test] 609 async fn test_disconnect() { 610 let tcp = Tcp::new(Config::default()); 611 let _node_ip = tcp.enable_listener().await.unwrap(); 612 613 // Initialize the peer. 614 let peer = Tcp::new(Config { 615 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 616 desired_listening_port: Some(0), 617 max_connections: 1, 618 ..Default::default() 619 }); 620 let peer_ip = peer.enable_listener().await.unwrap(); 621 622 // Connect to the peer. 623 tcp.connect(peer_ip).await.unwrap(); 624 assert_eq!(tcp.num_connected(), 1); 625 assert_eq!(tcp.num_connecting(), 0); 626 assert!(tcp.is_connected(peer_ip)); 627 assert!(!tcp.is_connecting(peer_ip)); 628 629 // Disconnect from the peer. 630 let has_disconnected = tcp.disconnect(peer_ip).await; 631 assert!(has_disconnected); 632 assert_eq!(tcp.num_connected(), 0); 633 assert_eq!(tcp.num_connecting(), 0); 634 assert!(!tcp.is_connected(peer_ip)); 635 assert!(!tcp.is_connecting(peer_ip)); 636 637 // Ensure disconnecting from the peer a second time is okay. 638 let has_disconnected = tcp.disconnect(peer_ip).await; 639 assert!(!has_disconnected); 640 assert_eq!(tcp.num_connected(), 0); 641 assert_eq!(tcp.num_connecting(), 0); 642 assert!(!tcp.is_connected(peer_ip)); 643 assert!(!tcp.is_connecting(peer_ip)); 644 } 645 646 #[tokio::test] 647 async fn test_can_add_connection() { 648 let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() }); 649 650 // Initialize the peer. 651 let peer = Tcp::new(Config { 652 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 653 desired_listening_port: Some(0), 654 max_connections: 1, 655 ..Default::default() 656 }); 657 let peer_ip = peer.enable_listener().await.unwrap(); 658 659 assert!(tcp.can_add_connection()); 660 661 // Simulate an active connection. 662 let stream = TcpStream::connect(peer_ip).await.unwrap(); 663 tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Initiator)); 664 assert!(!tcp.can_add_connection()); 665 666 // Ensure that we cannot invoke connect() successfully in this case. 667 // Use a non-local IP, to ensure it is never qual to peer IP. 668 let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap(); 669 let result = tcp.connect(another_ip).await; 670 assert!(matches!(result, Err(ConnectError::MaximumConnectionsReached { .. }))); 671 672 // Remove the active connection. 673 tcp.connections.remove(peer_ip); 674 assert!(tcp.can_add_connection()); 675 676 // Simulate a pending connection. 677 tcp.connecting.lock().insert(peer_ip); 678 assert!(!tcp.can_add_connection()); 679 680 // Ensure that we cannot invoke connect() successfully in this case either. 681 let another_ip = SocketAddr::from_str("1.2.3.4:4242").unwrap(); 682 let result = tcp.connect(another_ip).await; 683 assert!(matches!(result, Err(ConnectError::MaximumConnectionsReached { .. }))); 684 685 // Remove the pending connection. 686 tcp.connecting.lock().remove(&peer_ip); 687 assert!(tcp.can_add_connection()); 688 689 // Simulate an active and a pending connection (this case should never occur). 690 let stream = TcpStream::connect(peer_ip).await.unwrap(); 691 tcp.connections.add(Connection::new(peer_ip, stream, ConnectionSide::Responder)); 692 tcp.connecting.lock().insert(peer_ip); 693 assert!(!tcp.can_add_connection()); 694 695 // Remove the active and pending connection. 696 tcp.connections.remove(peer_ip); 697 tcp.connecting.lock().remove(&peer_ip); 698 assert!(tcp.can_add_connection()); 699 } 700 701 #[tokio::test] 702 async fn test_handle_connection() { 703 let tcp = Tcp::new(Config { 704 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 705 max_connections: 1, 706 ..Default::default() 707 }); 708 709 // Initialize peer 1. 710 let peer1 = Tcp::new(Config { 711 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 712 desired_listening_port: Some(0), 713 max_connections: 1, 714 ..Default::default() 715 }); 716 let peer1_ip = peer1.enable_listener().await.unwrap(); 717 718 // Simulate an active connection. 719 let stream = TcpStream::connect(peer1_ip).await.unwrap(); 720 tcp.connections.add(Connection::new(peer1_ip, stream, ConnectionSide::Responder)); 721 assert!(!tcp.can_add_connection()); 722 assert_eq!(tcp.num_connected(), 1); 723 assert_eq!(tcp.num_connecting(), 0); 724 assert!(tcp.is_connected(peer1_ip)); 725 assert!(!tcp.is_connecting(peer1_ip)); 726 727 // Initialize peer 2. 728 let peer2 = Tcp::new(Config { 729 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 730 desired_listening_port: Some(0), 731 max_connections: 1, 732 ..Default::default() 733 }); 734 let peer2_ip = peer2.enable_listener().await.unwrap(); 735 736 // Handle the connection. 737 let stream = TcpStream::connect(peer2_ip).await.unwrap(); 738 tcp.handle_connection(stream, peer2_ip); 739 assert!(!tcp.can_add_connection()); 740 assert_eq!(tcp.num_connected(), 1); 741 assert_eq!(tcp.num_connecting(), 0); 742 assert!(tcp.is_connected(peer1_ip)); 743 assert!(!tcp.is_connected(peer2_ip)); 744 assert!(!tcp.is_connecting(peer1_ip)); 745 assert!(!tcp.is_connecting(peer2_ip)); 746 } 747 748 #[tokio::test] 749 async fn test_adapt_stream() { 750 let tcp = Tcp::new(Config { max_connections: 1, ..Default::default() }); 751 752 // Initialize the peer. 753 let peer = Tcp::new(Config { 754 listener_ip: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), 755 desired_listening_port: Some(0), 756 max_connections: 1, 757 ..Default::default() 758 }); 759 let peer_ip = peer.enable_listener().await.unwrap(); 760 761 // Simulate a pending connection. 762 tcp.connecting.lock().insert(peer_ip); 763 assert_eq!(tcp.num_connected(), 0); 764 assert_eq!(tcp.num_connecting(), 1); 765 assert!(!tcp.is_connected(peer_ip)); 766 assert!(tcp.is_connecting(peer_ip)); 767 768 // Simulate a new connection. 769 let stream = TcpStream::connect(peer_ip).await.unwrap(); 770 tcp.adapt_stream(stream, peer_ip, ConnectionSide::Responder).await.unwrap(); 771 assert_eq!(tcp.num_connected(), 1); 772 assert_eq!(tcp.num_connecting(), 0); 773 assert!(tcp.is_connected(peer_ip)); 774 assert!(!tcp.is_connecting(peer_ip)); 775 } 776 }