instance.rs
1 use std::sync::Arc; 2 3 use dashmap::DashMap; 4 use distrox_network::connection::runner::Blocklist; 5 use iroh::Endpoint; 6 use iroh::EndpointAddr; 7 use iroh::EndpointId; 8 use iroh::discovery::EndpointData; 9 use tokio_util::sync::CancellationToken; 10 use tracing::Instrument; 11 12 use crate::command::InstanceCommand; 13 use crate::error::Error; 14 use crate::event::InstanceEvent; 15 16 pub struct Instance { 17 instance_cancellation_token: CancellationToken, 18 endpoint: Endpoint, 19 20 connected_nodes: Arc<DashMap<EndpointId, distrox_network::connection::ConnectionHandle>>, 21 known_mdns_endpoints: Arc<DashMap<EndpointId, EndpointData>>, 22 23 instance_command_sender: tokio::sync::mpsc::Sender<InstanceCommand>, 24 instance_command_recv: tokio::sync::mpsc::Receiver<InstanceCommand>, 25 26 blocklist: Blocklist, 27 } 28 29 impl Drop for Instance { 30 fn drop(&mut self) { 31 self.instance_cancellation_token.cancel(); 32 } 33 } 34 35 impl std::fmt::Debug for Instance { 36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 37 f.debug_struct("Instance").finish_non_exhaustive() 38 } 39 } 40 41 impl Instance { 42 pub async fn new( 43 secret_key: distrox_keys::private::PrivateSigningKey, 44 relay_mode: iroh::RelayMode, 45 blocklist: Blocklist, 46 ) -> Result<Self, Error> { 47 let instance_cancellation_token = CancellationToken::new(); 48 49 tracing::info!(relay_mode = ?relay_mode, "Booting endpoint"); 50 51 let endpoint = secret_key 52 .set_into(Endpoint::empty_builder(iroh::RelayMode::Disabled)) 53 .relay_mode(relay_mode) 54 // TODO For now this should be enough 55 .discovery(iroh::discovery::mdns::MdnsDiscovery::builder().advertise(true)) 56 .alpns(vec![distrox_network::protocol::ALPN.to_vec()]) 57 .bind() 58 .await 59 .map_err(Error::EndpointBind)?; 60 61 let mdns = iroh::discovery::mdns::MdnsDiscovery::builder() 62 .advertise(true) 63 .service_name("distrox") 64 .build(endpoint.id()) 65 .map_err(Error::BuildingMdnsDiscovery)?; 66 67 endpoint.discovery().add(mdns.clone()); 68 69 let known_mdns_endpoints = Arc::new(DashMap::new()); 70 71 tokio::spawn(observe_mdns_endpoints( 72 mdns, 73 known_mdns_endpoints.clone(), 74 instance_cancellation_token.clone(), 75 )); 76 77 tracing::info!(endpoint_id = %endpoint.id().fmt_short(), "Booted endpoint"); 78 79 tracing::debug!("Spawning background worker"); 80 81 tracing::debug!("Instance started"); 82 let (instance_command_sender, instance_command_recv) = tokio::sync::mpsc::channel(100); 83 84 Ok(Self { 85 instance_cancellation_token, 86 endpoint, 87 instance_command_sender, 88 instance_command_recv, 89 90 connected_nodes: Arc::new(DashMap::new()), 91 known_mdns_endpoints, 92 93 blocklist, 94 }) 95 } 96 97 pub fn handle(&self) -> crate::handle::InstanceHandle { 98 crate::handle::InstanceHandle { 99 connections: self.connected_nodes.clone(), 100 endpoint_id: self.endpoint.id(), 101 endpoint_addr: self.endpoint.addr(), 102 sender: self.instance_command_sender.clone(), 103 blocklist: self.blocklist.clone(), 104 instance_cancellation_token: self.instance_cancellation_token.clone(), 105 } 106 } 107 108 pub async fn run(mut self) -> impl futures::Stream<Item = InstanceEvent> { 109 let endpoint = self.endpoint.clone(); 110 111 let (event_sender, event_recv) = tokio::sync::mpsc::channel(1024); 112 tokio::task::spawn({ 113 let instance_cancellation_token = self.instance_cancellation_token.clone(); 114 115 async move { 116 loop { 117 tokio::select! { 118 instance_command = self.instance_command_recv.recv() => { 119 if let Some(instance_command) = instance_command 120 && self.handle_instance_command(instance_command).await == std::ops::ControlFlow::Break(()) { 121 break 122 } 123 } 124 125 incoming = endpoint.accept() => { 126 if let Some(incoming) = incoming { 127 let event_sender = event_sender.clone(); 128 let instance_cancellation_token = self.instance_cancellation_token.clone(); 129 130 tokio::task::spawn(async move { 131 match instance_cancellation_token.run_until_cancelled(event_sender.send(InstanceEvent::Incoming(incoming))).await { 132 Some(Ok(_)) => tracing::trace!("Successfully send event to processing loop"), 133 Some(Err(_)) => tracing::error!("Internal channel error"), 134 None => tracing::debug!("Instance cancelled"), 135 } 136 }); 137 } else { 138 tracing::error!("Accepting failed, Instance cannot continue to run"); 139 break 140 } 141 } 142 143 _cancelled = instance_cancellation_token.cancelled() => { 144 tracing::info!("Instance cancelled, finish processing"); 145 break 146 } 147 } 148 } 149 } 150 }); 151 152 tokio_stream::wrappers::ReceiverStream::new(event_recv) 153 } 154 155 async fn handle_instance_command(&self, command: InstanceCommand) -> std::ops::ControlFlow<()> { 156 match command { 157 InstanceCommand::Shutdown(reply_sender) => { 158 if reply_sender.send(()).is_err() { 159 tracing::error!("Internal channel error"); 160 } 161 return std::ops::ControlFlow::Break(()); 162 } 163 164 InstanceCommand::WaitOnline(reply_sender) => { 165 self.endpoint 166 .online() 167 .instrument(tracing::error_span!("instance.wait_until_online")) 168 .await; 169 170 if reply_sender.send(()).is_err() { 171 tracing::error!("Internal channel error"); 172 } 173 } 174 InstanceCommand::ConnectTo { 175 target, 176 result_sender, 177 services, 178 } => { 179 let result = self.connect_to(target, services).await; 180 if result_sender.send(result).is_err() { 181 tracing::error!("Internal channel error"); 182 } 183 } 184 } 185 186 std::ops::ControlFlow::Continue(()) 187 } 188 189 /// Connect to a remote instance 190 /// 191 /// # Note 192 /// 193 /// If the instance is already connected to the remote pointed to by `node_addr`, the 194 /// [distrox_network::connection::ConnectionHandle] for that connection is returned. 195 async fn connect_to( 196 &self, 197 node_addr: EndpointAddr, 198 services: distrox_network::connection::Services, 199 ) -> Result<distrox_network::connection::ConnectionHandle, Error> { 200 if let Some(connection_handle) = self.connected_nodes.get(&node_addr.id) { 201 return Ok(connection_handle.clone()); 202 } 203 204 let endpoint_id_str = data_encoding::BASE32_NOPAD 205 .encode(node_addr.id.as_bytes()) 206 .to_ascii_lowercase(); 207 208 let connection_span = 209 tracing::error_span!("connection", to = ?node_addr, encoded_id = endpoint_id_str); 210 tracing::info!(parent: &connection_span, "Trying to connect to remote"); 211 212 let node_addr = if let Some(endpoint_data) = 213 self.known_mdns_endpoints.get(&node_addr.id).as_deref() 214 { 215 node_addr.with_addrs(endpoint_data.addrs().cloned().inspect(|addr| { 216 tracing::debug!(parent: &connection_span, ?addr, "Adding known address to EndpointAddr for connecting") 217 })) 218 } else { 219 tracing::debug!(parent: &connection_span, "No more known addresses for that EndpointAddr"); 220 node_addr 221 }; 222 223 tracing::debug!(parent: &connection_span, ?node_addr, "Connecting to node"); 224 let connection = self 225 .endpoint 226 .connect(node_addr.clone(), distrox_network::protocol::ALPN) 227 .instrument(connection_span.clone()) 228 .await 229 .inspect_err(|error| { 230 tracing::warn!( 231 parent: &connection_span, 232 ?error, 233 alpn = str::from_utf8(distrox_network::protocol::ALPN).unwrap_or_default(), 234 "Failed to connect to endpoint: {error:#}" 235 ); 236 }) 237 .map_err(|source| Error::Connecting { 238 node_addr: node_addr.clone(), 239 source, 240 })?; 241 242 let remote_endpoint_id = connection.remote_id(); 243 244 let (sink, stream) = connection 245 .open_bi() 246 .instrument(connection_span.clone()) 247 .await 248 .inspect_err(|error| { 249 tracing::warn!( 250 parent: &connection_span, 251 ?error, 252 "Failed to open channel to endpoint" 253 ); 254 }) 255 .map_err(|source| Error::OpeningChannel { 256 node_addr: node_addr.clone(), 257 source, 258 })?; 259 260 let channel = distrox_network::connection::Channel { 261 sink: Box::pin(tokio_util::codec::FramedWrite::new( 262 sink, 263 distrox_network::protocol::codec::Encoder, 264 )), 265 stream: Box::pin(tokio_util::codec::FramedRead::new( 266 stream, 267 distrox_network::protocol::codec::Decoder, 268 )), 269 }; 270 271 let connection = distrox_network::connection::Connection::new( 272 connection_span.clone(), 273 self.instance_cancellation_token.child_token(), 274 remote_endpoint_id, 275 channel, 276 services, 277 ); 278 279 self.connected_nodes 280 .insert(node_addr.id, connection.handle()); 281 282 tracing::info!(parent: connection_span, "Successfully connected"); 283 let connection_handle = connection.handle(); 284 tokio::task::spawn( 285 connection 286 .run() 287 .instrument(tracing::error_span!("connection")), 288 ); 289 Ok(connection_handle) 290 } 291 292 pub fn known_mdns_endpoints(&self) -> Arc<DashMap<iroh::PublicKey, EndpointData>> { 293 self.known_mdns_endpoints.clone() 294 } 295 296 pub fn connected_nodes( 297 &self, 298 ) -> Arc<DashMap<EndpointId, distrox_network::connection::ConnectionHandle>> { 299 self.connected_nodes.clone() 300 } 301 302 pub fn cancellation_token(&self) -> &CancellationToken { 303 &self.instance_cancellation_token 304 } 305 } 306 307 async fn observe_mdns_endpoints( 308 mdns: iroh::discovery::mdns::MdnsDiscovery, 309 known_mdns_endpoints: Arc<DashMap<EndpointId, EndpointData>>, 310 instance_cancellation_token: CancellationToken, 311 ) { 312 use futures::StreamExt; 313 314 let span = tracing::error_span!("instance.mdns.observation"); 315 316 let known_mdns_endpoints = known_mdns_endpoints.clone(); 317 let mut events = mdns.subscribe().instrument(span.clone()).await; 318 319 loop { 320 let Some(Some(event)) = instance_cancellation_token 321 .run_until_cancelled(events.next()) 322 .instrument(span.clone()) 323 .await 324 else { 325 tracing::info!("Stopping MDNS observation"); 326 break; 327 }; 328 329 match event { 330 iroh::discovery::mdns::DiscoveryEvent::Discovered { 331 endpoint_info: iroh::discovery::EndpointInfo { endpoint_id, data }, 332 .. 333 } => { 334 tracing::info!(parent: &span, "MDNS discovered: {endpoint_id}: {data:?}"); 335 known_mdns_endpoints.insert(endpoint_id, data); 336 } 337 iroh::discovery::mdns::DiscoveryEvent::Expired { endpoint_id } => { 338 tracing::info!(parent: &span, "MDNS expired: {endpoint_id}"); 339 known_mdns_endpoints.remove(&endpoint_id); 340 } 341 } 342 } 343 344 tracing::info!(parent: span, "MDNS observation ended"); 345 }