/ fedimint-server / src / multiplexed.rs
multiplexed.rs
  1  use std::collections::{HashMap, VecDeque};
  2  use std::fmt::Debug;
  3  use std::hash::Hash;
  4  
  5  use async_trait::async_trait;
  6  use fedimint_core::net::peers::{IMuxPeerConnections, PeerConnections};
  7  use fedimint_core::runtime::spawn;
  8  use fedimint_core::task::{Cancellable, Cancelled};
  9  use fedimint_core::PeerId;
 10  use fedimint_logging::LOG_NET_PEER;
 11  use serde::de::DeserializeOwned;
 12  use serde::{Deserialize, Serialize};
 13  use tokio::sync::mpsc::{channel, Receiver, Sender};
 14  use tokio::sync::oneshot;
 15  use tracing::{debug, warn};
 16  
 17  /// TODO: Use proper ModuleId after modularization is complete
 18  pub type ModuleId = String;
 19  pub type ModuleIdRef<'a> = &'a str;
 20  
 21  /// Amount of per-peer messages after which we will stop throwing them away.
 22  ///
 23  /// It's hard to predict how many messages is too many, but we have
 24  /// to draw the line somewhere.
 25  pub const MAX_PEER_OUT_OF_ORDER_MESSAGES: u64 = 10000;
 26  
 27  /// A `Msg` that can target a specific destination module
 28  #[derive(Serialize, Deserialize, Debug, Clone)]
 29  pub struct ModuleMultiplexed<MuxKey, Msg> {
 30      pub key: MuxKey,
 31      pub msg: Msg,
 32  }
 33  
 34  struct ModuleMultiplexerOutOfOrder<MuxKey, Msg> {
 35      /// Cached messages per `ModuleId` waiting for callback
 36      msgs: HashMap<MuxKey, VecDeque<(PeerId, Msg)>>,
 37      /// Callback queue from tasks that want to receive
 38      callbacks: HashMap<MuxKey, VecDeque<oneshot::Sender<(PeerId, Msg)>>>,
 39      /// Track pending messages per peer to avoid a potential DoS
 40      peer_counts: HashMap<PeerId, u64>,
 41  }
 42  
 43  impl<MuxKey, Msg> Default for ModuleMultiplexerOutOfOrder<MuxKey, Msg> {
 44      fn default() -> Self {
 45          Self {
 46              msgs: Default::default(),
 47              callbacks: Default::default(),
 48              peer_counts: Default::default(),
 49          }
 50      }
 51  }
 52  
 53  /// A wrapper around `AnyPeerConnections` multiplexing communication between
 54  /// multiple modules over it
 55  ///
 56  /// This works by addressing each module when sending, and handling buffering
 57  /// messages received out of order until they are requested.
 58  ///
 59  /// This type is thread-safe and can be cheaply cloned.
 60  #[derive(Clone)]
 61  pub struct PeerConnectionMultiplexer<MuxKey, Msg> {
 62      /// Sender of send requests
 63      send_requests_tx: Sender<(Vec<PeerId>, MuxKey, Msg)>,
 64      /// Sender of receive callbacks
 65      receive_callbacks_tx: Sender<Callback<MuxKey, Msg>>,
 66      /// Sender of peer bans
 67      peer_bans_tx: Sender<PeerId>,
 68  }
 69  
 70  type Callback<MuxKey, Msg> = (MuxKey, oneshot::Sender<(PeerId, Msg)>);
 71  
 72  impl<MuxKey, Msg> PeerConnectionMultiplexer<MuxKey, Msg>
 73  where
 74      Msg: Serialize + DeserializeOwned + Unpin + Send + Debug + 'static,
 75      MuxKey: Serialize + DeserializeOwned + Unpin + Send + Debug + Eq + Hash + Clone + 'static,
 76  {
 77      pub fn new(connections: PeerConnections<ModuleMultiplexed<MuxKey, Msg>>) -> Self {
 78          let (send_requests_tx, send_requests_rx) = channel(1000);
 79          let (receive_callbacks_tx, receive_callbacks_rx) = channel(1000);
 80          let (peer_bans_tx, peer_bans_rx) = channel(1000);
 81  
 82          spawn(
 83              "peer connection multiplexer",
 84              Self::run(
 85                  connections,
 86                  Default::default(),
 87                  send_requests_rx,
 88                  receive_callbacks_rx,
 89                  peer_bans_rx,
 90              ),
 91          );
 92  
 93          Self {
 94              send_requests_tx,
 95              receive_callbacks_tx,
 96              peer_bans_tx,
 97          }
 98      }
 99  
100      async fn run(
101          mut connections: PeerConnections<ModuleMultiplexed<MuxKey, Msg>>,
102          mut out_of_order: ModuleMultiplexerOutOfOrder<MuxKey, Msg>,
103          mut send_requests_rx: Receiver<(Vec<PeerId>, MuxKey, Msg)>,
104          mut receive_callbacks_rx: Receiver<Callback<MuxKey, Msg>>,
105          mut peer_bans_rx: Receiver<PeerId>,
106      ) -> Cancellable<()> {
107          loop {
108              let mut key_inserted: Option<MuxKey> = None;
109              tokio::select! {
110                   // Send requests are forwarded to underlying connections
111                   send_request = send_requests_rx.recv() => {
112                      let (peers, key, msg) = send_request.ok_or(Cancelled)?;
113                      connections.send(&peers, ModuleMultiplexed { key, msg }).await?;
114                  }
115                  // Ban requests are forwarded to underlying connections
116                  peer_ban = peer_bans_rx.recv() => {
117                      let peer = peer_ban.ok_or(Cancelled)?;
118                      connections.ban_peer(peer).await;
119                  }
120                  // Receive callbacks are added to callback queue by key
121                  receive_callback = receive_callbacks_rx.recv() => {
122                      let (key, callback) = receive_callback.ok_or(Cancelled)?;
123                      out_of_order.callbacks.entry(key.clone()).or_default().push_back(callback);
124                      key_inserted = Some(key);
125                  }
126                  // Actual received messages are added message queue by key
127                  receive = connections.receive() => {
128                      let (peer, ModuleMultiplexed { key, msg }) = receive?;
129                      let peer_pending = out_of_order.peer_counts.entry(peer).or_default();
130                      // We limit our messages from any given peer to avoid OOM
131                      // In practice this would halt DKG
132                      if *peer_pending > MAX_PEER_OUT_OF_ORDER_MESSAGES {
133                          warn!(
134                              target: LOG_NET_PEER,
135                              "Peer {peer} has {peer_pending} pending messages. Dropping new message."
136                          );
137                      } else {
138                          *peer_pending += 1;
139                          out_of_order.msgs.entry(key.clone()).or_default().push_back((peer, msg));
140                          key_inserted = Some(key);
141                      }
142                  }
143              }
144  
145              // If a key was inserted, check to see if we can fulfill a callback
146              if let Some(key) = key_inserted {
147                  let callbacks = out_of_order.callbacks.entry(key.clone()).or_default();
148                  let msgs = out_of_order.msgs.entry(key.clone()).or_default();
149  
150                  if !callbacks.is_empty() && !msgs.is_empty() {
151                      let callback = callbacks.pop_front().expect("checked");
152                      let (peer, msg) = msgs.pop_front().expect("checked");
153                      let peer_pending = out_of_order.peer_counts.entry(peer).or_default();
154                      *peer_pending -= 1;
155                      callback.send((peer, msg)).map_err(|_| Cancelled)?;
156                  }
157              }
158          }
159      }
160  }
161  
162  #[async_trait]
163  impl<MuxKey, Msg> IMuxPeerConnections<MuxKey, Msg> for PeerConnectionMultiplexer<MuxKey, Msg>
164  where
165      Msg: Serialize + DeserializeOwned + Unpin + Send + Debug,
166      MuxKey: Serialize + DeserializeOwned + Unpin + Send + Debug + Eq + Hash + Clone,
167  {
168      async fn send(&self, peers: &[PeerId], key: MuxKey, msg: Msg) -> Cancellable<()> {
169          debug!("Sending to {peers:?}/{key:?}, {msg:?}");
170          self.send_requests_tx
171              .send((peers.to_vec(), key, msg))
172              .await
173              .map_err(|_e| Cancelled)
174      }
175  
176      /// Await receipt of a message from any connected peer.
177      async fn receive(&self, key: MuxKey) -> Cancellable<(PeerId, Msg)> {
178          let (callback_tx, callback_rx) = oneshot::channel();
179          self.receive_callbacks_tx
180              .send((key, callback_tx))
181              .await
182              .map_err(|_e| Cancelled)?;
183          callback_rx.await.map_err(|_e| Cancelled)
184      }
185  
186      async fn ban_peer(&self, peer: PeerId) {
187          // We don't return a `Cancellable` for bans
188          let _ = self.peer_bans_tx.send(peer).await;
189      }
190  }
191  
192  #[cfg(test)]
193  pub mod test {
194      use std::time::Duration;
195  
196      use fedimint_core::net::peers::fake::make_fake_peer_connection;
197      use fedimint_core::net::peers::IMuxPeerConnections;
198      use fedimint_core::task::{self, TaskGroup};
199      use fedimint_core::PeerId;
200      use rand::rngs::OsRng;
201      use rand::seq::SliceRandom;
202      use rand::{thread_rng, Rng};
203  
204      use crate::multiplexed::PeerConnectionMultiplexer;
205  
206      /// Send over many messages a multiplexed fake link
207      ///
208      /// Some things this is checking for:
209      ///
210      /// * no message were missed
211      /// * messages arrived in order (from PoW of each module)
212      /// * nothing deadlocked somewhere.
213      #[test_log::test(tokio::test)]
214      async fn test_multiplexer() {
215          const NUM_MODULES: usize = 128;
216          const NUM_MSGS_PER_MODULE: usize = 128;
217          const NUM_REPEAT_TEST: usize = 10;
218  
219          for _ in 0..NUM_REPEAT_TEST {
220              let task_group = TaskGroup::new();
221              let task_handle = task_group.make_handle();
222  
223              let peer1 = PeerId::from(0);
224              let peer2 = PeerId::from(1);
225  
226              let (conn1, conn2) = make_fake_peer_connection(peer1, peer2, 1000, task_handle.clone());
227              let (conn1, conn2) = (
228                  PeerConnectionMultiplexer::new(conn1).into_dyn(),
229                  PeerConnectionMultiplexer::new(conn2).into_dyn(),
230              );
231  
232              let mut modules: Vec<_> = (0..NUM_MODULES).collect();
233              modules.shuffle(&mut thread_rng());
234  
235              for mux_key in modules.clone() {
236                  let conn1 = conn1.clone();
237                  let task_handle = task_handle.clone();
238                  task_group.spawn(format!("sender-{mux_key}"), move |_| async move {
239                      for msg_i in 0..NUM_MSGS_PER_MODULE {
240                          // add some random jitter
241                          if OsRng.gen() {
242                              // Note that randomized sleep in sender is larger than
243                              // in receiver, to avoid just running with always full
244                              // queues.
245                              task::sleep(Duration::from_millis(2)).await;
246                          }
247                          if task_handle.is_shutting_down() {
248                              break;
249                          }
250                          conn1.send(&[peer2], mux_key, msg_i).await.unwrap();
251                      }
252                  });
253              }
254  
255              modules.shuffle(&mut thread_rng());
256              for mux_key in modules.clone() {
257                  let conn2 = conn2.clone();
258                  task_group.spawn(format!("receiver-{mux_key}"), move |_| async move {
259                      for msg_i in 0..NUM_MSGS_PER_MODULE {
260                          // add some random jitter
261                          if OsRng.gen() {
262                              task::sleep(Duration::from_millis(1)).await;
263                          }
264                          assert_eq!(conn2.receive(mux_key).await.unwrap(), (peer1, msg_i));
265                      }
266                  });
267              }
268  
269              task_group.join_all(None).await.expect("no failures");
270          }
271      }
272  }