/ fedimint-server / src / net / connect.rs
connect.rs
   1  //! Provides an abstract network connection interface and multiple
   2  //! implementations
   3  
   4  use std::collections::BTreeMap;
   5  use std::fmt::Debug;
   6  use std::net::SocketAddr;
   7  use std::pin::Pin;
   8  use std::sync::Arc;
   9  
  10  use anyhow::format_err;
  11  use async_trait::async_trait;
  12  use fedimint_core::util::SafeUrl;
  13  use fedimint_core::PeerId;
  14  use futures::Stream;
  15  use tokio::io::{ReadHalf, WriteHalf};
  16  use tokio::net::{TcpListener, TcpStream};
  17  use tokio_rustls::rustls::server::AllowAnyAuthenticatedClient;
  18  use tokio_rustls::rustls::RootCertStore;
  19  use tokio_rustls::{rustls, TlsAcceptor, TlsConnector, TlsStream};
  20  
  21  use crate::net::framed::{AnyFramedTransport, BidiFramed, FramedTransport};
  22  
  23  /// Shared [`Connector`] trait object
  24  pub type SharedAnyConnector<M> = Arc<dyn Connector<M> + Send + Sync + Unpin + 'static>;
  25  
  26  /// Owned [`Connector`] trait object
  27  pub type AnyConnector<M> = Box<dyn Connector<M> + Send + Sync + Unpin + 'static>;
  28  
  29  /// Result of a connection opening future
  30  pub type ConnectResult<M> = Result<(PeerId, AnyFramedTransport<M>), anyhow::Error>;
  31  
  32  /// Owned trait object type for incoming connection listeners
  33  pub type ConnectionListener<M> =
  34      Pin<Box<dyn Stream<Item = ConnectResult<M>> + Send + Unpin + 'static>>;
  35  
  36  /// Allows to connect to peers and to listen for incoming connections
  37  ///
  38  /// Connections are message based ([`FramedTransport`]) and should be
  39  /// authenticated and encrypted for production deployments.
  40  #[async_trait]
  41  pub trait Connector<M> {
  42      /// Connect to a `destination`
  43      async fn connect_framed(&self, destination: SafeUrl, peer: PeerId) -> ConnectResult<M>;
  44  
  45      /// Listen for incoming connections on `bind_addr`
  46      async fn listen(&self, bind_addr: SocketAddr) -> Result<ConnectionListener<M>, anyhow::Error>;
  47  
  48      /// Transform this concrete `Connector` into an owned trait object version
  49      /// of itself
  50      fn into_dyn(self) -> AnyConnector<M>
  51      where
  52          Self: Sized + Send + Sync + Unpin + 'static,
  53      {
  54          Box::new(self)
  55      }
  56  }
  57  
  58  /// TCP connector with encryption and authentication
  59  #[derive(Debug)]
  60  pub struct TlsTcpConnector {
  61      our_certificate: rustls::Certificate,
  62      our_private_key: rustls::PrivateKey,
  63      peer_certs: Arc<PeerCertStore>,
  64      /// Copy of the certs from `peer_certs`, but in a format that `tokio_rustls`
  65      /// understands
  66      cert_store: RootCertStore,
  67      peer_names: BTreeMap<PeerId, String>,
  68  }
  69  
  70  #[derive(Debug, Clone)]
  71  pub struct TlsConfig {
  72      pub our_private_key: rustls::PrivateKey,
  73      pub peer_certs: BTreeMap<PeerId, rustls::Certificate>,
  74      pub peer_names: BTreeMap<PeerId, String>,
  75  }
  76  
  77  #[derive(Debug, Clone)]
  78  pub struct PeerCertStore {
  79      peer_certificates: Vec<(PeerId, rustls::Certificate)>,
  80  }
  81  
  82  impl TlsTcpConnector {
  83      pub fn new(cfg: TlsConfig, our_id: PeerId) -> TlsTcpConnector {
  84          let mut cert_store = RootCertStore::empty();
  85          for (_, cert) in cfg.peer_certs.iter() {
  86              cert_store
  87                  .add(cert)
  88                  .expect("Could not add peer certificate");
  89          }
  90  
  91          TlsTcpConnector {
  92              our_certificate: cfg.peer_certs.get(&our_id).expect("exists").clone(),
  93              our_private_key: cfg.our_private_key,
  94              peer_certs: Arc::new(PeerCertStore::new(cfg.peer_certs)),
  95              cert_store,
  96              peer_names: cfg.peer_names,
  97          }
  98      }
  99  }
 100  
 101  impl PeerCertStore {
 102      fn new(certs: impl IntoIterator<Item = (PeerId, rustls::Certificate)>) -> PeerCertStore {
 103          PeerCertStore {
 104              peer_certificates: certs.into_iter().collect(),
 105          }
 106      }
 107  
 108      fn get_peer_by_cert(&self, cert: &rustls::Certificate) -> Option<PeerId> {
 109          self.peer_certificates
 110              .iter()
 111              .find_map(|(peer, peer_cert)| if peer_cert == cert { Some(*peer) } else { None })
 112      }
 113  
 114      fn authenticate_peer(
 115          &self,
 116          received: Option<&[rustls::Certificate]>,
 117      ) -> Result<PeerId, anyhow::Error> {
 118          let cert_chain =
 119              received.ok_or_else(|| anyhow::anyhow!("Peer did not authenticate itself"))?;
 120  
 121          if cert_chain.len() != 1 {
 122              return Err(anyhow::anyhow!(
 123                  "Received certificate chain of len={}, expected=1",
 124                  cert_chain.len()
 125              ));
 126          }
 127  
 128          let received_cert = cert_chain.first().expect("Checked above");
 129  
 130          self.get_peer_by_cert(received_cert)
 131              .ok_or_else(|| anyhow::anyhow!("Unknown certificate"))
 132      }
 133  
 134      async fn accept_connection<M>(
 135          &self,
 136          listener: &mut TcpListener,
 137          acceptor: &TlsAcceptor,
 138      ) -> Result<(PeerId, AnyFramedTransport<M>), anyhow::Error>
 139      where
 140          M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static,
 141      {
 142          let (connection, _) = listener.accept().await?;
 143          let tls_conn = acceptor.accept(connection).await?;
 144  
 145          let (_, tls_session) = tls_conn.get_ref();
 146          let auth_peer = self.authenticate_peer(tls_session.peer_certificates())?;
 147  
 148          let framed =
 149              BidiFramed::<_, WriteHalf<TlsStream<TcpStream>>, ReadHalf<TlsStream<TcpStream>>>::new(
 150                  tls_conn,
 151              )
 152              .into_dyn();
 153          Ok((auth_peer, framed))
 154      }
 155  }
 156  
 157  #[async_trait]
 158  impl<M> Connector<M> for TlsTcpConnector
 159  where
 160      M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static,
 161  {
 162      async fn connect_framed(&self, destination: SafeUrl, peer: PeerId) -> ConnectResult<M> {
 163          let cfg = rustls::ClientConfig::builder()
 164              .with_safe_defaults()
 165              .with_root_certificates(self.cert_store.clone())
 166              .with_client_auth_cert(
 167                  vec![self.our_certificate.clone()],
 168                  self.our_private_key.clone(),
 169              )
 170              .expect("Failed to create TLS config");
 171  
 172          let fake_domain =
 173              rustls::ServerName::try_from(dns_sanitize(&self.peer_names[&peer]).as_str())
 174                  .expect("Always a valid DNS name");
 175  
 176          let connector = TlsConnector::from(Arc::new(cfg));
 177          let tls_conn = connector
 178              .connect(
 179                  fake_domain,
 180                  TcpStream::connect(parse_host_port(destination)?).await?,
 181              )
 182              .await?;
 183  
 184          let (_, tls_session) = tls_conn.get_ref();
 185          let auth_peer = self
 186              .peer_certs
 187              .authenticate_peer(tls_session.peer_certificates())?;
 188  
 189          if auth_peer != peer {
 190              return Err(anyhow::anyhow!("Connected to unexpected peer"));
 191          }
 192  
 193          let framed =
 194              BidiFramed::<_, WriteHalf<TlsStream<TcpStream>>, ReadHalf<TlsStream<TcpStream>>>::new(
 195                  tls_conn,
 196              )
 197              .into_dyn();
 198  
 199          Ok((peer, framed))
 200      }
 201  
 202      async fn listen(&self, bind_addr: SocketAddr) -> Result<ConnectionListener<M>, anyhow::Error> {
 203          let verifier = AllowAnyAuthenticatedClient::new(self.cert_store.clone());
 204          let config = rustls::ServerConfig::builder()
 205              .with_safe_defaults()
 206              .with_client_cert_verifier(Arc::from(verifier))
 207              .with_single_cert(
 208                  vec![self.our_certificate.clone()],
 209                  self.our_private_key.clone(),
 210              )
 211              .unwrap();
 212          let listener = TcpListener::bind(bind_addr).await?;
 213          let peer_certs = self.peer_certs.clone();
 214  
 215          let stream = futures::stream::unfold(listener, move |mut listener| {
 216              let acceptor = TlsAcceptor::from(Arc::new(config.clone()));
 217              let peer_certs = peer_certs.clone();
 218  
 219              Box::pin(async move {
 220                  let res = peer_certs.accept_connection(&mut listener, &acceptor).await;
 221                  Some((res, listener))
 222              })
 223          });
 224          Ok(Box::pin(stream))
 225      }
 226  }
 227  
 228  /// Sanitizes name as valid domain name
 229  pub fn dns_sanitize(name: &str) -> String {
 230      let sanitized = name.replace(|c: char| !c.is_ascii_alphanumeric(), "_");
 231      format!("peer{sanitized}")
 232  }
 233  
 234  /// Parses the host and port from a url
 235  pub fn parse_host_port(url: SafeUrl) -> anyhow::Result<String> {
 236      let host = url
 237          .host_str()
 238          .ok_or_else(|| format_err!("Missing host in {url}"))?;
 239      let port = url
 240          .port()
 241          .ok_or_else(|| format_err!("Missing port in {url}"))?;
 242  
 243      Ok(format!("{host}:{port}"))
 244  }
 245  
 246  /// Fake network stack used in tests
 247  #[allow(unused_imports)]
 248  pub mod mock {
 249      use std::collections::HashMap;
 250      use std::fmt::Debug;
 251      use std::future::Future;
 252      use std::net::SocketAddr;
 253      use std::pin::Pin;
 254      use std::sync::atomic::{AtomicBool, Ordering};
 255      use std::sync::Arc;
 256      use std::time::Duration;
 257  
 258      use anyhow::{anyhow, Error};
 259      use fedimint_core::runtime::spawn;
 260      use fedimint_core::task::sleep;
 261      use fedimint_core::util::SafeUrl;
 262      use fedimint_core::{task, PeerId};
 263      use futures::{pin_mut, FutureExt, SinkExt, Stream, StreamExt};
 264      use rand::Rng;
 265      use tokio::io::{
 266          AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf,
 267      };
 268      use tokio::sync::mpsc::Sender;
 269      use tokio::sync::Mutex;
 270      use tokio_util::sync::CancellationToken;
 271      use tracing::error;
 272  
 273      use crate::net::connect::{parse_host_port, ConnectResult, Connector};
 274      use crate::net::framed::{BidiFramed, FramedTransport};
 275  
 276      struct UnreliableDuplexStream {
 277          inner: DuplexStream,
 278          broken: CancellationToken,
 279          read_generator: Option<UnreliabilityGenerator>,
 280          write_generator: Option<UnreliabilityGenerator>,
 281          flush_generator: Option<UnreliabilityGenerator>,
 282          shutdown_generator: Option<UnreliabilityGenerator>,
 283      }
 284  
 285      impl UnreliableDuplexStream {
 286          fn new(inner: DuplexStream, reliability: StreamReliability) -> UnreliableDuplexStream {
 287              match reliability {
 288                  StreamReliability::FullyReliable => Self {
 289                      inner,
 290                      broken: CancellationToken::new(),
 291                      read_generator: None,
 292                      write_generator: None,
 293                      flush_generator: None,
 294                      shutdown_generator: None,
 295                  },
 296                  StreamReliability::RandomlyUnreliable {
 297                      read_failure_rate,
 298                      write_failure_rate,
 299                      flush_failure_rate,
 300                      shutdown_failure_rate,
 301                      read_latency,
 302                      write_latency,
 303                      flush_latency,
 304                      shutdown_latency,
 305                  } => Self {
 306                      inner,
 307                      broken: CancellationToken::new(),
 308                      read_generator: Some(UnreliabilityGenerator::new(
 309                          read_latency,
 310                          read_failure_rate,
 311                      )),
 312                      write_generator: Some(UnreliabilityGenerator::new(
 313                          write_latency,
 314                          write_failure_rate,
 315                      )),
 316                      flush_generator: Some(UnreliabilityGenerator::new(
 317                          flush_latency,
 318                          flush_failure_rate,
 319                      )),
 320                      shutdown_generator: Some(UnreliabilityGenerator::new(
 321                          shutdown_latency,
 322                          shutdown_failure_rate,
 323                      )),
 324                  },
 325              }
 326          }
 327  
 328          fn poll_broken(&self, cx: &mut std::task::Context<'_>) -> bool {
 329              let await_cancellation = self.broken.cancelled();
 330              pin_mut!(await_cancellation);
 331              await_cancellation.poll(cx).is_ready()
 332          }
 333      }
 334  
 335      impl Debug for UnreliableDuplexStream {
 336          fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 337              f.debug_struct("UnreliableDuplexStream").finish()
 338          }
 339      }
 340  
 341      struct UnreliabilityGenerator {
 342          latency: LatencyInterval,
 343          failure_rate: FailureRate,
 344          sleep_future: Option<Pin<Box<tokio::time::Sleep>>>,
 345          successes: u64,
 346      }
 347  
 348      impl UnreliabilityGenerator {
 349          fn new(latency: LatencyInterval, failure_rate: FailureRate) -> UnreliabilityGenerator {
 350              Self {
 351                  latency,
 352                  failure_rate,
 353                  sleep_future: None,
 354                  successes: 0,
 355              }
 356          }
 357  
 358          pub fn generate(
 359              &mut self,
 360              cx: &mut std::task::Context<'_>,
 361          ) -> std::task::Poll<std::io::Result<()>> {
 362              let sleep = self.sleep_future.get_or_insert_with(|| {
 363                  Box::pin(
 364                      // nosemgrep: ban-tokio-sleep
 365                      tokio::time::sleep(self.latency.random()),
 366                  )
 367              });
 368              match sleep.poll_unpin(cx) {
 369                  std::task::Poll::Ready(()) => {
 370                      self.sleep_future = None;
 371                  }
 372                  std::task::Poll::Pending => return std::task::Poll::Pending,
 373              }
 374              if self.failure_rate.random_fail() {
 375                  tracing::debug!(
 376                      "Returning random error on unreliable stream after {} successes",
 377                      self.successes
 378                  );
 379                  std::task::Poll::Ready(Err(std::io::Error::new(
 380                      std::io::ErrorKind::Other,
 381                      "Randomly failed",
 382                  )))
 383              } else {
 384                  self.successes += 1;
 385                  std::task::Poll::Ready(Ok(()))
 386              }
 387          }
 388      }
 389  
 390      impl AsyncRead for UnreliableDuplexStream {
 391          fn poll_read(
 392              mut self: Pin<&mut Self>,
 393              cx: &mut std::task::Context<'_>,
 394              buf: &mut tokio::io::ReadBuf<'_>,
 395          ) -> std::task::Poll<std::io::Result<()>> {
 396              if self.poll_broken(cx) {
 397                  return std::task::Poll::Ready(Err(std::io::Error::new(
 398                      std::io::ErrorKind::Other,
 399                      "Stream is broken",
 400                  )));
 401              }
 402  
 403              match self.read_generator.as_mut().map(|g| g.generate(cx)) {
 404                  Some(std::task::Poll::Ready(Err(e))) => {
 405                      self.broken.cancel();
 406                      std::task::Poll::Ready(Err(e))
 407                  }
 408                  Some(std::task::Poll::Pending) => std::task::Poll::Pending,
 409                  Some(std::task::Poll::Ready(Ok(()))) | None => {
 410                      Pin::new(&mut self.inner).poll_read(cx, buf)
 411                  }
 412              }
 413          }
 414      }
 415  
 416      impl AsyncWrite for UnreliableDuplexStream {
 417          fn poll_write(
 418              mut self: Pin<&mut Self>,
 419              cx: &mut std::task::Context<'_>,
 420              buf: &[u8],
 421          ) -> std::task::Poll<Result<usize, std::io::Error>> {
 422              if self.poll_broken(cx) {
 423                  return std::task::Poll::Ready(Err(std::io::Error::new(
 424                      std::io::ErrorKind::Other,
 425                      "Stream is broken",
 426                  )));
 427              }
 428  
 429              match self.write_generator.as_mut().map(|g| g.generate(cx)) {
 430                  Some(std::task::Poll::Ready(Err(e))) => {
 431                      self.broken.cancel();
 432                      std::task::Poll::Ready(Err(e))
 433                  }
 434                  Some(std::task::Poll::Pending) => std::task::Poll::Pending,
 435                  Some(std::task::Poll::Ready(Ok(()))) | None => {
 436                      Pin::new(&mut self.inner).poll_write(cx, buf)
 437                  }
 438              }
 439          }
 440  
 441          fn poll_flush(
 442              mut self: Pin<&mut Self>,
 443              cx: &mut std::task::Context<'_>,
 444          ) -> std::task::Poll<Result<(), std::io::Error>> {
 445              if self.poll_broken(cx) {
 446                  return std::task::Poll::Ready(Err(std::io::Error::new(
 447                      std::io::ErrorKind::Other,
 448                      "Stream is broken",
 449                  )));
 450              }
 451  
 452              match self.flush_generator.as_mut().map(|g| g.generate(cx)) {
 453                  Some(std::task::Poll::Ready(Err(e))) => {
 454                      self.broken.cancel();
 455                      std::task::Poll::Ready(Err(e))
 456                  }
 457                  Some(std::task::Poll::Pending) => std::task::Poll::Pending,
 458                  Some(std::task::Poll::Ready(Ok(()))) | None => {
 459                      Pin::new(&mut self.inner).poll_flush(cx)
 460                  }
 461              }
 462          }
 463  
 464          fn poll_shutdown(
 465              mut self: Pin<&mut Self>,
 466              cx: &mut std::task::Context<'_>,
 467          ) -> std::task::Poll<Result<(), std::io::Error>> {
 468              if self.poll_broken(cx) {
 469                  return std::task::Poll::Ready(Err(std::io::Error::new(
 470                      std::io::ErrorKind::Other,
 471                      "Stream is broken",
 472                  )));
 473              }
 474  
 475              match self.shutdown_generator.as_mut().map(|g| g.generate(cx)) {
 476                  Some(std::task::Poll::Ready(Err(e))) => {
 477                      self.broken.cancel();
 478                      std::task::Poll::Ready(Err(e))
 479                  }
 480                  Some(std::task::Poll::Pending) => std::task::Poll::Pending,
 481                  Some(std::task::Poll::Ready(Ok(()))) | None => {
 482                      Pin::new(&mut self.inner).poll_shutdown(cx)
 483                  }
 484              }
 485          }
 486      }
 487  
 488      pub struct MockNetwork {
 489          clients: Arc<Mutex<HashMap<String, Sender<UnreliableDuplexStream>>>>,
 490      }
 491  
 492      pub struct MockConnector {
 493          id: PeerId,
 494          clients: Arc<Mutex<HashMap<String, Sender<UnreliableDuplexStream>>>>,
 495          reliability: StreamReliability,
 496      }
 497  
 498      impl MockNetwork {
 499          #[allow(clippy::new_without_default)]
 500          pub fn new() -> MockNetwork {
 501              MockNetwork {
 502                  clients: Arc::new(Default::default()),
 503              }
 504          }
 505  
 506          pub fn connector(&self, id: PeerId, reliability: StreamReliability) -> MockConnector {
 507              MockConnector {
 508                  id,
 509                  clients: self.clients.clone(),
 510                  reliability,
 511              }
 512          }
 513      }
 514  
 515      #[derive(Debug, Copy, Clone, PartialEq, Eq)]
 516      pub struct LatencyInterval {
 517          min_millis: u64,
 518          max_millis: u64,
 519      }
 520  
 521      impl LatencyInterval {
 522          const ZERO: LatencyInterval = LatencyInterval {
 523              min_millis: 0,
 524              max_millis: 0,
 525          };
 526  
 527          pub fn new(min: Duration, max: Duration) -> LatencyInterval {
 528              assert!(min <= max);
 529              LatencyInterval {
 530                  min_millis: min
 531                      .as_millis()
 532                      .try_into()
 533                      .expect("min duration as millis to fit in a u64"),
 534                  max_millis: max
 535                      .as_millis()
 536                      .try_into()
 537                      .expect("max duration as millis to fit in a u64"),
 538              }
 539          }
 540  
 541          pub fn random(&self) -> Duration {
 542              let mut rng = rand::thread_rng();
 543              Duration::from_millis(rng.gen_range(self.min_millis..=self.max_millis))
 544          }
 545      }
 546  
 547      #[derive(Debug, Copy, Clone)]
 548      pub struct FailureRate(f64);
 549      impl FailureRate {
 550          const MAX: FailureRate = FailureRate(1.0);
 551          pub fn new(failure_rate: f64) -> Self {
 552              assert!((0.0..=1.0).contains(&failure_rate));
 553              Self(failure_rate)
 554          }
 555  
 556          pub fn random_fail(&self) -> bool {
 557              let mut rng = rand::thread_rng();
 558              rng.gen_range(0.0..1.0) < self.0
 559          }
 560      }
 561  
 562      #[derive(Debug, Copy, Clone)]
 563      pub enum StreamReliability {
 564          FullyReliable,
 565          RandomlyUnreliable {
 566              read_failure_rate: FailureRate,
 567              write_failure_rate: FailureRate,
 568              flush_failure_rate: FailureRate,
 569              shutdown_failure_rate: FailureRate,
 570              read_latency: LatencyInterval,
 571              write_latency: LatencyInterval,
 572              flush_latency: LatencyInterval,
 573              shutdown_latency: LatencyInterval,
 574          },
 575      }
 576  
 577      impl StreamReliability {
 578          pub const MILDLY_UNRELIABLE: StreamReliability = {
 579              let failure_rate = FailureRate(0.0);
 580              let latency = LatencyInterval {
 581                  min_millis: 1,
 582                  max_millis: 10,
 583              };
 584              Self::RandomlyUnreliable {
 585                  read_failure_rate: failure_rate,
 586                  write_failure_rate: failure_rate,
 587                  flush_failure_rate: failure_rate,
 588                  shutdown_failure_rate: failure_rate,
 589                  read_latency: latency,
 590                  write_latency: latency,
 591                  flush_latency: latency,
 592                  shutdown_latency: latency,
 593              }
 594          };
 595  
 596          pub const INTEGRATION_TEST: StreamReliability = {
 597              // Based on empirical testing: creates errors without causing tests to take
 598              // additional time compared to StreamReliability::FullyReliable
 599              // If an order of magnitude higher, tests may take unreasonable amounts of time.
 600              // If an order of magnitude lower, a test may run without any error actually
 601              // happening
 602              let failure_rate_base = 0.0;
 603              let latency = LatencyInterval {
 604                  min_millis: 1,
 605                  max_millis: 10,
 606              };
 607              Self::RandomlyUnreliable {
 608                  // Try to make read_failure_rate = write_failure_rate + flush_failure_rate
 609                  read_failure_rate: FailureRate(failure_rate_base * 2.0),
 610                  write_failure_rate: FailureRate(failure_rate_base),
 611                  flush_failure_rate: FailureRate(failure_rate_base),
 612                  shutdown_failure_rate: FailureRate(failure_rate_base),
 613                  read_latency: latency,
 614                  write_latency: latency,
 615                  flush_latency: latency,
 616                  shutdown_latency: latency,
 617              }
 618          };
 619  
 620          pub const BROKEN: StreamReliability = {
 621              Self::RandomlyUnreliable {
 622                  read_failure_rate: FailureRate::MAX,
 623                  write_failure_rate: FailureRate::MAX,
 624                  flush_failure_rate: FailureRate::MAX,
 625                  shutdown_failure_rate: FailureRate::MAX,
 626                  read_latency: LatencyInterval::ZERO,
 627                  write_latency: LatencyInterval::ZERO,
 628                  flush_latency: LatencyInterval::ZERO,
 629                  shutdown_latency: LatencyInterval::ZERO,
 630              }
 631          };
 632      }
 633  
 634      #[async_trait::async_trait]
 635      impl<M> Connector<M> for MockConnector
 636      where
 637          M: Debug + serde::Serialize + serde::de::DeserializeOwned + Send + Unpin + 'static,
 638      {
 639          async fn connect_framed(&self, destination: SafeUrl, _peer: PeerId) -> ConnectResult<M> {
 640              let mut clients_lock = self.clients.try_lock().map_err(|e| {
 641                  anyhow!("Mock network mutex busy or poisoned, the network stack will re-try anyway: {e:?}")
 642              })?;
 643              if let Some(client) = clients_lock.get_mut(&parse_host_port(destination)?) {
 644                  let (stream_our, stream_theirs) = tokio::io::duplex(43_689);
 645                  let mut stream_our = UnreliableDuplexStream::new(stream_our, self.reliability);
 646                  let stream_theirs = UnreliableDuplexStream::new(stream_theirs, self.reliability);
 647                  client.send(stream_theirs).await?;
 648                  let peer = do_handshake(self.id, &mut stream_our).await?;
 649                  let framed = BidiFramed::<
 650                      M,
 651                      WriteHalf<UnreliableDuplexStream>,
 652                      ReadHalf<UnreliableDuplexStream>,
 653                  >::new(stream_our)
 654                  .into_dyn();
 655                  Ok((peer, framed))
 656              } else {
 657                  return Err(anyhow::anyhow!("can't connect"));
 658              }
 659          }
 660  
 661          async fn listen(
 662              &self,
 663              bind_addr: SocketAddr,
 664          ) -> Result<Pin<Box<dyn Stream<Item = ConnectResult<M>> + Send + Unpin + 'static>>, Error>
 665          {
 666              let (send, receive) = tokio::sync::mpsc::channel(16);
 667  
 668              if self
 669                  .clients
 670                  .lock()
 671                  .await
 672                  .insert(bind_addr.to_string(), send)
 673                  .is_some()
 674              {
 675                  return Err(anyhow::anyhow!("Address already bound"));
 676              }
 677  
 678              let our_id = self.id;
 679              let stream = futures::stream::unfold(receive, move |mut receive| {
 680                  Box::pin(async move {
 681                      let mut connection = receive.recv().await.unwrap();
 682                      let peer = match do_handshake(our_id, &mut connection).await {
 683                          Ok(peer) => peer,
 684                          Err(e) => {
 685                              tracing::debug!("Error during handshake: {e:?}");
 686                              return Some((Err(e), receive));
 687                          }
 688                      };
 689                      let framed =
 690                          BidiFramed::<M, WriteHalf<DuplexStream>, ReadHalf<DuplexStream>>::new(
 691                              connection,
 692                          )
 693                          .into_dyn();
 694  
 695                      Some((Ok((peer, framed)), receive))
 696                  })
 697              });
 698              Ok(Box::pin(stream))
 699          }
 700      }
 701  
 702      async fn do_handshake<S>(our_id: PeerId, stream: &mut S) -> Result<PeerId, anyhow::Error>
 703      where
 704          S: AsyncRead + AsyncWrite + Unpin,
 705      {
 706          // Send our id
 707          let our_id = our_id.to_usize() as u16;
 708          stream.write_all(&our_id.to_be_bytes()[..]).await?;
 709  
 710          // Receive peer id
 711          let mut peer_id = [0u8; 2];
 712          stream.read_exact(&mut peer_id[..]).await?;
 713          Ok(PeerId::from(u16::from_be_bytes(peer_id)))
 714      }
 715  
 716      #[tokio::test]
 717      async fn test_mock_network() {
 718          let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap();
 719          let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap();
 720          let peer_a = PeerId::from(1);
 721          let peer_b = PeerId::from(2);
 722  
 723          let net = MockNetwork::new();
 724          let conn_a = net.connector(peer_a, StreamReliability::FullyReliable);
 725          let conn_b = net.connector(peer_b, StreamReliability::FullyReliable);
 726  
 727          let mut listener = Connector::<u64>::listen(&conn_a, bind_addr).await.unwrap();
 728          let conn_a_fut = spawn("listener next await", async move {
 729              listener.next().await.unwrap().unwrap()
 730          });
 731  
 732          let (auth_peer_b, mut conn_b) = Connector::<u64>::connect_framed(&conn_b, url, peer_a)
 733              .await
 734              .unwrap();
 735          let (auth_peer_a, mut conn_a) = conn_a_fut.await.unwrap();
 736  
 737          assert_eq!(auth_peer_a, peer_b);
 738          assert_eq!(auth_peer_b, peer_a);
 739  
 740          conn_a.send(42).await.unwrap();
 741          conn_b.send(21).await.unwrap();
 742  
 743          assert_eq!(conn_a.next().await.unwrap().unwrap(), 21);
 744          assert_eq!(conn_b.next().await.unwrap().unwrap(), 42);
 745      }
 746  
 747      #[tokio::test]
 748      async fn test_unreliable_components() {
 749          assert!(!FailureRate::new(0f64).random_fail());
 750          assert!(FailureRate::new(1f64).random_fail());
 751  
 752          let good_interval = (0..=3).contains(
 753              &LatencyInterval::new(Duration::from_millis(0), Duration::from_millis(3))
 754                  .random()
 755                  .as_millis(),
 756          );
 757          assert!(good_interval);
 758  
 759          let (a, b) = tokio::io::duplex(43_689);
 760          let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::FullyReliable);
 761          let mut b_stream = UnreliableDuplexStream::new(b, StreamReliability::FullyReliable);
 762          assert!(a_stream.write(&[1, 2, 3]).await.is_ok());
 763          assert!(a_stream.flush().await.is_ok());
 764          assert_eq!(b_stream.read_u8().await.unwrap(), 1);
 765          assert_eq!(b_stream.read_u8().await.unwrap(), 2);
 766          assert_eq!(b_stream.read_u8().await.unwrap(), 3);
 767  
 768          let (a, b) = tokio::io::duplex(43_689);
 769          let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::FullyReliable);
 770          let mut b_stream = UnreliableDuplexStream::new(b, StreamReliability::BROKEN);
 771          assert!(a_stream.write(&[1, 2, 3]).await.is_ok());
 772          assert!(a_stream.flush().await.is_ok());
 773          assert!(b_stream.read_u8().await.is_err());
 774  
 775          let (a, b) = tokio::io::duplex(43_689);
 776          let mut a_stream = UnreliableDuplexStream::new(a, StreamReliability::BROKEN);
 777          let mut _b_stream = UnreliableDuplexStream::new(b, StreamReliability::FullyReliable);
 778          assert!(a_stream.write(&[1, 2, 3]).await.is_err());
 779          // a read on _b_stream would block...
 780      }
 781  
 782      #[allow(dead_code)]
 783      async fn timeout<F, T>(f: F) -> Option<T>
 784      where
 785          F: Future<Output = T>,
 786      {
 787          tokio::time::timeout(Duration::from_secs(1), f).await.ok()
 788      }
 789  
 790      #[tokio::test]
 791      async fn test_large_messages() {
 792          let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap();
 793          let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap();
 794          let peer_a = PeerId::from(1);
 795          let peer_b = PeerId::from(2);
 796  
 797          let net = MockNetwork::new();
 798          let conn_a = net.connector(peer_a, StreamReliability::FullyReliable);
 799          let conn_b = net.connector(peer_b, StreamReliability::FullyReliable);
 800  
 801          let mut listener = Connector::<Vec<u8>>::listen(&conn_a, bind_addr)
 802              .await
 803              .unwrap();
 804          let conn_a_fut = spawn("listener next await", async move {
 805              listener.next().await.unwrap().unwrap()
 806          });
 807  
 808          let (auth_peer_b, mut conn_b) = Connector::<Vec<u8>>::connect_framed(&conn_b, url, peer_a)
 809              .await
 810              .unwrap();
 811          let (auth_peer_a, mut conn_a) = conn_a_fut.await.unwrap();
 812  
 813          assert_eq!(auth_peer_a, peer_b);
 814          assert_eq!(auth_peer_b, peer_a);
 815  
 816          let send_future = async move {
 817              conn_a.send(vec![42; 16000]).await.unwrap();
 818          }
 819          .boxed();
 820          let receive_future = async move {
 821              assert_eq!(
 822                  timeout(conn_b.next()).await.unwrap().unwrap().unwrap(),
 823                  vec![42; 16000]
 824              );
 825          }
 826          .boxed();
 827  
 828          tokio::join!(send_future, receive_future);
 829      }
 830  }
 831  
 832  #[cfg(test)]
 833  mod tests {
 834      use std::net::SocketAddr;
 835  
 836      use fedimint_core::runtime::spawn;
 837      use fedimint_core::util::SafeUrl;
 838      use fedimint_core::PeerId;
 839      use futures::{SinkExt, StreamExt};
 840  
 841      use crate::config::gen_cert_and_key;
 842      use crate::net::connect::{ConnectionListener, Connector, TlsConfig};
 843      use crate::net::framed::AnyFramedTransport;
 844      use crate::TlsTcpConnector;
 845  
 846      fn gen_connector_config(count: usize) -> Vec<TlsConfig> {
 847          let peer_keys = (0..count)
 848              .map(|id| {
 849                  let peer = PeerId::from(id as u16);
 850                  gen_cert_and_key(&format!("peer-{}", peer.to_usize())).unwrap()
 851              })
 852              .collect::<Vec<_>>();
 853  
 854          peer_keys
 855              .iter()
 856              .map(|(_cert, key)| TlsConfig {
 857                  our_private_key: key.clone(),
 858                  peer_certs: peer_keys
 859                      .iter()
 860                      .enumerate()
 861                      .map(|(peer, (cert, _))| (PeerId::from(peer as u16), cert.clone()))
 862                      .collect(),
 863                  peer_names: peer_keys
 864                      .iter()
 865                      .enumerate()
 866                      .map(|(peer, (_, _))| (PeerId::from(peer as u16), format!("peer-{peer}")))
 867                      .collect(),
 868              })
 869              .collect()
 870      }
 871  
 872      #[tokio::test]
 873      async fn connect_success() {
 874          // FIXME: don't actually bind here, probably requires yet another Box<dyn Trait>
 875          // layer :(
 876          let bind_addr: SocketAddr = "127.0.0.1:7000".parse().unwrap();
 877          let url: SafeUrl = "ws://127.0.0.1:7000".parse().unwrap();
 878          let connectors = gen_connector_config(5)
 879              .into_iter()
 880              .enumerate()
 881              .map(|(id, cfg)| TlsTcpConnector::new(cfg, PeerId::from(id as u16)))
 882              .collect::<Vec<_>>();
 883  
 884          let mut server: ConnectionListener<u64> = connectors[0].listen(bind_addr).await.unwrap();
 885  
 886          let server_task = spawn("server next await", async move {
 887              let (peer, mut conn) = server.next().await.unwrap().unwrap();
 888              assert_eq!(peer.to_usize(), 2);
 889              let received = conn.next().await.unwrap().unwrap();
 890              assert_eq!(received, 42);
 891              conn.send(21).await.unwrap();
 892              assert!(conn.next().await.unwrap().is_err());
 893          });
 894  
 895          let (peer_of_a, mut client_a): (_, AnyFramedTransport<u64>) = connectors[2]
 896              .connect_framed(url.clone(), PeerId::from(0))
 897              .await
 898              .unwrap();
 899          assert_eq!(peer_of_a.to_usize(), 0);
 900          client_a.send(42).await.unwrap();
 901          let received = client_a.next().await.unwrap().unwrap();
 902          assert_eq!(received, 21);
 903          drop(client_a);
 904  
 905          server_task.await.unwrap();
 906      }
 907  
 908      #[tokio::test]
 909      async fn connect_reject() {
 910          let bind_addr: SocketAddr = "127.0.0.1:7001".parse().unwrap();
 911          let url: SafeUrl = "wss://127.0.0.1:7001".parse().unwrap();
 912          let cfg = gen_connector_config(5);
 913  
 914          let honest = TlsTcpConnector::new(cfg[0].clone(), PeerId::from(0));
 915  
 916          let mut malicious_wrong_key_cfg = cfg[1].clone();
 917          malicious_wrong_key_cfg.our_private_key = cfg[2].our_private_key.clone();
 918          let malicious_wrong_key = TlsTcpConnector::new(malicious_wrong_key_cfg, PeerId::from(1));
 919  
 920          // Honest server, malicious client with wrong private key
 921          {
 922              let mut server: ConnectionListener<u64> = honest.listen(bind_addr).await.unwrap();
 923  
 924              let server_task = spawn("server next await", async move {
 925                  let conn_res = server.next().await.unwrap();
 926                  assert_eq!(
 927                      conn_res.err().unwrap().to_string().as_str(),
 928                      "invalid peer certificate: BadSignature"
 929                  );
 930              });
 931  
 932              let err_anytime = async {
 933                  let (_peer, mut conn): (_, AnyFramedTransport<u64>) = malicious_wrong_key
 934                      .connect_framed(url.clone(), PeerId::from(0))
 935                      .await?;
 936  
 937                  conn.send(42).await?;
 938                  conn.flush().await?;
 939                  conn.next().await.unwrap()?;
 940  
 941                  Result::<_, anyhow::Error>::Ok(())
 942              };
 943  
 944              let conn_res = err_anytime.await;
 945              assert_eq!(
 946                  conn_res.err().unwrap().to_string().as_str(),
 947                  "received fatal alert: DecryptError"
 948              );
 949  
 950              server_task.await.unwrap();
 951          }
 952  
 953          // Malicious server with wrong key, honest client
 954          {
 955              let mut server: ConnectionListener<u64> =
 956                  malicious_wrong_key.listen(bind_addr).await.unwrap();
 957  
 958              let server_task = spawn("server next await", async move {
 959                  let conn_res = server.next().await.unwrap();
 960                  assert_eq!(
 961                      conn_res.err().unwrap().to_string().as_str(),
 962                      "received fatal alert: DecryptError"
 963                  );
 964              });
 965  
 966              let err_anytime = async {
 967                  let (_peer, mut conn): (_, AnyFramedTransport<u64>) =
 968                      honest.connect_framed(url.clone(), PeerId::from(1)).await?;
 969  
 970                  conn.send(42).await?;
 971                  conn.flush().await?;
 972                  conn.next().await.unwrap()?;
 973  
 974                  Result::<_, anyhow::Error>::Ok(())
 975              };
 976  
 977              let conn_res = err_anytime.await;
 978              assert_eq!(
 979                  conn_res.err().unwrap().to_string().as_str(),
 980                  "invalid peer certificate: BadSignature"
 981              );
 982  
 983              server_task.await.unwrap();
 984          }
 985  
 986          // Server with wrong certificate, honest client
 987          {
 988              let mut server: ConnectionListener<u64> =
 989                  TlsTcpConnector::new(cfg[2].clone(), PeerId::from(2))
 990                      .listen(bind_addr)
 991                      .await
 992                      .unwrap();
 993  
 994              let server_task = spawn("server next await", async move {
 995                  let conn_res = server.next().await.unwrap();
 996                  assert_eq!(
 997                      conn_res.err().unwrap().to_string().as_str(),
 998                      "received fatal alert: BadCertificate"
 999                  );
1000              });
1001  
1002              let err_anytime = async {
1003                  let (_peer, mut conn): (_, AnyFramedTransport<u64>) =
1004                      honest.connect_framed(url.clone(), PeerId::from(0)).await?;
1005  
1006                  conn.send(42).await?;
1007                  conn.flush().await?;
1008                  conn.next().await.unwrap()?;
1009  
1010                  Result::<_, anyhow::Error>::Ok(())
1011              };
1012  
1013              let conn_res = err_anytime.await;
1014              assert_eq!(
1015                  conn_res.err().unwrap().to_string().as_str(),
1016                  "invalid peer certificate: NotValidForName"
1017              );
1018  
1019              server_task.await.unwrap();
1020          }
1021      }
1022  }