/ crates / distrox-instance / src / instance.rs
instance.rs
  1  use std::sync::Arc;
  2  
  3  use dashmap::DashMap;
  4  use distrox_network::connection::runner::Blocklist;
  5  use iroh::Endpoint;
  6  use iroh::EndpointAddr;
  7  use iroh::EndpointId;
  8  use iroh::discovery::EndpointData;
  9  use tokio_util::sync::CancellationToken;
 10  use tracing::Instrument;
 11  
 12  use crate::command::InstanceCommand;
 13  use crate::error::Error;
 14  use crate::event::InstanceEvent;
 15  
 16  pub struct Instance {
 17      instance_cancellation_token: CancellationToken,
 18      endpoint: Endpoint,
 19  
 20      connected_nodes: Arc<DashMap<EndpointId, distrox_network::connection::ConnectionHandle>>,
 21      known_mdns_endpoints: Arc<DashMap<EndpointId, EndpointData>>,
 22  
 23      instance_command_sender: tokio::sync::mpsc::Sender<InstanceCommand>,
 24      instance_command_recv: tokio::sync::mpsc::Receiver<InstanceCommand>,
 25  
 26      blocklist: Blocklist,
 27  }
 28  
 29  impl Drop for Instance {
 30      fn drop(&mut self) {
 31          self.instance_cancellation_token.cancel();
 32      }
 33  }
 34  
 35  impl std::fmt::Debug for Instance {
 36      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 37          f.debug_struct("Instance").finish_non_exhaustive()
 38      }
 39  }
 40  
 41  impl Instance {
 42      pub async fn new(
 43          secret_key: distrox_keys::private::PrivateSigningKey,
 44          relay_mode: iroh::RelayMode,
 45          blocklist: Blocklist,
 46      ) -> Result<Self, Error> {
 47          let instance_cancellation_token = CancellationToken::new();
 48  
 49          tracing::info!(relay_mode = ?relay_mode, "Booting endpoint");
 50  
 51          let endpoint = secret_key
 52              .set_into(Endpoint::empty_builder(iroh::RelayMode::Disabled))
 53              .relay_mode(relay_mode)
 54              // TODO For now this should be enough
 55              .discovery(iroh::discovery::mdns::MdnsDiscovery::builder().advertise(true))
 56              .alpns(vec![distrox_network::protocol::ALPN.to_vec()])
 57              .bind()
 58              .await
 59              .map_err(Error::EndpointBind)?;
 60  
 61          let mdns = iroh::discovery::mdns::MdnsDiscovery::builder()
 62              .advertise(true)
 63              .service_name("distrox")
 64              .build(endpoint.id())
 65              .map_err(Error::BuildingMdnsDiscovery)?;
 66  
 67          endpoint.discovery().add(mdns.clone());
 68  
 69          let known_mdns_endpoints = Arc::new(DashMap::new());
 70  
 71          tokio::spawn(observe_mdns_endpoints(
 72              mdns,
 73              known_mdns_endpoints.clone(),
 74              instance_cancellation_token.clone(),
 75          ));
 76  
 77          tracing::info!(endpoint_id = %endpoint.id().fmt_short(), "Booted endpoint");
 78  
 79          tracing::debug!("Spawning background worker");
 80  
 81          tracing::debug!("Instance started");
 82          let (instance_command_sender, instance_command_recv) = tokio::sync::mpsc::channel(100);
 83  
 84          Ok(Self {
 85              instance_cancellation_token,
 86              endpoint,
 87              instance_command_sender,
 88              instance_command_recv,
 89  
 90              connected_nodes: Arc::new(DashMap::new()),
 91              known_mdns_endpoints,
 92  
 93              blocklist,
 94          })
 95      }
 96  
 97      pub fn handle(&self) -> crate::handle::InstanceHandle {
 98          crate::handle::InstanceHandle {
 99              connections: self.connected_nodes.clone(),
100              endpoint_id: self.endpoint.id(),
101              endpoint_addr: self.endpoint.addr(),
102              sender: self.instance_command_sender.clone(),
103              blocklist: self.blocklist.clone(),
104              instance_cancellation_token: self.instance_cancellation_token.clone(),
105          }
106      }
107  
108      pub async fn run(mut self) -> impl futures::Stream<Item = InstanceEvent> {
109          let endpoint = self.endpoint.clone();
110  
111          let (event_sender, event_recv) = tokio::sync::mpsc::channel(1024);
112          tokio::task::spawn({
113              let instance_cancellation_token = self.instance_cancellation_token.clone();
114  
115              async move {
116                  loop {
117                      tokio::select! {
118                          instance_command = self.instance_command_recv.recv() => {
119                              if let Some(instance_command) = instance_command
120                                  && self.handle_instance_command(instance_command).await == std::ops::ControlFlow::Break(()) {
121                                      break
122                                  }
123                          }
124  
125                          incoming = endpoint.accept() => {
126                              if let Some(incoming) = incoming {
127                                  let event_sender = event_sender.clone();
128                                  let instance_cancellation_token = self.instance_cancellation_token.clone();
129  
130                                  tokio::task::spawn(async move {
131                                      match instance_cancellation_token.run_until_cancelled(event_sender.send(InstanceEvent::Incoming(incoming))).await {
132                                          Some(Ok(_)) => tracing::trace!("Successfully send event to processing loop"),
133                                          Some(Err(_)) => tracing::error!("Internal channel error"),
134                                          None => tracing::debug!("Instance cancelled"),
135                                      }
136                                  });
137                              } else {
138                                  tracing::error!("Accepting failed, Instance cannot continue to run");
139                                  break
140                              }
141                          }
142  
143                          _cancelled = instance_cancellation_token.cancelled() => {
144                              tracing::info!("Instance cancelled, finish processing");
145                              break
146                          }
147                      }
148                  }
149              }
150          });
151  
152          tokio_stream::wrappers::ReceiverStream::new(event_recv)
153      }
154  
155      async fn handle_instance_command(&self, command: InstanceCommand) -> std::ops::ControlFlow<()> {
156          match command {
157              InstanceCommand::Shutdown(reply_sender) => {
158                  if reply_sender.send(()).is_err() {
159                      tracing::error!("Internal channel error");
160                  }
161                  return std::ops::ControlFlow::Break(());
162              }
163  
164              InstanceCommand::WaitOnline(reply_sender) => {
165                  self.endpoint
166                      .online()
167                      .instrument(tracing::error_span!("instance.wait_until_online"))
168                      .await;
169  
170                  if reply_sender.send(()).is_err() {
171                      tracing::error!("Internal channel error");
172                  }
173              }
174              InstanceCommand::ConnectTo {
175                  target,
176                  result_sender,
177                  services,
178              } => {
179                  let result = self.connect_to(target, services).await;
180                  if result_sender.send(result).is_err() {
181                      tracing::error!("Internal channel error");
182                  }
183              }
184          }
185  
186          std::ops::ControlFlow::Continue(())
187      }
188  
189      /// Connect to a remote instance
190      ///
191      /// # Note
192      ///
193      /// If the instance is already connected to the remote pointed to by `node_addr`, the
194      /// [distrox_network::connection::ConnectionHandle] for that connection is returned.
195      async fn connect_to(
196          &self,
197          node_addr: EndpointAddr,
198          services: distrox_network::connection::Services,
199      ) -> Result<distrox_network::connection::ConnectionHandle, Error> {
200          if let Some(connection_handle) = self.connected_nodes.get(&node_addr.id) {
201              return Ok(connection_handle.clone());
202          }
203  
204          let endpoint_id_str = data_encoding::BASE32_NOPAD
205              .encode(node_addr.id.as_bytes())
206              .to_ascii_lowercase();
207  
208          let connection_span =
209              tracing::error_span!("connection", to = ?node_addr, encoded_id = endpoint_id_str);
210          tracing::info!(parent: &connection_span, "Trying to connect to remote");
211  
212          let node_addr = if let Some(endpoint_data) =
213              self.known_mdns_endpoints.get(&node_addr.id).as_deref()
214          {
215              node_addr.with_addrs(endpoint_data.addrs().cloned().inspect(|addr| {
216                      tracing::debug!(parent: &connection_span, ?addr, "Adding known address to EndpointAddr for connecting")
217                  }))
218          } else {
219              tracing::debug!(parent: &connection_span, "No more known addresses for that EndpointAddr");
220              node_addr
221          };
222  
223          tracing::debug!(parent: &connection_span, ?node_addr, "Connecting to node");
224          let connection = self
225              .endpoint
226              .connect(node_addr.clone(), distrox_network::protocol::ALPN)
227              .instrument(connection_span.clone())
228              .await
229              .inspect_err(|error| {
230                  tracing::warn!(
231                      parent: &connection_span,
232                      ?error,
233                      alpn = str::from_utf8(distrox_network::protocol::ALPN).unwrap_or_default(),
234                      "Failed to connect to endpoint: {error:#}"
235                  );
236              })
237              .map_err(|source| Error::Connecting {
238                  node_addr: node_addr.clone(),
239                  source,
240              })?;
241  
242          let remote_endpoint_id = connection.remote_id();
243  
244          let (sink, stream) = connection
245              .open_bi()
246              .instrument(connection_span.clone())
247              .await
248              .inspect_err(|error| {
249                  tracing::warn!(
250                      parent: &connection_span,
251                      ?error,
252                      "Failed to open channel to endpoint"
253                  );
254              })
255              .map_err(|source| Error::OpeningChannel {
256                  node_addr: node_addr.clone(),
257                  source,
258              })?;
259  
260          let channel = distrox_network::connection::Channel {
261              sink: Box::pin(tokio_util::codec::FramedWrite::new(
262                  sink,
263                  distrox_network::protocol::codec::Encoder,
264              )),
265              stream: Box::pin(tokio_util::codec::FramedRead::new(
266                  stream,
267                  distrox_network::protocol::codec::Decoder,
268              )),
269          };
270  
271          let connection = distrox_network::connection::Connection::new(
272              connection_span.clone(),
273              self.instance_cancellation_token.child_token(),
274              remote_endpoint_id,
275              channel,
276              services,
277          );
278  
279          self.connected_nodes
280              .insert(node_addr.id, connection.handle());
281  
282          tracing::info!(parent: connection_span, "Successfully connected");
283          let connection_handle = connection.handle();
284          tokio::task::spawn(
285              connection
286                  .run()
287                  .instrument(tracing::error_span!("connection")),
288          );
289          Ok(connection_handle)
290      }
291  
292      pub fn known_mdns_endpoints(&self) -> Arc<DashMap<iroh::PublicKey, EndpointData>> {
293          self.known_mdns_endpoints.clone()
294      }
295  
296      pub fn connected_nodes(
297          &self,
298      ) -> Arc<DashMap<EndpointId, distrox_network::connection::ConnectionHandle>> {
299          self.connected_nodes.clone()
300      }
301  
302      pub fn cancellation_token(&self) -> &CancellationToken {
303          &self.instance_cancellation_token
304      }
305  }
306  
307  async fn observe_mdns_endpoints(
308      mdns: iroh::discovery::mdns::MdnsDiscovery,
309      known_mdns_endpoints: Arc<DashMap<EndpointId, EndpointData>>,
310      instance_cancellation_token: CancellationToken,
311  ) {
312      use futures::StreamExt;
313  
314      let span = tracing::error_span!("instance.mdns.observation");
315  
316      let known_mdns_endpoints = known_mdns_endpoints.clone();
317      let mut events = mdns.subscribe().instrument(span.clone()).await;
318  
319      loop {
320          let Some(Some(event)) = instance_cancellation_token
321              .run_until_cancelled(events.next())
322              .instrument(span.clone())
323              .await
324          else {
325              tracing::info!("Stopping MDNS observation");
326              break;
327          };
328  
329          match event {
330              iroh::discovery::mdns::DiscoveryEvent::Discovered {
331                  endpoint_info: iroh::discovery::EndpointInfo { endpoint_id, data },
332                  ..
333              } => {
334                  tracing::info!(parent: &span, "MDNS discovered: {endpoint_id}: {data:?}");
335                  known_mdns_endpoints.insert(endpoint_id, data);
336              }
337              iroh::discovery::mdns::DiscoveryEvent::Expired { endpoint_id } => {
338                  tracing::info!(parent: &span, "MDNS expired: {endpoint_id}");
339                  known_mdns_endpoints.remove(&endpoint_id);
340              }
341          }
342      }
343  
344      tracing::info!(parent: span, "MDNS observation ended");
345  }