/ fedimint-core / src / runtime.rs
runtime.rs
  1  //! Copyright 2021 The Matrix.org Foundation C.I.C.
  2  //! Abstraction over an executor so we can spawn tasks under WASM the same way
  3  //! we do usually.
  4  
  5  // Adapted from https://github.com/matrix-org/matrix-rust-sdk
  6  
  7  use std::future::Future;
  8  use std::time::Duration;
  9  
 10  use fedimint_logging::LOG_RUNTIME;
 11  use thiserror::Error;
 12  use tokio::time::Instant;
 13  use tracing::Instrument;
 14  
 15  #[derive(Debug, Error)]
 16  #[error("deadline has elapsed")]
 17  pub struct Elapsed;
 18  
 19  pub use self::r#impl::*;
 20  
 21  #[cfg(not(target_family = "wasm"))]
 22  mod r#impl {
 23      pub use tokio::task::{JoinError, JoinHandle};
 24  
 25      use super::*;
 26  
 27      pub fn spawn<F, T>(name: &str, future: F) -> tokio::task::JoinHandle<T>
 28      where
 29          F: Future<Output = T> + 'static + Send,
 30          T: Send + 'static,
 31      {
 32          let span = tracing::debug_span!(target: LOG_RUNTIME, parent: None, "spawn", task = name);
 33          // nosemgrep: ban-tokio-spawn
 34          tokio::spawn(future.instrument(span))
 35      }
 36  
 37      pub(crate) fn spawn_local<F>(name: &str, future: F) -> JoinHandle<()>
 38      where
 39          F: Future<Output = ()> + 'static,
 40      {
 41          let span =
 42              tracing::debug_span!(target: LOG_RUNTIME, parent: None, "spawn_local", task = name);
 43          // nosemgrep: ban-tokio-spawn
 44          tokio::task::spawn_local(future.instrument(span))
 45      }
 46  
 47      // note: this call does not exist on wasm and you need to handle it
 48      // conditionally at the call site of packages that compile on wasm
 49      pub fn block_in_place<F, R>(f: F) -> R
 50      where
 51          F: FnOnce() -> R,
 52      {
 53          // nosemgrep: ban-raw-block-in-place
 54          tokio::task::block_in_place(f)
 55      }
 56  
 57      // note: this call does not exist on wasm and you need to handle it
 58      // conditionally at the call site of packages that compile on wasm
 59      pub fn block_on<F: Future>(future: F) -> F::Output {
 60          // nosemgrep: ban-raw-block-on
 61          tokio::runtime::Handle::current().block_on(future)
 62      }
 63  
 64      pub async fn sleep(duration: Duration) {
 65          // nosemgrep: ban-tokio-sleep
 66          tokio::time::sleep(duration).await
 67      }
 68  
 69      pub async fn sleep_until(deadline: Instant) {
 70          tokio::time::sleep_until(deadline).await
 71      }
 72  
 73      pub async fn timeout<T>(duration: Duration, future: T) -> Result<T::Output, Elapsed>
 74      where
 75          T: Future,
 76      {
 77          tokio::time::timeout(duration, future)
 78              .await
 79              .map_err(|_| Elapsed)
 80      }
 81  }
 82  
 83  #[cfg(target_family = "wasm")]
 84  mod r#impl {
 85  
 86      pub use std::convert::Infallible as JoinError;
 87      use std::pin::Pin;
 88      use std::task::{Context, Poll};
 89  
 90      use async_lock::{RwLock, RwLockReadGuard, RwLockWriteGuard};
 91      use futures_util::future::RemoteHandle;
 92      use futures_util::FutureExt;
 93  
 94      use super::*;
 95  
 96      #[derive(Debug)]
 97      pub struct JoinHandle<T> {
 98          handle: Option<RemoteHandle<T>>,
 99      }
100  
101      impl<T> JoinHandle<T> {
102          pub fn abort(&mut self) {
103              drop(self.handle.take());
104          }
105      }
106  
107      impl<T> Drop for JoinHandle<T> {
108          fn drop(&mut self) {
109              // don't abort the spawned future
110              if let Some(h) = self.handle.take() {
111                  h.forget();
112              }
113          }
114      }
115      impl<T: 'static> Future for JoinHandle<T> {
116          type Output = Result<T, JoinError>;
117  
118          fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
119              if let Some(handle) = self.handle.as_mut() {
120                  Pin::new(handle).poll(cx).map(Ok)
121              } else {
122                  Poll::Pending
123              }
124          }
125      }
126  
127      pub fn spawn<F, T>(name: &str, future: F) -> JoinHandle<T>
128      where
129          F: Future<Output = T> + 'static,
130      {
131          let span = tracing::debug_span!(target: LOG_RUNTIME, "spawn", task = name);
132          let (fut, handle) = future.remote_handle();
133          wasm_bindgen_futures::spawn_local(fut);
134  
135          JoinHandle {
136              handle: Some(handle),
137          }
138      }
139  
140      pub(crate) fn spawn_local<F>(name: &str, future: F) -> JoinHandle<()>
141      where
142          // No Send needed on wasm
143          F: Future<Output = ()> + 'static,
144      {
145          spawn(name, future)
146      }
147  
148      pub async fn sleep(duration: Duration) {
149          gloo_timers::future::sleep(duration.min(Duration::from_millis(i32::MAX as _))).await
150      }
151  
152      pub async fn sleep_until(deadline: Instant) {
153          // nosemgrep: ban-system-time-now
154          // nosemgrep: ban-instant-now
155          sleep(deadline.saturating_duration_since(Instant::now())).await
156      }
157  
158      pub async fn timeout<T>(duration: Duration, future: T) -> Result<T::Output, Elapsed>
159      where
160          T: Future,
161      {
162          futures::pin_mut!(future);
163          futures::select_biased! {
164              value = future.fuse() => Ok(value),
165              _ = sleep(duration).fuse() => Err(Elapsed),
166          }
167      }
168  }