/ fedimint-core / src / task.rs
task.rs
  1  #![cfg_attr(target_family = "wasm", allow(dead_code))]
  2  
  3  /// Just-in-time initialization
  4  pub mod jit;
  5  pub mod waiter;
  6  
  7  use std::collections::VecDeque;
  8  use std::future::Future;
  9  use std::pin::{pin, Pin};
 10  use std::sync::Arc;
 11  use std::time::{Duration, SystemTime};
 12  
 13  use anyhow::bail;
 14  use fedimint_core::time::now;
 15  use fedimint_logging::{LOG_TASK, LOG_TEST};
 16  use futures::future::{self, Either};
 17  use thiserror::Error;
 18  use tokio::sync::{oneshot, watch};
 19  use tracing::{debug, error, info, warn};
 20  
 21  use crate::runtime;
 22  // TODO: stop using `task::*`, and use `runtime::*` in the code
 23  // lots of churn though
 24  pub use crate::runtime::*;
 25  
 26  #[derive(Debug)]
 27  struct TaskGroupInner {
 28      on_shutdown_tx: watch::Sender<bool>,
 29      // It is necessary to keep at least one `Receiver` around,
 30      // otherwise shutdown writes are lost.
 31      on_shutdown_rx: watch::Receiver<bool>,
 32      // using blocking Mutex to avoid `async` in `spawn`
 33      // it's OK as we don't ever need to yield
 34      join: std::sync::Mutex<VecDeque<(String, JoinHandle<()>)>>,
 35      // using blocking Mutex to avoid `async` in `shutdown`
 36      // it's OK as we don't ever need to yield
 37      subgroups: std::sync::Mutex<Vec<TaskGroup>>,
 38  }
 39  
 40  impl Default for TaskGroupInner {
 41      fn default() -> Self {
 42          let (on_shutdown_tx, on_shutdown_rx) = watch::channel(false);
 43          Self {
 44              on_shutdown_tx,
 45              on_shutdown_rx,
 46              join: std::sync::Mutex::new(Default::default()),
 47              subgroups: std::sync::Mutex::new(vec![]),
 48          }
 49      }
 50  }
 51  
 52  impl TaskGroupInner {
 53      pub fn shutdown(&self) {
 54          // Note: set the flag before starting to call shutdown handlers
 55          // to avoid confusion.
 56          self.on_shutdown_tx
 57              .send(true)
 58              .expect("We must have on_shutdown_rx around so this never fails");
 59  
 60          let subgroups = self.subgroups.lock().expect("locking failed").clone();
 61          for subgroup in subgroups {
 62              subgroup.inner.shutdown();
 63          }
 64      }
 65  }
 66  /// A group of task working together
 67  ///
 68  /// Using this struct it is possible to spawn one or more
 69  /// main thread collaborating, which can cooperatively gracefully
 70  /// shut down, either due to external request, or failure of
 71  /// one of them.
 72  ///
 73  /// Each thread should periodically check [`TaskHandle`] or rely
 74  /// on condition like channel disconnection to detect when it is time
 75  /// to finish.
 76  #[derive(Clone, Default, Debug)]
 77  pub struct TaskGroup {
 78      inner: Arc<TaskGroupInner>,
 79  }
 80  
 81  impl TaskGroup {
 82      pub fn new() -> Self {
 83          Self::default()
 84      }
 85  
 86      pub fn make_handle(&self) -> TaskHandle {
 87          TaskHandle {
 88              inner: self.inner.clone(),
 89          }
 90      }
 91  
 92      /// Create a sub-group
 93      ///
 94      /// Task subgroup works like an independent [`TaskGroup`], but the parent
 95      /// `TaskGroup` will propagate the shut down signal to a sub-group.
 96      ///
 97      /// In contrast to using the parent group directly, a subgroup allows
 98      /// calling [`Self::join_all`] and detecting any panics on just a
 99      /// subset of tasks.
100      ///
101      /// The code create a subgroup is responsible for calling
102      /// [`Self::join_all`]. If it won't, the parent subgroup **will not**
103      /// detect any panics in the tasks spawned by the subgroup.
104      pub fn make_subgroup(&self) -> TaskGroup {
105          let new_tg = Self::new();
106          self.inner
107              .subgroups
108              .lock()
109              .expect("locking failed")
110              .push(new_tg.clone());
111          new_tg
112      }
113  
114      pub fn shutdown(&self) {
115          self.inner.shutdown()
116      }
117  
118      pub async fn shutdown_join_all(
119          self,
120          join_timeout: impl Into<Option<Duration>>,
121      ) -> Result<(), anyhow::Error> {
122          self.shutdown();
123          self.join_all(join_timeout.into()).await
124      }
125  
126      #[cfg(not(target_family = "wasm"))]
127      pub fn install_kill_handler(&self) {
128          use tokio::signal;
129  
130          async fn wait_for_shutdown_signal() {
131              let ctrl_c = async {
132                  signal::ctrl_c()
133                      .await
134                      .expect("failed to install Ctrl+C handler");
135              };
136  
137              #[cfg(unix)]
138              let terminate = async {
139                  signal::unix::signal(signal::unix::SignalKind::terminate())
140                      .expect("failed to install signal handler")
141                      .recv()
142                      .await;
143              };
144  
145              #[cfg(not(unix))]
146              let terminate = std::future::pending::<()>();
147  
148              tokio::select! {
149                  _ = ctrl_c => {},
150                  _ = terminate => {},
151              }
152          }
153          runtime::spawn("kill handlers", {
154              let task_group = self.clone();
155              async move {
156                  wait_for_shutdown_signal().await;
157                  info!(
158                      target: LOG_TASK,
159                      "signal received, starting graceful shutdown"
160                  );
161                  task_group.shutdown();
162              }
163          });
164      }
165  
166      pub fn spawn<Fut, R>(
167          &self,
168          name: impl Into<String>,
169          f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
170      ) -> oneshot::Receiver<R>
171      where
172          Fut: Future<Output = R> + MaybeSend + 'static,
173          R: MaybeSend + 'static,
174      {
175          let name = name.into();
176          let mut guard = TaskPanicGuard {
177              name: name.clone(),
178              inner: self.inner.clone(),
179              completed: false,
180          };
181          let handle = self.make_handle();
182  
183          let (tx, rx) = oneshot::channel();
184          let handle = crate::runtime::spawn(&name, {
185              let name = name.clone();
186              async move {
187                  // if receiver is not interested, just drop the message
188                  debug!("Starting task {name}");
189                  let r = f(handle).await;
190                  debug!("Finished task {name}");
191                  let _ = tx.send(r);
192              }
193          });
194          self.inner
195              .join
196              .lock()
197              .expect("lock poison")
198              .push_back((name, handle));
199          guard.completed = true;
200  
201          rx
202      }
203  
204      pub async fn spawn_local<Fut>(
205          &self,
206          name: impl Into<String>,
207          f: impl FnOnce(TaskHandle) -> Fut + 'static,
208      ) where
209          Fut: Future<Output = ()> + 'static,
210      {
211          let name = name.into();
212          let mut guard = TaskPanicGuard {
213              name: name.clone(),
214              inner: self.inner.clone(),
215              completed: false,
216          };
217          let handle = self.make_handle();
218  
219          let handle = runtime::spawn_local(name.as_str(), async move {
220              f(handle).await;
221          });
222          self.inner
223              .join
224              .lock()
225              .expect("lock poison")
226              .push_back((name, handle));
227          guard.completed = true;
228      }
229  
230      /// Spawn a task that will get cancelled automatically on TaskGroup
231      /// shutdown.
232      pub fn spawn_cancellable<R>(
233          &self,
234          name: impl Into<String>,
235          future: impl Future<Output = R> + MaybeSend + 'static,
236      ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
237      where
238          R: MaybeSend + 'static,
239      {
240          self.spawn(name, move |handle| async move {
241              let value = handle.cancel_on_shutdown(future).await;
242              if value.is_err() {
243                  // name will part of span
244                  debug!("task cancelled on shutdown");
245              }
246              value
247          })
248      }
249  
250      pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
251          let deadline = timeout.map(|timeout| now() + timeout);
252          let mut errors = vec![];
253  
254          self.join_all_inner(deadline, &mut errors).await;
255  
256          if errors.is_empty() {
257              Ok(())
258          } else {
259              let num_errors = errors.len();
260              bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
261          }
262      }
263  
264      #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
265      #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
266      pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
267          let subgroups = self.inner.subgroups.lock().expect("locking failed").clone();
268          for subgroup in subgroups {
269              info!(target: LOG_TASK, "Waiting for subgroup to finish");
270              subgroup.join_all_inner(deadline, errors).await;
271              info!(target: LOG_TASK, "Subgroup finished");
272          }
273  
274          // drop lock early
275          while let Some((name, join)) = {
276              let mut lock = self.inner.join.lock().expect("lock poison");
277              lock.pop_front()
278          } {
279              debug!(target: LOG_TASK, task=%name, "Waiting for task to finish");
280  
281              let timeout = deadline.map(|deadline| {
282                  deadline
283                      .duration_since(now())
284                      .unwrap_or(Duration::from_millis(10))
285              });
286  
287              #[cfg(not(target_family = "wasm"))]
288              let join_future: Pin<Box<dyn Future<Output = _> + Send>> =
289                  if let Some(timeout) = timeout {
290                      Box::pin(runtime::timeout(timeout, join))
291                  } else {
292                      Box::pin(async move { Ok(join.await) })
293                  };
294  
295              #[cfg(target_family = "wasm")]
296              let join_future: Pin<Box<dyn Future<Output = _>>> = if let Some(timeout) = timeout {
297                  Box::pin(runtime::timeout(timeout, join))
298              } else {
299                  Box::pin(async move { Ok(join.await) })
300              };
301  
302              match join_future.await {
303                  Ok(Ok(())) => {
304                      debug!(target: LOG_TASK, task=%name, "Task finished");
305                  }
306                  Ok(Err(e)) => {
307                      error!(target: LOG_TASK, task=%name, error=%e, "Task panicked");
308                      errors.push(e);
309                  }
310                  Err(_) => {
311                      warn!(
312                          target: LOG_TASK, task=%name,
313                          "Timeout waiting for task to shut down"
314                      )
315                  }
316              }
317          }
318      }
319  }
320  
321  pub struct TaskPanicGuard {
322      name: String,
323      inner: Arc<TaskGroupInner>,
324      /// Did the future completed successfully (no panic)
325      completed: bool,
326  }
327  
328  impl TaskPanicGuard {
329      pub fn is_shutting_down(&self) -> bool {
330          *self.inner.on_shutdown_tx.borrow()
331      }
332  }
333  
334  impl Drop for TaskPanicGuard {
335      fn drop(&mut self) {
336          if !self.completed {
337              info!(
338                  target: LOG_TASK,
339                  "Task {} shut down uncleanly. Shutting down task group.", self.name
340              );
341              self.inner.shutdown();
342          }
343      }
344  }
345  
346  #[derive(Clone, Debug)]
347  pub struct TaskHandle {
348      inner: Arc<TaskGroupInner>,
349  }
350  
351  #[derive(thiserror::Error, Debug, Clone)]
352  #[error("Task group is shutting down")]
353  #[non_exhaustive]
354  pub struct ShuttingDownError {}
355  
356  impl TaskHandle {
357      /// Is task group shutting down?
358      ///
359      /// Every task in a task group should detect and stop if `true`.
360      pub fn is_shutting_down(&self) -> bool {
361          *self.inner.on_shutdown_tx.borrow()
362      }
363  
364      /// Make a [`oneshot::Receiver`] that will fire on shutdown
365      ///
366      /// Tasks can use `select` on the return value to handle shutdown
367      /// signal during otherwise blocking operation.
368      pub async fn make_shutdown_rx(&self) -> TaskShutdownToken {
369          TaskShutdownToken::new(self.inner.on_shutdown_rx.clone())
370      }
371  
372      /// Run the future or cancel it if the [`TaskGroup`] shuts down.
373      pub async fn cancel_on_shutdown<F: Future>(
374          &self,
375          fut: F,
376      ) -> Result<F::Output, ShuttingDownError> {
377          let rx = TaskShutdownToken::new(self.inner.on_shutdown_rx.clone());
378          match future::select(pin!(rx), pin!(fut)).await {
379              Either::Left(((), _)) => Err(ShuttingDownError {}),
380              Either::Right((value, _)) => Ok(value),
381          }
382      }
383  }
384  
385  pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
386  
387  impl TaskShutdownToken {
388      fn new(mut rx: watch::Receiver<bool>) -> Self {
389          Self(Box::pin(async move {
390              let _ = rx.wait_for(|v| *v).await;
391          }))
392      }
393  }
394  
395  impl Future for TaskShutdownToken {
396      type Output = ();
397  
398      fn poll(
399          mut self: Pin<&mut Self>,
400          cx: &mut std::task::Context<'_>,
401      ) -> std::task::Poll<Self::Output> {
402          self.0.as_mut().poll(cx)
403      }
404  }
405  
406  /// async trait that use MaybeSend
407  ///
408  /// # Example
409  ///
410  /// ```rust
411  /// use fedimint_core::{apply, async_trait_maybe_send};
412  /// #[apply(async_trait_maybe_send!)]
413  /// trait Foo {
414  ///     // methods
415  /// }
416  ///
417  /// #[apply(async_trait_maybe_send!)]
418  /// impl Foo for () {
419  ///     // methods
420  /// }
421  /// ```
422  #[macro_export]
423  macro_rules! async_trait_maybe_send {
424      ($($tt:tt)*) => {
425          #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
426          #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
427          $($tt)*
428      };
429  }
430  
431  /// MaybeSync can not be used in `dyn $Trait + MaybeSend`
432  ///
433  /// # Example
434  ///
435  /// ```rust
436  /// use std::any::Any;
437  ///
438  /// use fedimint_core::{apply, maybe_add_send};
439  /// type Foo = maybe_add_send!(dyn Any);
440  /// ```
441  #[cfg(not(target_family = "wasm"))]
442  #[macro_export]
443  macro_rules! maybe_add_send {
444      ($($tt:tt)*) => {
445          $($tt)* + Send
446      };
447  }
448  
449  /// MaybeSync can not be used in `dyn $Trait + MaybeSend`
450  ///
451  /// # Example
452  ///
453  /// ```rust
454  /// type Foo = maybe_add_send!(dyn Any);
455  /// ```
456  #[cfg(target_family = "wasm")]
457  #[macro_export]
458  macro_rules! maybe_add_send {
459      ($($tt:tt)*) => {
460          $($tt)*
461      };
462  }
463  
464  /// See `maybe_add_send`
465  #[cfg(not(target_family = "wasm"))]
466  #[macro_export]
467  macro_rules! maybe_add_send_sync {
468      ($($tt:tt)*) => {
469          $($tt)* + Send + Sync
470      };
471  }
472  
473  /// See `maybe_add_send`
474  #[cfg(target_family = "wasm")]
475  #[macro_export]
476  macro_rules! maybe_add_send_sync {
477      ($($tt:tt)*) => {
478          $($tt)*
479      };
480  }
481  
482  /// `MaybeSend` is no-op on wasm and `Send` on non wasm.
483  ///
484  /// On wasm, most types don't implement `Send` because JS types can not sent
485  /// between workers directly.
486  #[cfg(target_family = "wasm")]
487  pub trait MaybeSend {}
488  
489  /// `MaybeSend` is no-op on wasm and `Send` on non wasm.
490  ///
491  /// On wasm, most types don't implement `Send` because JS types can not sent
492  /// between workers directly.
493  #[cfg(not(target_family = "wasm"))]
494  pub trait MaybeSend: Send {}
495  
496  #[cfg(not(target_family = "wasm"))]
497  impl<T: Send> MaybeSend for T {}
498  
499  #[cfg(target_family = "wasm")]
500  impl<T> MaybeSend for T {}
501  
502  /// `MaybeSync` is no-op on wasm and `Sync` on non wasm.
503  #[cfg(target_family = "wasm")]
504  pub trait MaybeSync {}
505  
506  /// `MaybeSync` is no-op on wasm and `Sync` on non wasm.
507  #[cfg(not(target_family = "wasm"))]
508  pub trait MaybeSync: Sync {}
509  
510  #[cfg(not(target_family = "wasm"))]
511  impl<T: Sync> MaybeSync for T {}
512  
513  #[cfg(target_family = "wasm")]
514  impl<T> MaybeSync for T {}
515  
516  // Used in tests when sleep functionality is desired so it can be logged.
517  // Must include comment describing the reason for sleeping.
518  pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
519      info!(
520          target: LOG_TEST,
521          "Sleeping for {}.{:03} seconds because: {}",
522          duration.as_secs(),
523          duration.subsec_millis(),
524          comment.as_ref()
525      );
526      sleep(duration).await;
527  }
528  
529  /// An error used as a "cancelled" marker in [`Cancellable`].
530  #[derive(Error, Debug)]
531  #[error("Operation cancelled")]
532  pub struct Cancelled;
533  
534  /// Operation that can potentially get cancelled returning no result (e.g.
535  /// program shutdown).
536  pub type Cancellable<T> = std::result::Result<T, Cancelled>;
537  
538  #[cfg(test)]
539  mod tests {
540      use super::*;
541  
542      #[test_log::test(tokio::test)]
543      async fn shutdown_task_group_after() -> anyhow::Result<()> {
544          let tg = TaskGroup::new();
545          tg.spawn("shutdown waiter", |handle| async move {
546              handle.make_shutdown_rx().await.await
547          });
548          sleep(Duration::from_millis(10)).await;
549          tg.shutdown_join_all(None).await?;
550          Ok(())
551      }
552  
553      #[test_log::test(tokio::test)]
554      async fn shutdown_task_group_before() -> anyhow::Result<()> {
555          let tg = TaskGroup::new();
556          tg.spawn("shutdown waiter", |handle| async move {
557              sleep(Duration::from_millis(10)).await;
558              handle.make_shutdown_rx().await.await
559          });
560          tg.shutdown_join_all(None).await?;
561          Ok(())
562      }
563  
564      #[test_log::test(tokio::test)]
565      async fn shutdown_task_subgroup_after() -> anyhow::Result<()> {
566          let tg = TaskGroup::new();
567          tg.make_subgroup()
568              .spawn("shutdown waiter", |handle| async move {
569                  handle.make_shutdown_rx().await.await
570              });
571          sleep(Duration::from_millis(10)).await;
572          tg.shutdown_join_all(None).await?;
573          Ok(())
574      }
575  
576      #[test_log::test(tokio::test)]
577      async fn shutdown_task_subgroup_before() -> anyhow::Result<()> {
578          let tg = TaskGroup::new();
579          tg.make_subgroup()
580              .spawn("shutdown waiter", |handle| async move {
581                  sleep(Duration::from_millis(10)).await;
582                  handle.make_shutdown_rx().await.await
583              });
584          tg.shutdown_join_all(None).await?;
585          Ok(())
586      }
587  }