/ fedimint-core / src / task / jit.rs
jit.rs
  1  use std::convert::Infallible;
  2  use std::sync::Arc;
  3  use std::{fmt, panic};
  4  
  5  use fedimint_core::runtime::JoinHandle;
  6  use fedimint_logging::LOG_TASK;
  7  use futures::Future;
  8  use tokio::sync;
  9  use tracing::warn;
 10  
 11  use super::MaybeSend;
 12  
 13  pub type Jit<T> = JitCore<T, Infallible>;
 14  pub type JitTry<T, E> = JitCore<T, E>;
 15  pub type JitTryAnyhow<T> = JitCore<T, anyhow::Error>;
 16  
 17  /// Error that could have been returned before
 18  ///
 19  /// Newtype over `Option<E>` that allows better user (error conversion mostly)
 20  /// experience
 21  #[derive(Debug)]
 22  pub enum OneTimeError<E> {
 23      Original(E),
 24      Copy(anyhow::Error),
 25  }
 26  
 27  impl<E> std::error::Error for OneTimeError<E>
 28  where
 29      E: fmt::Debug,
 30      E: fmt::Display,
 31  {
 32      fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
 33          None
 34      }
 35  
 36      fn cause(&self) -> Option<&dyn std::error::Error> {
 37          self.source()
 38      }
 39  }
 40  
 41  impl<E> fmt::Display for OneTimeError<E>
 42  where
 43      E: fmt::Display,
 44  {
 45      fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 46          match self {
 47              OneTimeError::Original(o) => o.fmt(f),
 48              OneTimeError::Copy(c) => c.fmt(f),
 49          }
 50      }
 51  }
 52  
 53  /// A value that initializes eagerly in parallel in a falliable way
 54  #[derive(Debug)]
 55  pub struct JitCore<T, E> {
 56      inner: Arc<JitInner<T, E>>,
 57  }
 58  
 59  #[derive(Debug)]
 60  struct JitInner<T, E> {
 61      handle: sync::Mutex<JoinHandle<Result<T, E>>>,
 62      val: sync::OnceCell<Result<T, String>>,
 63  }
 64  
 65  impl<T, E> Clone for JitCore<T, E>
 66  where
 67      T: Clone,
 68  {
 69      fn clone(&self) -> Self {
 70          Self {
 71              inner: self.inner.clone(),
 72          }
 73      }
 74  }
 75  impl<T, E> Drop for JitInner<T, E> {
 76      fn drop(&mut self) {
 77          self.handle.get_mut().abort();
 78      }
 79  }
 80  impl<T, E> JitCore<T, E>
 81  where
 82      T: MaybeSend + 'static,
 83      E: MaybeSend + 'static + fmt::Display,
 84  {
 85      /// Create `JitTry` value, and spawn a future `f` that computes its value
 86      ///
 87      /// Unlike normal Rust futures, the `f` executes eagerly (is spawned as a
 88      /// tokio task).
 89      pub fn new_try<Fut>(f: impl FnOnce() -> Fut + 'static + MaybeSend) -> Self
 90      where
 91          Fut: Future<Output = std::result::Result<T, E>> + 'static + MaybeSend,
 92      {
 93          let handle = crate::runtime::spawn("jit-value", async move { f().await });
 94  
 95          Self {
 96              inner: JitInner {
 97                  handle: handle.into(),
 98                  val: sync::OnceCell::new(),
 99              }
100              .into(),
101          }
102      }
103  
104      /// Get the reference to the value, potentially blocking for the
105      /// initialization future to complete
106      pub async fn get_try(&self) -> Result<&T, OneTimeError<E>> {
107          let mut init_error = None;
108          let value = self
109              .inner
110              .val
111              .get_or_init(|| async {
112                  let handle: &mut _ = &mut *self.inner.handle.lock().await;
113                  match handle.await {
114                          Ok(Ok(o)) => Ok(o),
115                          Ok(Err(err)) => {
116                              let err_str = err.to_string();
117                              init_error = Some(err);
118                              Err(err_str)
119                          },
120                          Err(err) => {
121  
122                              #[cfg(not(target_family = "wasm"))]
123                              if err.is_panic() {
124                                  warn!(target: LOG_TASK, %err, type_name = %std::any::type_name::<T>(), "Jit value panicked");
125                                  // Resume the panic on the main task
126                                  panic::resume_unwind(err.into_panic());
127                              }
128                              #[cfg(not(target_family = "wasm"))]
129                              if err.is_cancelled() {
130                                  warn!(target: LOG_TASK, %err, type_name = %std::any::type_name::<T>(), "Jit value task canceled:");
131                              }
132                              Err(format!("Jit value {} failed unexpectedly with: {}", std::any::type_name::<T>(), err))
133                          },
134                      }
135              })
136              .await;
137          if let Some(err) = init_error {
138              return Err(OneTimeError::Original(err));
139          }
140          value
141              .as_ref()
142              .map_err(|err_str| OneTimeError::Copy(anyhow::Error::msg(err_str.to_owned())))
143      }
144  }
145  impl<T> JitCore<T, Infallible>
146  where
147      T: MaybeSend + 'static,
148  {
149      pub fn new<Fut>(f: impl FnOnce() -> Fut + 'static + MaybeSend) -> Self
150      where
151          Fut: Future<Output = T> + 'static + MaybeSend,
152          T: 'static,
153      {
154          Self::new_try(|| async { Ok(f().await) })
155      }
156  
157      pub async fn get(&self) -> &T {
158          self.get_try().await.expect("can't fail")
159      }
160  }
161  #[cfg(test)]
162  mod tests {
163      use std::time::Duration;
164  
165      use anyhow::bail;
166  
167      use super::{Jit, JitTry, JitTryAnyhow};
168  
169      #[test_log::test(tokio::test)]
170      async fn sanity_jit() {
171          let v = Jit::new(|| async {
172              fedimint_core::runtime::sleep(Duration::from_millis(0)).await;
173              3
174          });
175  
176          assert_eq!(*v.get().await, 3);
177          assert_eq!(*v.get().await, 3);
178          assert_eq!(*v.clone().get().await, 3);
179      }
180  
181      #[test_log::test(tokio::test)]
182      async fn sanity_jit_try_ok() {
183          let v = JitTryAnyhow::new_try(|| async {
184              fedimint_core::runtime::sleep(Duration::from_millis(0)).await;
185              Ok(3)
186          });
187  
188          assert_eq!(*v.get_try().await.expect("ok"), 3);
189          assert_eq!(*v.get_try().await.expect("ok"), 3);
190          assert_eq!(*v.clone().get_try().await.expect("ok"), 3);
191      }
192  
193      #[test_log::test(tokio::test)]
194      async fn sanity_jit_try_err() {
195          let v = JitTry::new_try(|| async {
196              fedimint_core::runtime::sleep(Duration::from_millis(0)).await;
197              bail!("BOOM");
198              #[allow(unreachable_code)]
199              Ok(3)
200          });
201  
202          assert!(v.get_try().await.is_err());
203          assert!(v.get_try().await.is_err());
204          assert!(v.clone().get_try().await.is_err());
205      }
206  }