runtime.rs
1 //! Copyright 2021 The Matrix.org Foundation C.I.C. 2 //! Abstraction over an executor so we can spawn tasks under WASM the same way 3 //! we do usually. 4 5 // Adapted from https://github.com/matrix-org/matrix-rust-sdk 6 7 use std::future::Future; 8 use std::time::Duration; 9 10 use fedimint_logging::LOG_RUNTIME; 11 use thiserror::Error; 12 use tokio::time::Instant; 13 use tracing::Instrument; 14 15 #[derive(Debug, Error)] 16 #[error("deadline has elapsed")] 17 pub struct Elapsed; 18 19 pub use self::r#impl::*; 20 21 #[cfg(not(target_family = "wasm"))] 22 mod r#impl { 23 pub use tokio::task::{JoinError, JoinHandle}; 24 25 use super::*; 26 27 pub fn spawn<F, T>(name: &str, future: F) -> tokio::task::JoinHandle<T> 28 where 29 F: Future<Output = T> + 'static + Send, 30 T: Send + 'static, 31 { 32 let span = tracing::debug_span!(target: LOG_RUNTIME, parent: None, "spawn", task = name); 33 // nosemgrep: ban-tokio-spawn 34 tokio::spawn(future.instrument(span)) 35 } 36 37 pub(crate) fn spawn_local<F>(name: &str, future: F) -> JoinHandle<()> 38 where 39 F: Future<Output = ()> + 'static, 40 { 41 let span = 42 tracing::debug_span!(target: LOG_RUNTIME, parent: None, "spawn_local", task = name); 43 // nosemgrep: ban-tokio-spawn 44 tokio::task::spawn_local(future.instrument(span)) 45 } 46 47 // note: this call does not exist on wasm and you need to handle it 48 // conditionally at the call site of packages that compile on wasm 49 pub fn block_in_place<F, R>(f: F) -> R 50 where 51 F: FnOnce() -> R, 52 { 53 // nosemgrep: ban-raw-block-in-place 54 tokio::task::block_in_place(f) 55 } 56 57 // note: this call does not exist on wasm and you need to handle it 58 // conditionally at the call site of packages that compile on wasm 59 pub fn block_on<F: Future>(future: F) -> F::Output { 60 // nosemgrep: ban-raw-block-on 61 tokio::runtime::Handle::current().block_on(future) 62 } 63 64 pub async fn sleep(duration: Duration) { 65 // nosemgrep: ban-tokio-sleep 66 tokio::time::sleep(duration).await 67 } 68 69 pub async fn sleep_until(deadline: Instant) { 70 tokio::time::sleep_until(deadline).await 71 } 72 73 pub async fn timeout<T>(duration: Duration, future: T) -> Result<T::Output, Elapsed> 74 where 75 T: Future, 76 { 77 tokio::time::timeout(duration, future) 78 .await 79 .map_err(|_| Elapsed) 80 } 81 } 82 83 #[cfg(target_family = "wasm")] 84 mod r#impl { 85 86 pub use std::convert::Infallible as JoinError; 87 use std::pin::Pin; 88 use std::task::{Context, Poll}; 89 90 use async_lock::{RwLock, RwLockReadGuard, RwLockWriteGuard}; 91 use futures_util::future::RemoteHandle; 92 use futures_util::FutureExt; 93 94 use super::*; 95 96 #[derive(Debug)] 97 pub struct JoinHandle<T> { 98 handle: Option<RemoteHandle<T>>, 99 } 100 101 impl<T> JoinHandle<T> { 102 pub fn abort(&mut self) { 103 drop(self.handle.take()); 104 } 105 } 106 107 impl<T> Drop for JoinHandle<T> { 108 fn drop(&mut self) { 109 // don't abort the spawned future 110 if let Some(h) = self.handle.take() { 111 h.forget(); 112 } 113 } 114 } 115 impl<T: 'static> Future for JoinHandle<T> { 116 type Output = Result<T, JoinError>; 117 118 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 119 if let Some(handle) = self.handle.as_mut() { 120 Pin::new(handle).poll(cx).map(Ok) 121 } else { 122 Poll::Pending 123 } 124 } 125 } 126 127 pub fn spawn<F, T>(name: &str, future: F) -> JoinHandle<T> 128 where 129 F: Future<Output = T> + 'static, 130 { 131 let span = tracing::debug_span!(target: LOG_RUNTIME, "spawn", task = name); 132 let (fut, handle) = future.remote_handle(); 133 wasm_bindgen_futures::spawn_local(fut); 134 135 JoinHandle { 136 handle: Some(handle), 137 } 138 } 139 140 pub(crate) fn spawn_local<F>(name: &str, future: F) -> JoinHandle<()> 141 where 142 // No Send needed on wasm 143 F: Future<Output = ()> + 'static, 144 { 145 spawn(name, future) 146 } 147 148 pub async fn sleep(duration: Duration) { 149 gloo_timers::future::sleep(duration.min(Duration::from_millis(i32::MAX as _))).await 150 } 151 152 pub async fn sleep_until(deadline: Instant) { 153 // nosemgrep: ban-system-time-now 154 // nosemgrep: ban-instant-now 155 sleep(deadline.saturating_duration_since(Instant::now())).await 156 } 157 158 pub async fn timeout<T>(duration: Duration, future: T) -> Result<T::Output, Elapsed> 159 where 160 T: Future, 161 { 162 futures::pin_mut!(future); 163 futures::select_biased! { 164 value = future.fuse() => Ok(value), 165 _ = sleep(duration).fuse() => Err(Elapsed), 166 } 167 } 168 }