handle.rs
  1  use distrox_wire_types::id::FromModelId;
  2  use distrox_wire_types::id::IntoModelId;
  3  use futures::FutureExt;
  4  use tracing::Instrument;
  5  
  6  use crate::connection::command::ConnectionCommand;
  7  
  8  /// A handle to a [Connection] that has been [Connection::run]ed
  9  ///
 10  /// Objects of this type can be used to work with a [Connection] that is running in the background.
 11  /// The [ConnectionHandle] can be used to talk to the remote peer, for example requesting data from
 12  /// it.
 13  ///
 14  /// [ConnectionHandle] implements [tower::Service], so throttling, load balancing, etc etc is
 15  /// possible via [tower] extensions.
 16  ///
 17  /// [ConnectionHandle] is [Clone], so passing it around is possible easily.
 18  ///
 19  /// Once all [ConnectionHandle] instances are dropped, the [Connection] is dropped as well.
 20  #[derive(Clone)]
 21  pub struct ConnectionHandle {
 22      pub(super) span: tracing::Span,
 23      pub(super) cancellation_token: tokio_util::sync::CancellationToken,
 24      pub(super) remote_node_id: iroh::PublicKey,
 25      pub(super) command_sender: tokio::sync::mpsc::Sender<ConnectionCommand>,
 26  }
 27  
 28  impl Drop for ConnectionHandle {
 29      fn drop(&mut self) {
 30          tracing::trace!(
 31              parent: &self.span,
 32              strong_count = self.command_sender.strong_count(),
 33              "Dropping connection handle"
 34          );
 35      }
 36  }
 37  
 38  impl std::fmt::Debug for ConnectionHandle {
 39      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 40          f.debug_struct("ConnectionHandle")
 41              .field("remote_node_id", &self.remote_node_id)
 42              .finish_non_exhaustive()
 43      }
 44  }
 45  
 46  impl tower::Service<distrox_api::node::GetNodeRequest> for ConnectionHandle {
 47      type Response = distrox_api::node::GetNodeResponse;
 48      type Error = crate::error::Error;
 49      type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
 50  
 51      fn poll_ready(
 52          &mut self,
 53          _cx: &mut std::task::Context<'_>,
 54      ) -> std::task::Poll<Result<(), Self::Error>> {
 55          std::task::Poll::Ready(Ok(()))
 56      }
 57  
 58      fn call(&mut self, req: distrox_api::node::GetNodeRequest) -> Self::Future {
 59          let id = distrox_wire_types::node_id::NodeId::from_model_id(req.id);
 60          let (reply_sender, reply_recv) = tokio::sync::oneshot::channel();
 61          let command_sender = self.command_sender.clone();
 62  
 63          async move {
 64              command_sender
 65                  .send(ConnectionCommand::SendNodeGetRequest { id, reply_sender })
 66                  .await
 67                  .map_err(|_| crate::error::Error::InternalChannelError)?;
 68  
 69              tracing::debug!("Successfully send NodeGetRequest");
 70  
 71              if let Some((node, signatures)) = reply_recv.await?? {
 72                  let node = node.into_model_repr()?;
 73                  let signatures = signatures.map(distrox_model::node::Signature::from);
 74                  let tpl = Some((node, signatures));
 75                  Ok(distrox_api::node::GetNodeResponse(tpl))
 76              } else {
 77                  Ok(distrox_api::node::GetNodeResponse(None))
 78              }
 79          }
 80          .instrument(tracing::error_span!(parent: &self.span, "get-node-request"))
 81          .boxed()
 82      }
 83  }
 84  
 85  impl tower::Service<distrox_api::content::ContentGetRequest> for ConnectionHandle {
 86      type Response = distrox_api::content::ContentGetResponse;
 87      type Error = crate::error::Error;
 88      type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
 89  
 90      fn poll_ready(
 91          &mut self,
 92          _cx: &mut std::task::Context<'_>,
 93      ) -> std::task::Poll<Result<(), Self::Error>> {
 94          std::task::Poll::Ready(Ok(()))
 95      }
 96  
 97      fn call(&mut self, req: distrox_api::content::ContentGetRequest) -> Self::Future {
 98          let id = distrox_wire_types::content_id::ContentId::from_model_id(req.id);
 99          let (reply_sender, reply_recv) = tokio::sync::oneshot::channel();
100          let command_sender = self.command_sender.clone();
101  
102          async move {
103              command_sender
104                  .send(ConnectionCommand::SendContentGetRequest { id, reply_sender })
105                  .await
106                  .map_err(|_| crate::error::Error::InternalChannelError)?;
107  
108              reply_recv
109                  .await?
110                  .and_then(|cont| {
111                      cont.map(|c| c.into_model_repr().map_err(crate::error::Error::from))
112                          .transpose()
113                  })
114                  .map(distrox_api::content::ContentGetResponse)
115          }
116          .instrument(tracing::error_span!(parent: &self.span, "content-get-request"))
117          .boxed()
118      }
119  }
120  
121  impl tower::Service<distrox_api::node::GetHeadRequest> for ConnectionHandle {
122      type Response = distrox_api::node::GetHeadResponse;
123      type Error = crate::error::Error;
124      type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
125  
126      fn poll_ready(
127          &mut self,
128          _cx: &mut std::task::Context<'_>,
129      ) -> std::task::Poll<Result<(), Self::Error>> {
130          std::task::Poll::Ready(Ok(()))
131      }
132  
133      fn call(&mut self, req: distrox_api::node::GetHeadRequest) -> Self::Future {
134          let (reply_sender, reply_recv) = tokio::sync::oneshot::channel();
135          let command_sender = self.command_sender.clone();
136  
137          async move {
138              let key = distrox_wire_types::key::PublicKey::from(req.key);
139              let key = iroh::PublicKey::try_from(key)?;
140  
141              command_sender
142                  .send(ConnectionCommand::SendHeadGetRequest { key, reply_sender })
143                  .await
144                  .map_err(|_| crate::error::Error::InternalChannelError)?;
145  
146              reply_recv
147                  .await?
148                  .map(|nid| nid.map(|n| n.into_model_id()))
149                  .map(distrox_api::node::GetHeadResponse)
150          }
151          .instrument(tracing::error_span!(parent: &self.span, "get-head-request"))
152          .boxed()
153      }
154  }
155  
156  impl tower::Service<distrox_api::payload::PayloadGetRequest> for ConnectionHandle {
157      type Response = distrox_api::payload::PayloadGetResponse;
158      type Error = crate::error::Error;
159      type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;
160  
161      fn poll_ready(
162          &mut self,
163          _cx: &mut std::task::Context<'_>,
164      ) -> std::task::Poll<Result<(), Self::Error>> {
165          std::task::Poll::Ready(Ok(()))
166      }
167  
168      fn call(&mut self, req: distrox_api::payload::PayloadGetRequest) -> Self::Future {
169          let id = distrox_wire_types::payload_id::PayloadId::from_model_id(req.id);
170          let (reply_sender, reply_recv) = tokio::sync::oneshot::channel();
171          let command_sender = self.command_sender.clone();
172  
173          async move {
174              command_sender
175                  .send(ConnectionCommand::SendPayloadGetRequest { id, reply_sender })
176                  .await
177                  .map_err(|_| crate::error::Error::InternalChannelError)?;
178  
179              reply_recv
180                  .await?
181                  .map(|pl| pl.map(|pl| pl.into_model_repr()))
182                  .map(distrox_api::payload::PayloadGetResponse)
183          }
184          .instrument(tracing::error_span!(parent: &self.span, "payload-get-request"))
185          .boxed()
186      }
187  }
188  
189  impl ConnectionHandle {
190      pub fn remote_node_id(&self) -> iroh::PublicKey {
191          self.remote_node_id
192      }
193  
194      pub fn cancellation_token(&self) -> &tokio_util::sync::CancellationToken {
195          &self.cancellation_token
196      }
197  }