connect.rs
1 //! Provides an abstract network connection interface and multiple 2 //! implementations 3 4 use std::collections::BTreeMap; 5 use std::fmt::Debug; 6 use std::net::SocketAddr; 7 use std::pin::Pin; 8 use std::sync::Arc; 9 10 use anyhow::format_err; 11 use async_trait::async_trait; 12 use fedimint_core::util::SafeUrl; 13 use fedimint_core::PeerId; 14 use futures::Stream; 15 use tokio::io::{ReadHalf, WriteHalf}; 16 use tokio::net::{TcpListener, TcpStream}; 17 use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient; 18 use tokio_rustls::rustls::RootCertStore; 19 use tokio_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream}; 20 21 use crate::net::framed::{AnyFramedTransport, BidiFramed, FramedTransport}; 22 23 /// Shared [`Connector`] trait object 24 pub type SharedAnyConnector<M> = Arc<dyn Connector<M> + Send + Sync + Unpin + 'static>; 25 26 /// Owned [`Connector`] trait object 27 pub type AnyConnector<M> = Box<dyn Connector<M> + Send + Sync + Unpin + 'static>; 28 29 /// Result of a connection opening future 30 pub type ConnectResult<M> = Result<(PeerId, AnyFramedTransport<M>), anyhow::Error>; 31 32 /// Owned trait object type for incoming connection listeners 33 pub type ConnectionListener<M> = 34 Pin<Box<dyn Stream<Item = ConnectResult<M>> + Send + Unpin + 'static>>; 35 36 /// Allows to connect to peers and to listen for incoming connections 37 /// 38 /// Connections are message based ([`FramedTransport`]) and should be 39 /// authenticated and encrypted for production deployments. 40 #[async_trait] 41 pub trait Connector<M> { 42 /// Connect to a `destination` 43 async fn connect_framed(&self, destination: SafeUrl, peer: PeerId) -> ConnectResult<M>; 44 45 /// Listen for incoming connections on `bind_addr` 46 async fn listen(&self, bind_addr: SocketAddr) -> Result<ConnectionListener<M>, anyhow::Error>; 47 48 /// Transform this concrete `Connector` into an owned trait object version 49 /// of itself 50 fn into_dyn(self) -> AnyConnector<M> 51 where 52 Self: Sized + Send + Sync + Unpin + 'static, 53 { 54 Box::new(self) 55 } 56 } 57 58 /// TCP connector with encryption and authentication 59 #[derive(Debug)] 60 pub struct TlsTcpConnector { 61 our_certificate: rustls::Certificate, 62 our_private_key: rustls::PrivateKey, 63 peer_certs: Arc<PeerCertStore>, 64 /// Copy of the certs from `peer_certs`, but in a format that `tokio_rustls` 65 /// understands 66 cert_store: RootCertStore, 67 peer_names: BTreeMap<PeerId, String>, 68 } 69 70 #[derive(Debug, Clone)] 71 pub struct TlsConfig { 72 pub our_private_key: rustls::PrivateKey, 73 pub peer_certs: BTreeMap<PeerId, rustls::Certificate>, 74 pub peer_names: BTreeMap<PeerId, String>, 75 } 76 77 #[derive(Debug, Clone)] 78 pub struct PeerCertStore { 79 peer_certificates: Vec<(PeerId, rustls::Certificate)>, 80 } 81 82 impl TlsTcpConnector { 83 pub fn new(cfg: TlsConfig, our_id: PeerId) -> TlsTcpConnector { 84 let mut cert_store = RootCertStore::empty(); 85 for (_, cert) in cfg.peer_certs.iter() { 86 cert_store 87 .add(cert) 88 .expect("Could not add peer certificate"); 89 } 90 91 TlsTcpConnector { 92 our_certificate: cfg.peer_certs.get(&our_id).expect("exists").clone(), 93 our_private_key: cfg.our_private_key, 94 peer_certs: Arc::new(PeerCertStore::new(cfg.peer_certs)), 95 cert_store, 96 peer_names: cfg.peer_names, 97 } 98 } 99 } 100 101 impl PeerCertStore { 102 fn new(certs: impl IntoIterator<Item = (PeerId, rustls::Certificate)>) -> PeerCertStore { 103 PeerCertStore { 104 peer_certificates: certs.into_iter().collect(), 105 } 106 } 107 108 fn get_peer_by_cert(&self, cert: &rustls::Certificate) -> Option<PeerId> { 109 self.peer_certificates 110 .iter() 111 .find_map(|(peer, peer_cert)| if peer_cert == cert { Some(*peer) } else { None }) 112 } 113 114 fn authenticate_peer( 115 &self, 116 received: Option<&[rustls::Certificate]>, 117 ) -> Result<PeerId, anyhow::Error> { 118 let cert_chain = 119 received.ok_or_else(|| anyhow::anyhow!("Peer did not authenticate itself"))?; 120 121 if cert_chain.len() != 1 { 122 return Err(anyhow::anyhow!( 123 "Received certificate chain of len={}, expected=1", 124 cert_chain.len() 125 )); 126 } 127 128 let received_cert = cert_chain.first().expect("Checked above"); 129 130 self.get_peer_by_cert(received_cert) 131 .ok_or_else(|| anyhow::anyhow!("Unknown certificate")) 132 } 133 134 async fn accept_connection<M>( 135 &self, 136 listener: &mut TcpListener, 137 acceptor: &TlsAcceptor, 138 ) -> Result<(PeerId, AnyFramedTransport<M>), anyhow::Error> 139 where 140 M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static, 141 { 142 let (connection, _) = listener.accept().await?; 143 let tls_conn = acceptor.accept(connection).await?; 144 145 let (_, tls_session) = tls_conn.get_ref(); 146 let auth_peer = self.authenticate_peer(tls_session.peer_certificates())?; 147 148 let framed = 149 BidiFramed::<_, WriteHalf<TlsStream<TcpStream>>, ReadHalf<TlsStream<TcpStream>>>::new( 150 tls_conn, 151 ) 152 .into_dyn(); 153 Ok((auth_peer, framed)) 154 } 155 } 156 157 #[async_trait] 158 impl<M> Connector<M> for TlsTcpConnector 159 where 160 M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static, 161 { 162 async fn connect_framed(&self, destination: SafeUrl, peer: PeerId) -> ConnectResult<M> { 163 let cfg = rustls::ClientConfig::builder() 164 .with_safe_defaults() 165 .with_root_certificates(self.cert_store.clone()) 166 .with_client_auth_cert( 167 vec![self.our_certificate.clone()], 168 self.our_private_key.clone(), 169 ) 170 .expect("Failed to create TLS config"); 171 172 let fake_domain = 173 rustls::ServerName::try_from(dns_sanitize(&self.peer_names[&peer]).as_str()) 174 .expect("Always a valid DNS name"); 175 176 let connector = TlsConnector::from(Arc::new(cfg)); 177 let tls_conn = connector 178 .connect( 179 fake_domain, 180 TcpStream::connect(parse_host_port(destination)?).await?, 181 ) 182 .await?; 183 184 let (_, tls_session) = tls_conn.get_ref(); 185 let auth_peer = self 186 .peer_certs 187 .authenticate_peer(tls_session.peer_certificates())?; 188 189 if auth_peer != peer { 190 return Err(anyhow::anyhow!("Connected to unexpected peer")); 191 } 192 193 let framed = 194 BidiFramed::<_, WriteHalf<TlsStream<TcpStream>>, ReadHalf<TlsStream<TcpStream>>>::new( 195 tls_conn, 196 ) 197 .into_dyn(); 198 199 Ok((peer, framed)) 200 } 201 202 async fn listen(&self, bind_addr: SocketAddr) -> Result<ConnectionListener<M>, anyhow::Error> { 203 let verifier = AllowAnyAuthenticatedClient::new(self.cert_store.clone()); 204 let config = rustls::ServerConfig::builder() 205 .with_safe_defaults() 206 .with_client_cert_verifier(Arc::from(verifier)) 207 .with_single_cert( 208 vec![self.our_certificate.clone()], 209 self.our_private_key.clone(), 210 ) 211 .unwrap(); 212 let listener = TcpListener::bind(bind_addr).await?; 213 let peer_certs = self.peer_certs.clone(); 214 215 let stream = futures::stream::unfold(listener, move |mut listener| { 216 let acceptor = TlsAcceptor::from(Arc::new(config.clone())); 217 let peer_certs = peer_certs.clone(); 218 219 Box::pin(async move { 220 let res = peer_certs.accept_connection(&mut listener, &acceptor).await; 221 Some((res, listener)) 222 }) 223 }); 224 Ok(Box::pin(stream)) 225 } 226 } 227 228 /// Sanitizes name as valid domain name 229 pub fn dns_sanitize(name: &str) -> String { 230 let sanitized = name.replace(|c: char| !c.is_ascii_alphanumeric(), "_"); 231 format!("peer{sanitized}") 232 } 233 234 /// Parses the host and port from a url 235 pub fn parse_host_port(url: SafeUrl) -> anyhow::Result<String> { 236 let host = url 237 .host_str() 238 .ok_or_else(|| format_err!("Missing host in {url}"))?; 239 let port = url 240 .port() 241 .ok_or_else(|| format_err!("Missing port in {url}"))?; 242 243 Ok(format!("{host}:{port}")) 244 } 245 246 /// Fake network stack used in tests 247 #[allow(unused_imports)] 248 pub mod mock { 249 use std::collections::HashMap; 250 use std::fmt::Debug; 251 use std::future::Future; 252 use std::net::SocketAddr; 253 use std::pin::Pin; 254 use std::sync::atomic::{AtomicBool, Ordering}; 255 use std::sync::Arc; 256 use std::time::Duration; 257 258 use anyhow::{anyhow, Error}; 259 use fedimint_core::runtime::spawn; 260 use fedimint_core::task::sleep; 261 use fedimint_core::util::SafeUrl; 262 use fedimint_core::{task, PeerId}; 263 use futures::{pin_mut, FutureExt, SinkExt, Stream, StreamExt}; 264 use rand::Rng; 265 use tokio::io::{ 266 AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf, 267 }; 268 use tokio::sync::mpsc::Sender; 269 use tokio::sync::Mutex; 270 use tokio_util::sync::CancellationToken; 271 use tracing::error; 272 273 use crate::net::connect::{parse_host_port, ConnectResult, Connector}; 274 use crate::net::framed::{BidiFramed, FramedTransport}; 275 276 struct UnreliableDuplexStream { 277 inner: DuplexStream, 278 broken: CancellationToken, 279 read_generator: Option<UnreliabilityGenerator>, 280 write_generator: Option<UnreliabilityGenerator>, 281 flush_generator: Option<UnreliabilityGenerator>, 282 shutdown_generator: Option<UnreliabilityGenerator>, 283 } 284 285 impl UnreliableDuplexStream { 286 fn new(inner: DuplexStream, reliability: StreamReliability) -> UnreliableDuplexStream { 287 match reliability { 288 StreamReliability::FullyReliable => Self { 289 inner, 290 broken: CancellationToken::new(), 291 read_generator: None, 292 write_generator: None, 293 flush_generator: None, 294 shutdown_generator: None, 295 }, 296 StreamReliability::RandomlyUnreliable { 297 read_failure_rate, 298 write_failure_rate, 299 flush_failure_rate, 300 shutdown_failure_rate, 301 read_latency, 302 write_latency, 303 flush_latency, 304 shutdown_latency, 305 } => Self { 306 inner, 307 broken: CancellationToken::new(), 308 read_generator: Some(UnreliabilityGenerator::new( 309 read_latency, 310 read_failure_rate, 311 )), 312 write_generator: Some(UnreliabilityGenerator::new( 313 write_latency, 314 write_failure_rate, 315 )), 316 flush_generator: Some(UnreliabilityGenerator::new( 317 flush_latency, 318 flush_failure_rate, 319 )), 320 shutdown_generator: Some(UnreliabilityGenerator::new( 321 shutdown_latency, 322 shutdown_failure_rate, 323 )), 324 }, 325 } 326 } 327 328 fn poll_broken(&self, cx: &mut std::task::Context<'_>) -> bool { 329 let await_cancellation = self.broken.cancelled(); 330 pin_mut!(await_cancellation); 331 await_cancellation.poll(cx).is_ready() 332 } 333 } 334 335 impl Debug for UnreliableDuplexStream { 336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 337 f.debug_struct("UnreliableDuplexStream").finish() 338 } 339 } 340 341 struct UnreliabilityGenerator { 342 latency: LatencyInterval, 343 failure_rate: FailureRate, 344 sleep_future: Option<Pin<Box<tokio::time::Sleep>>>, 345 successes: u64, 346 } 347 348 impl UnreliabilityGenerator { 349 fn new(latency: LatencyInterval, failure_rate: FailureRate) -> UnreliabilityGenerator { 350 Self { 351 latency, 352 failure_rate, 353 sleep_future: None, 354 successes: 0, 355 } 356 } 357 358 pub fn generate( 359 &mut self, 360 cx: &mut std::task::Context<'_>, 361 ) -> std::task::Poll<std::io::Result<()>> { 362 let sleep = self.sleep_future.get_or_insert_with(|| { 363 Box::pin( 364 // nosemgrep: ban-tokio-sleep 365 tokio::time::sleep(self.latency.random()), 366 ) 367 }); 368 match sleep.poll_unpin(cx) { 369 std::task::Poll::Ready(()) => { 370 self.sleep_future = None; 371 } 372 std::task::Poll::Pending => return std::task::Poll::Pending, 373 } 374 if self.failure_rate.random_fail() { 375 tracing::debug!( 376 "Returning random error on unreliable stream after {} successes", 377 self.successes 378 ); 379 std::task::Poll::Ready(Err(std::io::Error::new( 380 std::io::ErrorKind::Other, 381 "Randomly failed", 382 ))) 383 } else { 384 self.successes += 1; 385 std::task::Poll::Ready(Ok(())) 386 } 387 } 388 } 389 390 impl AsyncRead for UnreliableDuplexStream { 391 fn poll_read( 392 mut self: Pin<&mut Self>, 393 cx: &mut std::task::Context<'_>, 394 buf: &mut tokio::io::ReadBuf<'_>, 395 ) -> std::task::Poll<std::io::Result<()>> { 396 if self.poll_broken(cx) { 397 return std::task::Poll::Ready(Err(std::io::Error::new( 398 std::io::ErrorKind::Other, 399 "Stream is broken", 400 ))); 401 } 402 403 match self.read_generator.as_mut().map(|g| g.generate(cx)) { 404 Some(std::task::Poll::Ready(Err(e))) => { 405 self.broken.cancel(); 406 std::task::Poll::Ready(Err(e)) 407 } 408 Some(std::task::Poll::Pending) => std::task::Poll::Pending, 409 Some(std::task::Poll::Ready(Ok(()))) | None => { 410 Pin::new(&mut self.inner).poll_read(cx, buf) 411 } 412 } 413 } 414 } 415 416 impl AsyncWrite for UnreliableDuplexStream { 417 fn poll_write( 418 mut self: Pin<&mut Self>, 419 cx: &mut std::task::Context<'_>, 420 buf: &[u8], 421 ) -> std::task::Poll<Result<usize, std::io::Error>> { 422 if self.poll_broken(cx) { 423 return std::task::Poll::Ready(Err(std::io::Error::new( 424 std::io::ErrorKind::Other, 425 "Stream is broken", 426 ))); 427 } 428 429 match self.write_generator.as_mut().map(|g| g.generate(cx)) { 430 Some(std::task::Poll::Ready(Err(e))) => { 431 self.broken.cancel(); 432 std::task::Poll::Ready(Err(e)) 433 } 434 Some(std::task::Poll::Pending) => std::task::Poll::Pending, 435 Some(std::task::Poll::Ready(Ok(()))) | None => { 436 Pin::new(&mut self.inner).poll_write(cx, buf) 437 } 438 } 439 } 440 441 fn poll_flush( 442 mut self: Pin<&mut Self>, 443 cx: &mut std::task::Context<'_>, 444 ) -> std::task::Poll<Result<(), std::io::Error>> { 445 if self.poll_broken(cx) { 446 return std::task::Poll::Ready(Err(std::io::Error::new( 447 std::io::ErrorKind::Other, 448 "Stream is broken", 449 ))); 450 } 451 452 match self.flush_generator.as_mut().map(|g| g.generate(cx)) { 453 Some(std::task::Poll::Ready(Err(e))) => { 454 self.broken.cancel(); 455 std::task::Poll::Ready(Err(e)) 456 } 457 Some(std::task::Poll::Pending) => std::task::Poll::Pending, 458 Some(std::task::Poll::Ready(Ok(()))) | None => { 459 Pin::new(&mut self.inner).poll_flush(cx) 460 } 461 } 462 } 463 464 fn poll_shutdown( 465 mut self: Pin<&mut Self>, 466 cx: &mut std::task::Context<'_>, 467 ) -> std::task::Poll<Result<(), std::io::Error>> { 468 if self.poll_broken(cx) { 469 return std::task::Poll::Ready(Err(std::io::Error::new( 470 std::io::ErrorKind::Other, 471 "Stream is broken", 472 ))); 473 } 474 475 match self.shutdown_generator.as_mut().map(|g| g.generate(cx)) { 476 Some(std::task::Poll::Ready(Err(e))) => { 477 self.broken.cancel(); 478 std::task::Poll::Ready(Err(e)) 479 } 480 Some(std::task::Poll::Pending) => std::task::Poll::Pending, 481 Some(std::task::Poll::Ready(Ok(()))) | None => { 482 Pin::new(&mut self.inner).poll_shutdown(cx) 483 } 484 } 485 } 486 } 487 488 pub struct MockNetwork { 489 clients: Arc<Mutex<HashMap<String, Sender<UnreliableDuplexStream>>>>, 490 } 491 492 pub struct MockConnector { 493 id: PeerId, 494 clients: Arc<Mutex<HashMap<String, Sender<UnreliableDuplexStream>>>>, 495 reliability: StreamReliability, 496 } 497 498 impl MockNetwork { 499 #[allow(clippy::new_without_default)] 500 pub fn new() -> MockNetwork { 501 MockNetwork { 502 clients: Arc::new(Default::default()), 503 } 504 } 505 506 pub fn connector(&self, id: PeerId, reliability: StreamReliability) -> MockConnector { 507 MockConnector { 508 id, 509 clients: self.clients.clone(), 510 reliability, 511 } 512 } 513 } 514 515 #[derive(Debug, Copy, Clone, PartialEq, Eq)] 516 pub struct LatencyInterval { 517 min_millis: u64, 518 max_millis: u64, 519 } 520 521 impl LatencyInterval { 522 const ZERO: LatencyInterval = LatencyInterval { 523 min_millis: 0, 524 max_millis: 0, 525 }; 526 527 pub fn new(min: Duration, max: Duration) -> LatencyInterval { 528 assert!(min <= max); 529 LatencyInterval { 530 min_millis: min 531 .as_millis() 532 .try_into() 533 .expect("min duration as millis to fit in a u64"), 534 max_millis: max 535 .as_millis() 536 .try_into() 537 .expect("max duration as millis to fit in a u64"), 538 } 539 } 540 541 pub fn random(&self) -> Duration { 542 let mut rng = rand::thread_rng(); 543 Duration::from_millis(rng.gen_range(self.min_millis..=self.max_millis)) 544 } 545 } 546 547 #[derive(Debug, Copy, Clone)] 548 pub struct FailureRate(f64); 549 impl FailureRate { 550 const MAX: FailureRate = FailureRate(1.0); 551 pub fn new(failure_rate: f64) -> Self { 552 assert!((0.0..=1.0).contains(&failure_rate)); 553 Self(failure_rate) 554 } 555 556 pub fn random_fail(&self) -> bool { 557 let mut rng = rand::thread_rng(); 558 rng.gen_range(0.0..1.0) < self.0 559 } 560 } 561 562 #[derive(Debug, Copy, Clone)] 563 pub enum StreamReliability { 564 FullyReliable, 565 RandomlyUnreliable { 566 read_failure_rate: FailureRate, 567 write_failure_rate: FailureRate, 568 flush_failure_rate: FailureRate, 569 shutdown_failure_rate: FailureRate, 570 read_latency: LatencyInterval, 571 write_latency: LatencyInterval, 572 flush_latency: LatencyInterval, 573 shutdown_latency: LatencyInterval, 574 }, 575 } 576 577 impl StreamReliability { 578 pub const MILDLY_UNRELIABLE: StreamReliability = { 579 let failure_rate = FailureRate(0.0); 580 let latency = LatencyInterval { 581 min_millis: 1, 582 max_millis: 10, 583 }; 584 Self::RandomlyUnreliable { 585 read_failure_rate: failure_rate, 586 write_failure_rate: failure_rate, 587 flush_failure_rate: failure_rate, 588 shutdown_failure_rate: failure_rate, 589 read_latency: latency, 590 write_latency: latency, 591 flush_latency: latency, 592 shutdown_latency: latency, 593 } 594 }; 595 596 pub const INTEGRATION_TEST: StreamReliability = { 597 // Based on empirical testing: creates errors without causing tests to take 598 // additional time compared to StreamReliability::FullyReliable 599 // If an order of magnitude higher, tests may take unreasonable amounts of time. 600 // If an order of magnitude lower, a test may run without any error actually 601 // happening 602 let failure_rate_base = 0.0; 603 let latency = LatencyInterval { 604 min_millis: 1, 605 max_millis: 10, 606 }; 607 Self::RandomlyUnreliable { 608 // Try to make read_failure_rate = write_failure_rate + flush_failure_rate 609 read_failure_rate: FailureRate(failure_rate_base * 2.0), 610 write_failure_rate: FailureRate(failure_rate_base), 611 flush_failure_rate: FailureRate(failure_rate_base), 612 shutdown_failure_rate: FailureRate(failure_rate_base), 613 read_latency: latency, 614 write_latency: latency, 615 flush_latency: latency, 616 shutdown_latency: latency, 617 } 618 }; 619 620 pub const BROKEN: StreamReliability = { 621 Self::RandomlyUnreliable { 622 read_failure_rate: FailureRate::MAX, 623 write_failure_rate: FailureRate::MAX, 624 flush_failure_rate: FailureRate::MAX, 625 shutdown_failure_rate: FailureRate::MAX, 626 read_latency: LatencyInterval::ZERO, 627 write_latency: LatencyInterval::ZERO, 628 flush_latency: LatencyInterval::ZERO, 629 shutdown_latency: LatencyInterval::ZERO, 630 } 631 }; 632 } 633 634 #[async_trait::async_trait] 635 impl<M> Connector<M> for MockConnector 636 where 637 M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static, 638 { 639 async fn connect_framed(&self, destination: SafeUrl, _peer: PeerId) -> ConnectResult<M> { 640 let mut clients_lock = self.clients.try_lock().map_err(|e| { 641 anyhow!("Mock network mutex busy or poisoned, the network stack will re-try anyway: {e:?}") 642 })?; 643 if let Some(client) = clients_lock.get_mut(&parse_host_port(destination)?) { 644 let (stream_our, stream_theirs) = tokio::io::duplex(43_689); 645 let mut stream_our = UnreliableDuplexStream::new(stream_our, self.reliability); 646 let stream_theirs = UnreliableDuplexStream::new(stream_theirs, self.reliability); 647 client.send(stream_theirs).await?; 648 let peer = do_handshake(self.id, &mut stream_our).await?; 649 let framed = BidiFramed::< 650 M, 651 WriteHalf<UnreliableDuplexStream>, 652 ReadHalf<UnreliableDuplexStream>, 653 >::new(stream_our) 654 .into_dyn(); 655 Ok((peer, framed)) 656 } else { 657 return Err(anyhow::anyhow!("can't connect")); 658 } 659 } 660 661 async fn listen( 662 &self, 663 bind_addr: SocketAddr, 664 ) -> Result<Pin<Box<dyn Stream<Item = ConnectResult<M>> + Send + Unpin + 'static>>, Error> 665 { 666 let (send, receive) = tokio::sync::mpsc::channel(16); 667 668 if self 669 .clients 670 .lock() 671 .await 672 .insert(bind_addr.to_string(), send) 673 .is_some() 674 { 675 return Err(anyhow::anyhow!("Address already bound")); 676 } 677 678 let our_id = self.id; 679 let stream = futures::stream::unfold(receive, move |mut receive| { 680 Box::pin(async move { 681 let mut connection = receive.recv().await.unwrap(); 682 let peer = match do_handshake(our_id, &mut connection).await { 683 Ok(peer) => peer, 684 Err(e) => { 685 tracing::debug!("Error during handshake: {e:?}"); 686 return Some((Err(e), receive)); 687 } 688 }; 689 let framed = 690 BidiFramed::<M, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new( 691 connection, 692 ) 693 .into_dyn(); 694 695 Some((Ok((peer, framed)), receive)) 696 }) 697 }); 698 Ok(Box::pin(stream)) 699 } 700 } 701 702 async fn do_handshake<S>(our_id: PeerId, stream: &mut S) -> Result<PeerId, anyhow::Error> 703 where 704 S: AsyncRead + AsyncWrite + Unpin, 705 { 706 // Send our id 707 let our_id = our_id.to_usize() as u16; 708 stream.write_all(&our_id.to_be_bytes()[..]).await?; 709 710 // Receive peer id 711 let mut peer_id = [0u8; 2]; 712 stream.read_exact(&mut peer_id[..]).await?; 713 Ok(PeerId::from(u16::from_be_bytes(peer_id))) 714 } 715 716 #[tokio::test] 717 async fn test_mock_network() { 718 let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap(); 719 let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap(); 720 let peer_a = PeerId::from(1); 721 let peer_b = PeerId::from(2); 722 723 let net = MockNetwork::new(); 724 let conn_a = net.connector(peer_a, StreamReliability::FullyReliable); 725 let conn_b = net.connector(peer_b, StreamReliability::FullyReliable); 726 727 let mut listener = Connector::<u64>::listen(&conn_a, bind_addr).await.unwrap(); 728 let conn_a_fut = spawn("listener next await", async move { 729 listener.next().await.unwrap().unwrap() 730 }); 731 732 let (auth_peer_b, mut conn_b) = Connector::<u64>::connect_framed(&conn_b, url, peer_a) 733 .await 734 .unwrap(); 735 let (auth_peer_a, mut conn_a) = conn_a_fut.await.unwrap(); 736 737 assert_eq!(auth_peer_a, peer_b); 738 assert_eq!(auth_peer_b, peer_a); 739 740 conn_a.send(42).await.unwrap(); 741 conn_b.send(21).await.unwrap(); 742 743 assert_eq!(conn_a.next().await.unwrap().unwrap(), 21); 744 assert_eq!(conn_b.next().await.unwrap().unwrap(), 42); 745 } 746 747 #[tokio::test] 748 async fn test_unreliable_components() { 749 assert!(!FailureRate::new(0f64).random_fail()); 750 assert!(FailureRate::new(1f64).random_fail()); 751 752 let good_interval = (0..=3).contains( 753 &LatencyInterval::new(Duration::from_millis(0), Duration::from_millis(3)) 754 .random() 755 .as_millis(), 756 ); 757 assert!(good_interval); 758 759 let (a, b) = tokio::io::duplex(43_689); 760 let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::FullyReliable); 761 let mut b_stream = UnreliableDuplexStream::new(b, StreamReliability::FullyReliable); 762 assert!(a_stream.write(&[1, 2, 3]).await.is_ok()); 763 assert!(a_stream.flush().await.is_ok()); 764 assert_eq!(b_stream.read_u8().await.unwrap(), 1); 765 assert_eq!(b_stream.read_u8().await.unwrap(), 2); 766 assert_eq!(b_stream.read_u8().await.unwrap(), 3); 767 768 let (a, b) = tokio::io::duplex(43_689); 769 let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::FullyReliable); 770 let mut b_stream = UnreliableDuplexStream::new(b, StreamReliability::BROKEN); 771 assert!(a_stream.write(&[1, 2, 3]).await.is_ok()); 772 assert!(a_stream.flush().await.is_ok()); 773 assert!(b_stream.read_u8().await.is_err()); 774 775 let (a, b) = tokio::io::duplex(43_689); 776 let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::BROKEN); 777 let mut _b_stream = UnreliableDuplexStream::new(b, StreamReliability::FullyReliable); 778 assert!(a_stream.write(&[1, 2, 3]).await.is_err()); 779 // a read on _b_stream would block... 780 } 781 782 #[allow(dead_code)] 783 async fn timeout<F, T>(f: F) -> Option<T> 784 where 785 F: Future<Output = T>, 786 { 787 tokio::time::timeout(Duration::from_secs(1), f).await.ok() 788 } 789 790 #[tokio::test] 791 async fn test_large_messages() { 792 let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap(); 793 let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap(); 794 let peer_a = PeerId::from(1); 795 let peer_b = PeerId::from(2); 796 797 let net = MockNetwork::new(); 798 let conn_a = net.connector(peer_a, StreamReliability::FullyReliable); 799 let conn_b = net.connector(peer_b, StreamReliability::FullyReliable); 800 801 let mut listener = Connector::<Vec<u8>>::listen(&conn_a, bind_addr) 802 .await 803 .unwrap(); 804 let conn_a_fut = spawn("listener next await", async move { 805 listener.next().await.unwrap().unwrap() 806 }); 807 808 let (auth_peer_b, mut conn_b) = Connector::<Vec<u8>>::connect_framed(&conn_b, url, peer_a) 809 .await 810 .unwrap(); 811 let (auth_peer_a, mut conn_a) = conn_a_fut.await.unwrap(); 812 813 assert_eq!(auth_peer_a, peer_b); 814 assert_eq!(auth_peer_b, peer_a); 815 816 let send_future = async move { 817 conn_a.send(vec![42; 16000]).await.unwrap(); 818 } 819 .boxed(); 820 let receive_future = async move { 821 assert_eq!( 822 timeout(conn_b.next()).await.unwrap().unwrap().unwrap(), 823 vec![42; 16000] 824 ); 825 } 826 .boxed(); 827 828 tokio::join!(send_future, receive_future); 829 } 830 } 831 832 #[cfg(test)] 833 mod tests { 834 use std::net::SocketAddr; 835 836 use fedimint_core::runtime::spawn; 837 use fedimint_core::util::SafeUrl; 838 use fedimint_core::PeerId; 839 use futures::{SinkExt, StreamExt}; 840 841 use crate::config::gen_cert_and_key; 842 use crate::net::connect::{ConnectionListener, Connector, TlsConfig}; 843 use crate::net::framed::AnyFramedTransport; 844 use crate::TlsTcpConnector; 845 846 fn gen_connector_config(count: usize) -> Vec<TlsConfig> { 847 let peer_keys = (0..count) 848 .map(|id| { 849 let peer = PeerId::from(id as u16); 850 gen_cert_and_key(&format!("peer-{}", peer.to_usize())).unwrap() 851 }) 852 .collect::<Vec<_>>(); 853 854 peer_keys 855 .iter() 856 .map(|(_cert, key)| TlsConfig { 857 our_private_key: key.clone(), 858 peer_certs: peer_keys 859 .iter() 860 .enumerate() 861 .map(|(peer, (cert, _))| (PeerId::from(peer as u16), cert.clone())) 862 .collect(), 863 peer_names: peer_keys 864 .iter() 865 .enumerate() 866 .map(|(peer, (_, _))| (PeerId::from(peer as u16), format!("peer-{peer}"))) 867 .collect(), 868 }) 869 .collect() 870 } 871 872 #[tokio::test] 873 async fn connect_success() { 874 // FIXME: don't actually bind here, probably requires yet another Box<dyn Trait> 875 // layer :( 876 let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap(); 877 let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap(); 878 let connectors = gen_connector_config(5) 879 .into_iter() 880 .enumerate() 881 .map(|(id, cfg)| TlsTcpConnector::new(cfg, PeerId::from(id as u16))) 882 .collect::<Vec<_>>(); 883 884 let mut server: ConnectionListener<u64> = connectors[0].listen(bind_addr).await.unwrap(); 885 886 let server_task = spawn("server next await", async move { 887 let (peer, mut conn) = server.next().await.unwrap().unwrap(); 888 assert_eq!(peer.to_usize(), 2); 889 let received = conn.next().await.unwrap().unwrap(); 890 assert_eq!(received, 42); 891 conn.send(21).await.unwrap(); 892 assert!(conn.next().await.unwrap().is_err()); 893 }); 894 895 let (peer_of_a, mut client_a): (_, AnyFramedTransport<u64>) = connectors[2] 896 .connect_framed(url.clone(), PeerId::from(0)) 897 .await 898 .unwrap(); 899 assert_eq!(peer_of_a.to_usize(), 0); 900 client_a.send(42).await.unwrap(); 901 let received = client_a.next().await.unwrap().unwrap(); 902 assert_eq!(received, 21); 903 drop(client_a); 904 905 server_task.await.unwrap(); 906 } 907 908 #[tokio::test] 909 async fn connect_reject() { 910 let bind_addr: SocketAddr = "127.0.0.1:7001".parse().unwrap(); 911 let url: SafeUrl = "wss://127.0.0.1:7001".parse().unwrap(); 912 let cfg = gen_connector_config(5); 913 914 let honest = TlsTcpConnector::new(cfg[0].clone(), PeerId::from(0)); 915 916 let mut malicious_wrong_key_cfg = cfg[1].clone(); 917 malicious_wrong_key_cfg.our_private_key = cfg[2].our_private_key.clone(); 918 let malicious_wrong_key = TlsTcpConnector::new(malicious_wrong_key_cfg, PeerId::from(1)); 919 920 // Honest server, malicious client with wrong private key 921 { 922 let mut server: ConnectionListener<u64> = honest.listen(bind_addr).await.unwrap(); 923 924 let server_task = spawn("server next await", async move { 925 let conn_res = server.next().await.unwrap(); 926 assert_eq!( 927 conn_res.err().unwrap().to_string().as_str(), 928 "invalid peer certificate: BadSignature" 929 ); 930 }); 931 932 let err_anytime = async { 933 let (_peer, mut conn): (_, AnyFramedTransport<u64>) = malicious_wrong_key 934 .connect_framed(url.clone(), PeerId::from(0)) 935 .await?; 936 937 conn.send(42).await?; 938 conn.flush().await?; 939 conn.next().await.unwrap()?; 940 941 Result::<_, anyhow::Error>::Ok(()) 942 }; 943 944 let conn_res = err_anytime.await; 945 assert_eq!( 946 conn_res.err().unwrap().to_string().as_str(), 947 "received fatal alert: DecryptError" 948 ); 949 950 server_task.await.unwrap(); 951 } 952 953 // Malicious server with wrong key, honest client 954 { 955 let mut server: ConnectionListener<u64> = 956 malicious_wrong_key.listen(bind_addr).await.unwrap(); 957 958 let server_task = spawn("server next await", async move { 959 let conn_res = server.next().await.unwrap(); 960 assert_eq!( 961 conn_res.err().unwrap().to_string().as_str(), 962 "received fatal alert: DecryptError" 963 ); 964 }); 965 966 let err_anytime = async { 967 let (_peer, mut conn): (_, AnyFramedTransport<u64>) = 968 honest.connect_framed(url.clone(), PeerId::from(1)).await?; 969 970 conn.send(42).await?; 971 conn.flush().await?; 972 conn.next().await.unwrap()?; 973 974 Result::<_, anyhow::Error>::Ok(()) 975 }; 976 977 let conn_res = err_anytime.await; 978 assert_eq!( 979 conn_res.err().unwrap().to_string().as_str(), 980 "invalid peer certificate: BadSignature" 981 ); 982 983 server_task.await.unwrap(); 984 } 985 986 // Server with wrong certificate, honest client 987 { 988 let mut server: ConnectionListener<u64> = 989 TlsTcpConnector::new(cfg[2].clone(), PeerId::from(2)) 990 .listen(bind_addr) 991 .await 992 .unwrap(); 993 994 let server_task = spawn("server next await", async move { 995 let conn_res = server.next().await.unwrap(); 996 assert_eq!( 997 conn_res.err().unwrap().to_string().as_str(), 998 "received fatal alert: BadCertificate" 999 ); 1000 }); 1001 1002 let err_anytime = async { 1003 let (_peer, mut conn): (_, AnyFramedTransport<u64>) = 1004 honest.connect_framed(url.clone(), PeerId::from(0)).await?; 1005 1006 conn.send(42).await?; 1007 conn.flush().await?; 1008 conn.next().await.unwrap()?; 1009 1010 Result::<_, anyhow::Error>::Ok(()) 1011 }; 1012 1013 let conn_res = err_anytime.await; 1014 assert_eq!( 1015 conn_res.err().unwrap().to_string().as_str(), 1016 "invalid peer certificate: NotValidForName" 1017 ); 1018 1019 server_task.await.unwrap(); 1020 } 1021 } 1022 }