task.rs
1 #![cfg_attr(target_family = "wasm", allow(dead_code))] 2 3 /// Just-in-time initialization 4 pub mod jit; 5 pub mod waiter; 6 7 use std::collections::VecDeque; 8 use std::future::Future; 9 use std::pin::{pin, Pin}; 10 use std::sync::Arc; 11 use std::time::{Duration, SystemTime}; 12 13 use anyhow::bail; 14 use fedimint_core::time::now; 15 use fedimint_logging::{LOG_TASK, LOG_TEST}; 16 use futures::future::{self, Either}; 17 use thiserror::Error; 18 use tokio::sync::{oneshot, watch}; 19 use tracing::{debug, error, info, warn}; 20 21 use crate::runtime; 22 // TODO: stop using `task::*`, and use `runtime::*` in the code 23 // lots of churn though 24 pub use crate::runtime::*; 25 26 #[derive(Debug)] 27 struct TaskGroupInner { 28 on_shutdown_tx: watch::Sender<bool>, 29 // It is necessary to keep at least one `Receiver` around, 30 // otherwise shutdown writes are lost. 31 on_shutdown_rx: watch::Receiver<bool>, 32 // using blocking Mutex to avoid `async` in `spawn` 33 // it's OK as we don't ever need to yield 34 join: std::sync::Mutex<VecDeque<(String, JoinHandle<()>)>>, 35 // using blocking Mutex to avoid `async` in `shutdown` 36 // it's OK as we don't ever need to yield 37 subgroups: std::sync::Mutex<Vec<TaskGroup>>, 38 } 39 40 impl Default for TaskGroupInner { 41 fn default() -> Self { 42 let (on_shutdown_tx, on_shutdown_rx) = watch::channel(false); 43 Self { 44 on_shutdown_tx, 45 on_shutdown_rx, 46 join: std::sync::Mutex::new(Default::default()), 47 subgroups: std::sync::Mutex::new(vec![]), 48 } 49 } 50 } 51 52 impl TaskGroupInner { 53 pub fn shutdown(&self) { 54 // Note: set the flag before starting to call shutdown handlers 55 // to avoid confusion. 56 self.on_shutdown_tx 57 .send(true) 58 .expect("We must have on_shutdown_rx around so this never fails"); 59 60 let subgroups = self.subgroups.lock().expect("locking failed").clone(); 61 for subgroup in subgroups { 62 subgroup.inner.shutdown(); 63 } 64 } 65 } 66 /// A group of task working together 67 /// 68 /// Using this struct it is possible to spawn one or more 69 /// main thread collaborating, which can cooperatively gracefully 70 /// shut down, either due to external request, or failure of 71 /// one of them. 72 /// 73 /// Each thread should periodically check [`TaskHandle`] or rely 74 /// on condition like channel disconnection to detect when it is time 75 /// to finish. 76 #[derive(Clone, Default, Debug)] 77 pub struct TaskGroup { 78 inner: Arc<TaskGroupInner>, 79 } 80 81 impl TaskGroup { 82 pub fn new() -> Self { 83 Self::default() 84 } 85 86 pub fn make_handle(&self) -> TaskHandle { 87 TaskHandle { 88 inner: self.inner.clone(), 89 } 90 } 91 92 /// Create a sub-group 93 /// 94 /// Task subgroup works like an independent [`TaskGroup`], but the parent 95 /// `TaskGroup` will propagate the shut down signal to a sub-group. 96 /// 97 /// In contrast to using the parent group directly, a subgroup allows 98 /// calling [`Self::join_all`] and detecting any panics on just a 99 /// subset of tasks. 100 /// 101 /// The code create a subgroup is responsible for calling 102 /// [`Self::join_all`]. If it won't, the parent subgroup **will not** 103 /// detect any panics in the tasks spawned by the subgroup. 104 pub fn make_subgroup(&self) -> TaskGroup { 105 let new_tg = Self::new(); 106 self.inner 107 .subgroups 108 .lock() 109 .expect("locking failed") 110 .push(new_tg.clone()); 111 new_tg 112 } 113 114 pub fn shutdown(&self) { 115 self.inner.shutdown() 116 } 117 118 pub async fn shutdown_join_all( 119 self, 120 join_timeout: impl Into<Option<Duration>>, 121 ) -> Result<(), anyhow::Error> { 122 self.shutdown(); 123 self.join_all(join_timeout.into()).await 124 } 125 126 #[cfg(not(target_family = "wasm"))] 127 pub fn install_kill_handler(&self) { 128 use tokio::signal; 129 130 async fn wait_for_shutdown_signal() { 131 let ctrl_c = async { 132 signal::ctrl_c() 133 .await 134 .expect("failed to install Ctrl+C handler"); 135 }; 136 137 #[cfg(unix)] 138 let terminate = async { 139 signal::unix::signal(signal::unix::SignalKind::terminate()) 140 .expect("failed to install signal handler") 141 .recv() 142 .await; 143 }; 144 145 #[cfg(not(unix))] 146 let terminate = std::future::pending::<()>(); 147 148 tokio::select! { 149 _ = ctrl_c => {}, 150 _ = terminate => {}, 151 } 152 } 153 runtime::spawn("kill handlers", { 154 let task_group = self.clone(); 155 async move { 156 wait_for_shutdown_signal().await; 157 info!( 158 target: LOG_TASK, 159 "signal received, starting graceful shutdown" 160 ); 161 task_group.shutdown(); 162 } 163 }); 164 } 165 166 pub fn spawn<Fut, R>( 167 &self, 168 name: impl Into<String>, 169 f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static, 170 ) -> oneshot::Receiver<R> 171 where 172 Fut: Future<Output = R> + MaybeSend + 'static, 173 R: MaybeSend + 'static, 174 { 175 let name = name.into(); 176 let mut guard = TaskPanicGuard { 177 name: name.clone(), 178 inner: self.inner.clone(), 179 completed: false, 180 }; 181 let handle = self.make_handle(); 182 183 let (tx, rx) = oneshot::channel(); 184 let handle = crate::runtime::spawn(&name, { 185 let name = name.clone(); 186 async move { 187 // if receiver is not interested, just drop the message 188 debug!("Starting task {name}"); 189 let r = f(handle).await; 190 debug!("Finished task {name}"); 191 let _ = tx.send(r); 192 } 193 }); 194 self.inner 195 .join 196 .lock() 197 .expect("lock poison") 198 .push_back((name, handle)); 199 guard.completed = true; 200 201 rx 202 } 203 204 pub async fn spawn_local<Fut>( 205 &self, 206 name: impl Into<String>, 207 f: impl FnOnce(TaskHandle) -> Fut + 'static, 208 ) where 209 Fut: Future<Output = ()> + 'static, 210 { 211 let name = name.into(); 212 let mut guard = TaskPanicGuard { 213 name: name.clone(), 214 inner: self.inner.clone(), 215 completed: false, 216 }; 217 let handle = self.make_handle(); 218 219 let handle = runtime::spawn_local(name.as_str(), async move { 220 f(handle).await; 221 }); 222 self.inner 223 .join 224 .lock() 225 .expect("lock poison") 226 .push_back((name, handle)); 227 guard.completed = true; 228 } 229 230 /// Spawn a task that will get cancelled automatically on TaskGroup 231 /// shutdown. 232 pub fn spawn_cancellable<R>( 233 &self, 234 name: impl Into<String>, 235 future: impl Future<Output = R> + MaybeSend + 'static, 236 ) -> oneshot::Receiver<Result<R, ShuttingDownError>> 237 where 238 R: MaybeSend + 'static, 239 { 240 self.spawn(name, move |handle| async move { 241 let value = handle.cancel_on_shutdown(future).await; 242 if value.is_err() { 243 // name will part of span 244 debug!("task cancelled on shutdown"); 245 } 246 value 247 }) 248 } 249 250 pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> { 251 let deadline = timeout.map(|timeout| now() + timeout); 252 let mut errors = vec![]; 253 254 self.join_all_inner(deadline, &mut errors).await; 255 256 if errors.is_empty() { 257 Ok(()) 258 } else { 259 let num_errors = errors.len(); 260 bail!("{num_errors} tasks did not finish cleanly: {errors:?}") 261 } 262 } 263 264 #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)] 265 #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))] 266 pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) { 267 let subgroups = self.inner.subgroups.lock().expect("locking failed").clone(); 268 for subgroup in subgroups { 269 info!(target: LOG_TASK, "Waiting for subgroup to finish"); 270 subgroup.join_all_inner(deadline, errors).await; 271 info!(target: LOG_TASK, "Subgroup finished"); 272 } 273 274 // drop lock early 275 while let Some((name, join)) = { 276 let mut lock = self.inner.join.lock().expect("lock poison"); 277 lock.pop_front() 278 } { 279 debug!(target: LOG_TASK, task=%name, "Waiting for task to finish"); 280 281 let timeout = deadline.map(|deadline| { 282 deadline 283 .duration_since(now()) 284 .unwrap_or(Duration::from_millis(10)) 285 }); 286 287 #[cfg(not(target_family = "wasm"))] 288 let join_future: Pin<Box<dyn Future<Output = _> + Send>> = 289 if let Some(timeout) = timeout { 290 Box::pin(runtime::timeout(timeout, join)) 291 } else { 292 Box::pin(async move { Ok(join.await) }) 293 }; 294 295 #[cfg(target_family = "wasm")] 296 let join_future: Pin<Box<dyn Future<Output = _>>> = if let Some(timeout) = timeout { 297 Box::pin(runtime::timeout(timeout, join)) 298 } else { 299 Box::pin(async move { Ok(join.await) }) 300 }; 301 302 match join_future.await { 303 Ok(Ok(())) => { 304 debug!(target: LOG_TASK, task=%name, "Task finished"); 305 } 306 Ok(Err(e)) => { 307 error!(target: LOG_TASK, task=%name, error=%e, "Task panicked"); 308 errors.push(e); 309 } 310 Err(_) => { 311 warn!( 312 target: LOG_TASK, task=%name, 313 "Timeout waiting for task to shut down" 314 ) 315 } 316 } 317 } 318 } 319 } 320 321 pub struct TaskPanicGuard { 322 name: String, 323 inner: Arc<TaskGroupInner>, 324 /// Did the future completed successfully (no panic) 325 completed: bool, 326 } 327 328 impl TaskPanicGuard { 329 pub fn is_shutting_down(&self) -> bool { 330 *self.inner.on_shutdown_tx.borrow() 331 } 332 } 333 334 impl Drop for TaskPanicGuard { 335 fn drop(&mut self) { 336 if !self.completed { 337 info!( 338 target: LOG_TASK, 339 "Task {} shut down uncleanly. Shutting down task group.", self.name 340 ); 341 self.inner.shutdown(); 342 } 343 } 344 } 345 346 #[derive(Clone, Debug)] 347 pub struct TaskHandle { 348 inner: Arc<TaskGroupInner>, 349 } 350 351 #[derive(thiserror::Error, Debug, Clone)] 352 #[error("Task group is shutting down")] 353 #[non_exhaustive] 354 pub struct ShuttingDownError {} 355 356 impl TaskHandle { 357 /// Is task group shutting down? 358 /// 359 /// Every task in a task group should detect and stop if `true`. 360 pub fn is_shutting_down(&self) -> bool { 361 *self.inner.on_shutdown_tx.borrow() 362 } 363 364 /// Make a [`oneshot::Receiver`] that will fire on shutdown 365 /// 366 /// Tasks can use `select` on the return value to handle shutdown 367 /// signal during otherwise blocking operation. 368 pub async fn make_shutdown_rx(&self) -> TaskShutdownToken { 369 TaskShutdownToken::new(self.inner.on_shutdown_rx.clone()) 370 } 371 372 /// Run the future or cancel it if the [`TaskGroup`] shuts down. 373 pub async fn cancel_on_shutdown<F: Future>( 374 &self, 375 fut: F, 376 ) -> Result<F::Output, ShuttingDownError> { 377 let rx = TaskShutdownToken::new(self.inner.on_shutdown_rx.clone()); 378 match future::select(pin!(rx), pin!(fut)).await { 379 Either::Left(((), _)) => Err(ShuttingDownError {}), 380 Either::Right((value, _)) => Ok(value), 381 } 382 } 383 } 384 385 pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>); 386 387 impl TaskShutdownToken { 388 fn new(mut rx: watch::Receiver<bool>) -> Self { 389 Self(Box::pin(async move { 390 let _ = rx.wait_for(|v| *v).await; 391 })) 392 } 393 } 394 395 impl Future for TaskShutdownToken { 396 type Output = (); 397 398 fn poll( 399 mut self: Pin<&mut Self>, 400 cx: &mut std::task::Context<'_>, 401 ) -> std::task::Poll<Self::Output> { 402 self.0.as_mut().poll(cx) 403 } 404 } 405 406 /// async trait that use MaybeSend 407 /// 408 /// # Example 409 /// 410 /// ```rust 411 /// use fedimint_core::{apply, async_trait_maybe_send}; 412 /// #[apply(async_trait_maybe_send!)] 413 /// trait Foo { 414 /// // methods 415 /// } 416 /// 417 /// #[apply(async_trait_maybe_send!)] 418 /// impl Foo for () { 419 /// // methods 420 /// } 421 /// ``` 422 #[macro_export] 423 macro_rules! async_trait_maybe_send { 424 ($($tt:tt)*) => { 425 #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)] 426 #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))] 427 $($tt)* 428 }; 429 } 430 431 /// MaybeSync can not be used in `dyn $Trait + MaybeSend` 432 /// 433 /// # Example 434 /// 435 /// ```rust 436 /// use std::any::Any; 437 /// 438 /// use fedimint_core::{apply, maybe_add_send}; 439 /// type Foo = maybe_add_send!(dyn Any); 440 /// ``` 441 #[cfg(not(target_family = "wasm"))] 442 #[macro_export] 443 macro_rules! maybe_add_send { 444 ($($tt:tt)*) => { 445 $($tt)* + Send 446 }; 447 } 448 449 /// MaybeSync can not be used in `dyn $Trait + MaybeSend` 450 /// 451 /// # Example 452 /// 453 /// ```rust 454 /// type Foo = maybe_add_send!(dyn Any); 455 /// ``` 456 #[cfg(target_family = "wasm")] 457 #[macro_export] 458 macro_rules! maybe_add_send { 459 ($($tt:tt)*) => { 460 $($tt)* 461 }; 462 } 463 464 /// See `maybe_add_send` 465 #[cfg(not(target_family = "wasm"))] 466 #[macro_export] 467 macro_rules! maybe_add_send_sync { 468 ($($tt:tt)*) => { 469 $($tt)* + Send + Sync 470 }; 471 } 472 473 /// See `maybe_add_send` 474 #[cfg(target_family = "wasm")] 475 #[macro_export] 476 macro_rules! maybe_add_send_sync { 477 ($($tt:tt)*) => { 478 $($tt)* 479 }; 480 } 481 482 /// `MaybeSend` is no-op on wasm and `Send` on non wasm. 483 /// 484 /// On wasm, most types don't implement `Send` because JS types can not sent 485 /// between workers directly. 486 #[cfg(target_family = "wasm")] 487 pub trait MaybeSend {} 488 489 /// `MaybeSend` is no-op on wasm and `Send` on non wasm. 490 /// 491 /// On wasm, most types don't implement `Send` because JS types can not sent 492 /// between workers directly. 493 #[cfg(not(target_family = "wasm"))] 494 pub trait MaybeSend: Send {} 495 496 #[cfg(not(target_family = "wasm"))] 497 impl<T: Send> MaybeSend for T {} 498 499 #[cfg(target_family = "wasm")] 500 impl<T> MaybeSend for T {} 501 502 /// `MaybeSync` is no-op on wasm and `Sync` on non wasm. 503 #[cfg(target_family = "wasm")] 504 pub trait MaybeSync {} 505 506 /// `MaybeSync` is no-op on wasm and `Sync` on non wasm. 507 #[cfg(not(target_family = "wasm"))] 508 pub trait MaybeSync: Sync {} 509 510 #[cfg(not(target_family = "wasm"))] 511 impl<T: Sync> MaybeSync for T {} 512 513 #[cfg(target_family = "wasm")] 514 impl<T> MaybeSync for T {} 515 516 // Used in tests when sleep functionality is desired so it can be logged. 517 // Must include comment describing the reason for sleeping. 518 pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) { 519 info!( 520 target: LOG_TEST, 521 "Sleeping for {}.{:03} seconds because: {}", 522 duration.as_secs(), 523 duration.subsec_millis(), 524 comment.as_ref() 525 ); 526 sleep(duration).await; 527 } 528 529 /// An error used as a "cancelled" marker in [`Cancellable`]. 530 #[derive(Error, Debug)] 531 #[error("Operation cancelled")] 532 pub struct Cancelled; 533 534 /// Operation that can potentially get cancelled returning no result (e.g. 535 /// program shutdown). 536 pub type Cancellable<T> = std::result::Result<T, Cancelled>; 537 538 #[cfg(test)] 539 mod tests { 540 use super::*; 541 542 #[test_log::test(tokio::test)] 543 async fn shutdown_task_group_after() -> anyhow::Result<()> { 544 let tg = TaskGroup::new(); 545 tg.spawn("shutdown waiter", |handle| async move { 546 handle.make_shutdown_rx().await.await 547 }); 548 sleep(Duration::from_millis(10)).await; 549 tg.shutdown_join_all(None).await?; 550 Ok(()) 551 } 552 553 #[test_log::test(tokio::test)] 554 async fn shutdown_task_group_before() -> anyhow::Result<()> { 555 let tg = TaskGroup::new(); 556 tg.spawn("shutdown waiter", |handle| async move { 557 sleep(Duration::from_millis(10)).await; 558 handle.make_shutdown_rx().await.await 559 }); 560 tg.shutdown_join_all(None).await?; 561 Ok(()) 562 } 563 564 #[test_log::test(tokio::test)] 565 async fn shutdown_task_subgroup_after() -> anyhow::Result<()> { 566 let tg = TaskGroup::new(); 567 tg.make_subgroup() 568 .spawn("shutdown waiter", |handle| async move { 569 handle.make_shutdown_rx().await.await 570 }); 571 sleep(Duration::from_millis(10)).await; 572 tg.shutdown_join_all(None).await?; 573 Ok(()) 574 } 575 576 #[test_log::test(tokio::test)] 577 async fn shutdown_task_subgroup_before() -> anyhow::Result<()> { 578 let tg = TaskGroup::new(); 579 tg.make_subgroup() 580 .spawn("shutdown waiter", |handle| async move { 581 sleep(Duration::from_millis(10)).await; 582 handle.make_shutdown_rx().await.await 583 }); 584 tg.shutdown_join_all(None).await?; 585 Ok(()) 586 } 587 }