multiplexed.rs
1 use std::collections::{HashMap, VecDeque}; 2 use std::fmt::Debug; 3 use std::hash::Hash; 4 5 use async_trait::async_trait; 6 use fedimint_core::net::peers::{IMuxPeerConnections, PeerConnections}; 7 use fedimint_core::runtime::spawn; 8 use fedimint_core::task::{Cancellable, Cancelled}; 9 use fedimint_core::PeerId; 10 use fedimint_logging::LOG_NET_PEER; 11 use serde::de::DeserializeOwned; 12 use serde::{Deserialize, Serialize}; 13 use tokio::sync::mpsc::{channel, Receiver, Sender}; 14 use tokio::sync::oneshot; 15 use tracing::{debug, warn}; 16 17 /// TODO: Use proper ModuleId after modularization is complete 18 pub type ModuleId = String; 19 pub type ModuleIdRef<'a> = &'a str; 20 21 /// Amount of per-peer messages after which we will stop throwing them away. 22 /// 23 /// It's hard to predict how many messages is too many, but we have 24 /// to draw the line somewhere. 25 pub const MAX_PEER_OUT_OF_ORDER_MESSAGES: u64 = 10000; 26 27 /// A `Msg` that can target a specific destination module 28 #[derive(Serialize, Deserialize, Debug, Clone)] 29 pub struct ModuleMultiplexed<MuxKey, Msg> { 30 pub key: MuxKey, 31 pub msg: Msg, 32 } 33 34 struct ModuleMultiplexerOutOfOrder<MuxKey, Msg> { 35 /// Cached messages per `ModuleId` waiting for callback 36 msgs: HashMap<MuxKey, VecDeque<(PeerId, Msg)>>, 37 /// Callback queue from tasks that want to receive 38 callbacks: HashMap<MuxKey, VecDeque<oneshot::Sender<(PeerId, Msg)>>>, 39 /// Track pending messages per peer to avoid a potential DoS 40 peer_counts: HashMap<PeerId, u64>, 41 } 42 43 impl<MuxKey, Msg> Default for ModuleMultiplexerOutOfOrder<MuxKey, Msg> { 44 fn default() -> Self { 45 Self { 46 msgs: Default::default(), 47 callbacks: Default::default(), 48 peer_counts: Default::default(), 49 } 50 } 51 } 52 53 /// A wrapper around `AnyPeerConnections` multiplexing communication between 54 /// multiple modules over it 55 /// 56 /// This works by addressing each module when sending, and handling buffering 57 /// messages received out of order until they are requested. 58 /// 59 /// This type is thread-safe and can be cheaply cloned. 60 #[derive(Clone)] 61 pub struct PeerConnectionMultiplexer<MuxKey, Msg> { 62 /// Sender of send requests 63 send_requests_tx: Sender<(Vec<PeerId>, MuxKey, Msg)>, 64 /// Sender of receive callbacks 65 receive_callbacks_tx: Sender<Callback<MuxKey, Msg>>, 66 /// Sender of peer bans 67 peer_bans_tx: Sender<PeerId>, 68 } 69 70 type Callback<MuxKey, Msg> = (MuxKey, oneshot::Sender<(PeerId, Msg)>); 71 72 impl<MuxKey, Msg> PeerConnectionMultiplexer<MuxKey, Msg> 73 where 74 Msg: Serialize + DeserializeOwned + Unpin + Send + Debug + 'static, 75 MuxKey: Serialize + DeserializeOwned + Unpin + Send + Debug + Eq + Hash + Clone + 'static, 76 { 77 pub fn new(connections: PeerConnections<ModuleMultiplexed<MuxKey, Msg>>) -> Self { 78 let (send_requests_tx, send_requests_rx) = channel(1000); 79 let (receive_callbacks_tx, receive_callbacks_rx) = channel(1000); 80 let (peer_bans_tx, peer_bans_rx) = channel(1000); 81 82 spawn( 83 "peer connection multiplexer", 84 Self::run( 85 connections, 86 Default::default(), 87 send_requests_rx, 88 receive_callbacks_rx, 89 peer_bans_rx, 90 ), 91 ); 92 93 Self { 94 send_requests_tx, 95 receive_callbacks_tx, 96 peer_bans_tx, 97 } 98 } 99 100 async fn run( 101 mut connections: PeerConnections<ModuleMultiplexed<MuxKey, Msg>>, 102 mut out_of_order: ModuleMultiplexerOutOfOrder<MuxKey, Msg>, 103 mut send_requests_rx: Receiver<(Vec<PeerId>, MuxKey, Msg)>, 104 mut receive_callbacks_rx: Receiver<Callback<MuxKey, Msg>>, 105 mut peer_bans_rx: Receiver<PeerId>, 106 ) -> Cancellable<()> { 107 loop { 108 let mut key_inserted: Option<MuxKey> = None; 109 tokio::select! { 110 // Send requests are forwarded to underlying connections 111 send_request = send_requests_rx.recv() => { 112 let (peers, key, msg) = send_request.ok_or(Cancelled)?; 113 connections.send(&peers, ModuleMultiplexed { key, msg }).await?; 114 } 115 // Ban requests are forwarded to underlying connections 116 peer_ban = peer_bans_rx.recv() => { 117 let peer = peer_ban.ok_or(Cancelled)?; 118 connections.ban_peer(peer).await; 119 } 120 // Receive callbacks are added to callback queue by key 121 receive_callback = receive_callbacks_rx.recv() => { 122 let (key, callback) = receive_callback.ok_or(Cancelled)?; 123 out_of_order.callbacks.entry(key.clone()).or_default().push_back(callback); 124 key_inserted = Some(key); 125 } 126 // Actual received messages are added message queue by key 127 receive = connections.receive() => { 128 let (peer, ModuleMultiplexed { key, msg }) = receive?; 129 let peer_pending = out_of_order.peer_counts.entry(peer).or_default(); 130 // We limit our messages from any given peer to avoid OOM 131 // In practice this would halt DKG 132 if *peer_pending > MAX_PEER_OUT_OF_ORDER_MESSAGES { 133 warn!( 134 target: LOG_NET_PEER, 135 "Peer {peer} has {peer_pending} pending messages. Dropping new message." 136 ); 137 } else { 138 *peer_pending += 1; 139 out_of_order.msgs.entry(key.clone()).or_default().push_back((peer, msg)); 140 key_inserted = Some(key); 141 } 142 } 143 } 144 145 // If a key was inserted, check to see if we can fulfill a callback 146 if let Some(key) = key_inserted { 147 let callbacks = out_of_order.callbacks.entry(key.clone()).or_default(); 148 let msgs = out_of_order.msgs.entry(key.clone()).or_default(); 149 150 if !callbacks.is_empty() && !msgs.is_empty() { 151 let callback = callbacks.pop_front().expect("checked"); 152 let (peer, msg) = msgs.pop_front().expect("checked"); 153 let peer_pending = out_of_order.peer_counts.entry(peer).or_default(); 154 *peer_pending -= 1; 155 callback.send((peer, msg)).map_err(|_| Cancelled)?; 156 } 157 } 158 } 159 } 160 } 161 162 #[async_trait] 163 impl<MuxKey, Msg> IMuxPeerConnections<MuxKey, Msg> for PeerConnectionMultiplexer<MuxKey, Msg> 164 where 165 Msg: Serialize + DeserializeOwned + Unpin + Send + Debug, 166 MuxKey: Serialize + DeserializeOwned + Unpin + Send + Debug + Eq + Hash + Clone, 167 { 168 async fn send(&self, peers: &[PeerId], key: MuxKey, msg: Msg) -> Cancellable<()> { 169 debug!("Sending to {peers:?}/{key:?}, {msg:?}"); 170 self.send_requests_tx 171 .send((peers.to_vec(), key, msg)) 172 .await 173 .map_err(|_e| Cancelled) 174 } 175 176 /// Await receipt of a message from any connected peer. 177 async fn receive(&self, key: MuxKey) -> Cancellable<(PeerId, Msg)> { 178 let (callback_tx, callback_rx) = oneshot::channel(); 179 self.receive_callbacks_tx 180 .send((key, callback_tx)) 181 .await 182 .map_err(|_e| Cancelled)?; 183 callback_rx.await.map_err(|_e| Cancelled) 184 } 185 186 async fn ban_peer(&self, peer: PeerId) { 187 // We don't return a `Cancellable` for bans 188 let _ = self.peer_bans_tx.send(peer).await; 189 } 190 } 191 192 #[cfg(test)] 193 pub mod test { 194 use std::time::Duration; 195 196 use fedimint_core::net::peers::fake::make_fake_peer_connection; 197 use fedimint_core::net::peers::IMuxPeerConnections; 198 use fedimint_core::task::{self, TaskGroup}; 199 use fedimint_core::PeerId; 200 use rand::rngs::OsRng; 201 use rand::seq::SliceRandom; 202 use rand::{thread_rng, Rng}; 203 204 use crate::multiplexed::PeerConnectionMultiplexer; 205 206 /// Send over many messages a multiplexed fake link 207 /// 208 /// Some things this is checking for: 209 /// 210 /// * no message were missed 211 /// * messages arrived in order (from PoW of each module) 212 /// * nothing deadlocked somewhere. 213 #[test_log::test(tokio::test)] 214 async fn test_multiplexer() { 215 const NUM_MODULES: usize = 128; 216 const NUM_MSGS_PER_MODULE: usize = 128; 217 const NUM_REPEAT_TEST: usize = 10; 218 219 for _ in 0..NUM_REPEAT_TEST { 220 let task_group = TaskGroup::new(); 221 let task_handle = task_group.make_handle(); 222 223 let peer1 = PeerId::from(0); 224 let peer2 = PeerId::from(1); 225 226 let (conn1, conn2) = make_fake_peer_connection(peer1, peer2, 1000, task_handle.clone()); 227 let (conn1, conn2) = ( 228 PeerConnectionMultiplexer::new(conn1).into_dyn(), 229 PeerConnectionMultiplexer::new(conn2).into_dyn(), 230 ); 231 232 let mut modules: Vec<_> = (0..NUM_MODULES).collect(); 233 modules.shuffle(&mut thread_rng()); 234 235 for mux_key in modules.clone() { 236 let conn1 = conn1.clone(); 237 let task_handle = task_handle.clone(); 238 task_group.spawn(format!("sender-{mux_key}"), move |_| async move { 239 for msg_i in 0..NUM_MSGS_PER_MODULE { 240 // add some random jitter 241 if OsRng.gen() { 242 // Note that randomized sleep in sender is larger than 243 // in receiver, to avoid just running with always full 244 // queues. 245 task::sleep(Duration::from_millis(2)).await; 246 } 247 if task_handle.is_shutting_down() { 248 break; 249 } 250 conn1.send(&[peer2], mux_key, msg_i).await.unwrap(); 251 } 252 }); 253 } 254 255 modules.shuffle(&mut thread_rng()); 256 for mux_key in modules.clone() { 257 let conn2 = conn2.clone(); 258 task_group.spawn(format!("receiver-{mux_key}"), move |_| async move { 259 for msg_i in 0..NUM_MSGS_PER_MODULE { 260 // add some random jitter 261 if OsRng.gen() { 262 task::sleep(Duration::from_millis(1)).await; 263 } 264 assert_eq!(conn2.receive(mux_key).await.unwrap(), (peer1, msg_i)); 265 } 266 }); 267 } 268 269 task_group.join_all(None).await.expect("no failures"); 270 } 271 } 272 }