/ overwatch / src / overwatch / handle.rs
handle.rs
  1  use std::fmt::{Debug, Display};
  2  
  3  use tokio::{
  4      runtime::Handle,
  5      sync::mpsc::{Sender, error::SendError},
  6  };
  7  #[cfg(feature = "instrumentation")]
  8  use tracing::instrument;
  9  use tracing::{debug, error, info};
 10  
 11  use crate::{
 12      overwatch::{
 13          Error, Services,
 14          commands::{
 15              OverwatchCommand, OverwatchManagementCommand, RelayCommand, ReplyChannel,
 16              ServiceAllCommand, ServiceLifecycleCommand, ServiceSequenceCommand,
 17              ServiceSingleCommand, SettingsCommand, StatusCommand,
 18          },
 19          errors::OverwatchManagementError,
 20      },
 21      services::{
 22          AsServiceId, ServiceData,
 23          lifecycle::ServiceLifecycleError,
 24          relay::{OutboundRelay, RelayError},
 25          status::StatusWatcher,
 26      },
 27      utils::finished_signal,
 28  };
 29  
 30  /// Handler object over the main [`crate::overwatch::Overwatch`] runner.
 31  ///
 32  /// It handles communications to the main
 33  /// [`OverwatchRunner`](crate::overwatch::OverwatchRunner) for services that are
 34  /// part of the same runtime, i.e., aggregated under the same
 35  /// `RuntimeServiceId`.
 36  #[derive(Clone, Debug)]
 37  pub struct OverwatchHandle<RuntimeServiceId> {
 38      runtime_handle: Handle,
 39      sender: Sender<OverwatchCommand<RuntimeServiceId>>,
 40  }
 41  
 42  impl<RuntimeServiceId> OverwatchHandle<RuntimeServiceId> {
 43      #[must_use]
 44      pub const fn new(
 45          runtime_handle: Handle,
 46          sender: Sender<OverwatchCommand<RuntimeServiceId>>,
 47      ) -> Self {
 48          Self {
 49              runtime_handle,
 50              sender,
 51          }
 52      }
 53  
 54      #[must_use]
 55      pub const fn runtime(&self) -> &Handle {
 56          &self.runtime_handle
 57      }
 58  }
 59  
 60  impl<RuntimeServiceId> OverwatchHandle<RuntimeServiceId>
 61  where
 62      RuntimeServiceId: Debug + Sync + Display,
 63  {
 64      /// Request a relay with a service.
 65      ///
 66      /// # Errors
 67      ///
 68      /// If the relay cannot be created, or if the service is not available.
 69      pub async fn relay<Service>(&self) -> Result<OutboundRelay<Service::Message>, RelayError>
 70      where
 71          Service: ServiceData,
 72          Service::Message: 'static,
 73          RuntimeServiceId: AsServiceId<Service>,
 74      {
 75          info!("Requesting relay with {}", RuntimeServiceId::SERVICE_ID);
 76          let (sender, receiver) = tokio::sync::oneshot::channel();
 77  
 78          let Ok(()) = self
 79              .send(OverwatchCommand::Relay(RelayCommand {
 80                  service_id: RuntimeServiceId::SERVICE_ID,
 81                  reply_channel: ReplyChannel::from(sender),
 82              }))
 83              .await
 84          else {
 85              unreachable!("Service relay should always be available");
 86          };
 87          let message = receiver
 88              .await
 89              .map_err(|e| RelayError::Receiver(Box::new(e)))?;
 90          let Ok(downcasted_message) = message.downcast::<OutboundRelay<Service::Message>>() else {
 91              unreachable!("Statically should always be of the correct type");
 92          };
 93          Ok(*downcasted_message)
 94      }
 95  
 96      /// Request a [`StatusWatcher`] for a service
 97      ///
 98      /// # Panics
 99      ///
100      /// If the service watcher is not available, although this should never
101      /// happen.
102      pub async fn status_watcher<Service>(&self) -> StatusWatcher
103      where
104          RuntimeServiceId: AsServiceId<Service>,
105      {
106          info!(
107              "Requesting status watcher for {}",
108              RuntimeServiceId::SERVICE_ID
109          );
110          let (sender, receiver) = tokio::sync::oneshot::channel();
111          let Ok(()) = self
112              .send(OverwatchCommand::Status(StatusCommand {
113                  service_id: RuntimeServiceId::SERVICE_ID,
114                  reply_channel: ReplyChannel::from(sender),
115              }))
116              .await
117          else {
118              unreachable!("Service watcher should always be available");
119          };
120          receiver.await.unwrap_or_else(|_| {
121              panic!(
122                  "Service {} watcher should always be available",
123                  RuntimeServiceId::SERVICE_ID
124              )
125          })
126      }
127  
128      /// Send a [`ServiceLifecycleCommand::StartService`] command to the
129      /// [`OverwatchRunner`](crate::overwatch::OverwatchRunner).
130      ///
131      /// # Arguments
132      ///
133      /// * `Service` - The service type to start.
134      ///
135      /// # Errors
136      ///
137      /// If the command cannot be sent, or if the
138      /// [`Signal`](finished_signal::Signal) is not received.
139      pub async fn start_service<Service>(&self) -> Result<(), Error>
140      where
141          RuntimeServiceId: AsServiceId<Service>,
142      {
143          info!("Starting Service with ID {}", RuntimeServiceId::SERVICE_ID);
144  
145          let (sender, receiver) = finished_signal::channel();
146          let command = OverwatchCommand::ServiceLifecycle(ServiceLifecycleCommand::StartService(
147              ServiceSingleCommand {
148                  service_id: RuntimeServiceId::SERVICE_ID,
149                  sender,
150              },
151          ));
152  
153          self.send(command)
154              .await
155              .map_err(|_error| ServiceLifecycleError::Start)?;
156  
157          receiver.await.map_err(|error| {
158              debug!("{error:?}");
159              ServiceLifecycleError::Start.into()
160          })
161      }
162  
163      /// Send a [`ServiceLifecycleCommand::StartServiceSequence`] command to
164      /// the [`OverwatchRunner`](crate::overwatch::OverwatchRunner).
165      ///
166      /// # Arguments
167      ///
168      /// * `service_ids` - A list of service IDs to start.
169      ///
170      /// # Errors
171      ///
172      /// If the command cannot be sent, or if the
173      /// [`Signal`](finished_signal::Signal) is not received.
174      pub async fn start_service_sequence(
175          &self,
176          service_ids: impl IntoIterator<Item = RuntimeServiceId>,
177      ) -> Result<(), Error> {
178          let service_ids = service_ids.into_iter().collect::<Vec<RuntimeServiceId>>();
179          info!("Starting Service Sequence with IDs: {:?}", service_ids);
180  
181          let (sender, receiver) = finished_signal::channel();
182          let command = OverwatchCommand::ServiceLifecycle(
183              ServiceLifecycleCommand::StartServiceSequence(ServiceSequenceCommand {
184                  service_ids,
185                  sender,
186              }),
187          );
188  
189          self.send(command)
190              .await
191              .map_err(|_error| ServiceLifecycleError::StartSequence)?;
192  
193          receiver.await.map_err(|error| {
194              debug!("{error:?}");
195              ServiceLifecycleError::StartSequence.into()
196          })
197      }
198  
199      /// Send a [`ServiceLifecycleCommand::StartAllServices`] command to the
200      /// [`OverwatchRunner`](crate::overwatch::OverwatchRunner).
201      ///
202      /// # Errors
203      ///
204      /// If the command cannot be sent, or if the
205      /// [`Signal`](finished_signal::Signal) is not received.
206      pub async fn start_all_services(&self) -> Result<(), Error> {
207          info!("Starting all services");
208  
209          let (sender, receiver) = finished_signal::channel();
210          let command = OverwatchCommand::ServiceLifecycle(
211              ServiceLifecycleCommand::StartAllServices(ServiceAllCommand { sender }),
212          );
213  
214          self.send(command)
215              .await
216              .map_err(|_error| ServiceLifecycleError::StartAll)?;
217  
218          receiver.await.map_err(|error| {
219              debug!("{error:?}");
220              ServiceLifecycleError::StartAll.into()
221          })
222      }
223  
224      /// Send a [`ServiceLifecycleCommand::StopService`] command to the
225      /// [`OverwatchRunner`](crate::overwatch::OverwatchRunner).
226      ///
227      /// # Arguments
228      ///
229      /// * `Service` - The service type to stop.
230      ///
231      /// # Errors
232      ///
233      /// If the stop signal cannot be sent, or if the
234      /// [`Signal`](finished_signal::Signal) is not received.
235      pub async fn stop_service<Service>(&self) -> Result<(), Error>
236      where
237          RuntimeServiceId: AsServiceId<Service>,
238      {
239          info!("Stopping Service with ID {}", RuntimeServiceId::SERVICE_ID);
240  
241          let (sender, receiver) = tokio::sync::oneshot::channel();
242          let command = OverwatchCommand::ServiceLifecycle(ServiceLifecycleCommand::StopService(
243              ServiceSingleCommand {
244                  service_id: RuntimeServiceId::SERVICE_ID,
245                  sender,
246              },
247          ));
248  
249          self.send(command)
250              .await
251              .map_err(|_error| ServiceLifecycleError::Stop)?;
252  
253          receiver.await.map_err(|error| {
254              debug!("{error:?}");
255              ServiceLifecycleError::Stop.into()
256          })
257      }
258  
259      /// Send a [`ServiceLifecycleCommand::StopServiceSequence`] command to
260      /// the [`OverwatchRunner`](crate::overwatch::OverwatchRunner).
261      ///
262      /// # Arguments
263      ///
264      /// * `service_ids` - A list of service IDs to stop.
265      ///
266      /// # Errors
267      ///
268      /// If the stop signal cannot be sent, or if the
269      /// [`Signal`](finished_signal::Signal) is not received.
270      pub async fn stop_service_sequence(
271          &self,
272          service_ids: impl IntoIterator<Item = RuntimeServiceId>,
273      ) -> Result<(), Error> {
274          let service_ids = service_ids.into_iter().collect::<Vec<RuntimeServiceId>>();
275          info!("Stopping Service Sequence with IDs: {:?}", service_ids);
276  
277          let (sender, receiver) = finished_signal::channel();
278          let command = OverwatchCommand::ServiceLifecycle(
279              ServiceLifecycleCommand::StopServiceSequence(ServiceSequenceCommand {
280                  service_ids,
281                  sender,
282              }),
283          );
284  
285          self.send(command)
286              .await
287              .map_err(|_error| ServiceLifecycleError::StopSequence)?;
288  
289          receiver.await.map_err(|error| {
290              debug!("{error:?}");
291              ServiceLifecycleError::StopSequence.into()
292          })
293      }
294  
295      /// Send a [`ServiceLifecycleCommand::StopAllServices`] command to the
296      /// [`OverwatchRunner`](crate::overwatch::OverwatchRunner).
297      ///
298      /// # Errors
299      ///
300      /// If the command cannot be sent, or if the
301      /// [`Signal`](finished_signal::Signal) is not received.
302      pub async fn stop_all_services(&self) -> Result<(), Error> {
303          info!("Stopping all services");
304  
305          let (sender, receiver) = finished_signal::channel();
306          let command = OverwatchCommand::ServiceLifecycle(ServiceLifecycleCommand::StopAllServices(
307              ServiceAllCommand { sender },
308          ));
309  
310          self.send(command)
311              .await
312              .map_err(|_error| ServiceLifecycleError::StopAll)?;
313  
314          receiver.await.map_err(|error| {
315              debug!("{error:?}");
316              ServiceLifecycleError::StopAll.into()
317          })
318      }
319  
320      /// Send a [`ServiceLifecycleCommand::Shutdown`] command to the
321      /// [`OverwatchRunner`](crate::overwatch::OverwatchRunner).
322      ///
323      /// This triggers sending the `finish_runner_signal` to
324      /// [`Overwatch`](crate::overwatch::Overwatch). It's the signal which
325      /// [`Overwatch::wait_finished`](crate::overwatch::Overwatch::blocking_wait_finished)
326      /// waits for.
327      ///
328      /// # Errors
329      ///
330      /// If the command cannot be sent, or if the
331      /// [`Signal`](finished_signal::Signal) is not received.
332      pub async fn shutdown(&self) -> Result<(), Error> {
333          info!("Shutting down Overwatch");
334  
335          let (sender, receiver) = finished_signal::channel();
336          let command =
337              OverwatchCommand::OverwatchManagement(OverwatchManagementCommand::Shutdown(sender));
338  
339          self.send(command)
340              .await
341              .map_err(|_error| OverwatchManagementError::Shutdown)?;
342  
343          receiver.await.map_err(|error| {
344              debug!("{error:?}");
345              OverwatchManagementError::Shutdown.into()
346          })
347      }
348  
349      /// Retrieve all `Service`'s `RuntimeServiceId`'s.
350      ///
351      /// # Errors
352      ///
353      /// If the service IDs cannot be retrieved.
354      pub async fn retrieve_service_ids(&self) -> Result<Vec<RuntimeServiceId>, Error> {
355          info!("Retrieving all service IDs.");
356          let (sender, receiver) = tokio::sync::oneshot::channel();
357          let reply_channel = ReplyChannel::from(sender);
358          let command = OverwatchCommand::OverwatchManagement(
359              OverwatchManagementCommand::RetrieveServiceIds(reply_channel),
360          );
361  
362          self.send(command)
363              .await
364              .map_err(|_error| OverwatchManagementError::RetrieveServiceIds)?;
365  
366          receiver.await.map_err(|error| {
367              error!(error=?error, "Error while retrieving service IDs");
368              OverwatchManagementError::RetrieveServiceIds.into()
369          })
370      }
371  
372      /// Send a command to the
373      /// [`OverwatchRunner`](crate::overwatch::OverwatchRunner).
374      ///
375      /// # Errors
376      ///
377      /// If the received side of the channel is closed and the message cannot be
378      /// sent.
379      #[cfg_attr(
380          feature = "instrumentation",
381          instrument(name = "overwatch-command-send", skip(self))
382      )]
383      pub async fn send(
384          &self,
385          command: OverwatchCommand<RuntimeServiceId>,
386      ) -> Result<(), SendError<OverwatchCommand<RuntimeServiceId>>> {
387          self.sender.send(command).await.map_err(|error| {
388              error!(error=?error, "Error while sending an Overwatch command");
389              error
390          })
391      }
392  
393      #[cfg_attr(feature = "instrumentation", instrument(skip(self)))]
394      pub async fn update_settings<S: Services>(&self, settings: S::Settings)
395      where
396          S::Settings: Send + Debug + 'static,
397      {
398          let _: Result<(), _> = self
399              .send(OverwatchCommand::Settings(SettingsCommand(Box::new(
400                  settings,
401              ))))
402              .await
403              .map_err(|e| error!(error=?e, "Error updating settings"));
404      }
405  }