jit.rs
1 use std::convert::Infallible; 2 use std::sync::Arc; 3 use std::{fmt, panic}; 4 5 use fedimint_core::runtime::JoinHandle; 6 use fedimint_logging::LOG_TASK; 7 use futures::Future; 8 use tokio::sync; 9 use tracing::warn; 10 11 use super::MaybeSend; 12 13 pub type Jit<T> = JitCore<T, Infallible>; 14 pub type JitTry<T, E> = JitCore<T, E>; 15 pub type JitTryAnyhow<T> = JitCore<T, anyhow::Error>; 16 17 /// Error that could have been returned before 18 /// 19 /// Newtype over `Option<E>` that allows better user (error conversion mostly) 20 /// experience 21 #[derive(Debug)] 22 pub enum OneTimeError<E> { 23 Original(E), 24 Copy(anyhow::Error), 25 } 26 27 impl<E> std::error::Error for OneTimeError<E> 28 where 29 E: fmt::Debug, 30 E: fmt::Display, 31 { 32 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { 33 None 34 } 35 36 fn cause(&self) -> Option<&dyn std::error::Error> { 37 self.source() 38 } 39 } 40 41 impl<E> fmt::Display for OneTimeError<E> 42 where 43 E: fmt::Display, 44 { 45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 46 match self { 47 OneTimeError::Original(o) => o.fmt(f), 48 OneTimeError::Copy(c) => c.fmt(f), 49 } 50 } 51 } 52 53 /// A value that initializes eagerly in parallel in a falliable way 54 #[derive(Debug)] 55 pub struct JitCore<T, E> { 56 inner: Arc<JitInner<T, E>>, 57 } 58 59 #[derive(Debug)] 60 struct JitInner<T, E> { 61 handle: sync::Mutex<JoinHandle<Result<T, E>>>, 62 val: sync::OnceCell<Result<T, String>>, 63 } 64 65 impl<T, E> Clone for JitCore<T, E> 66 where 67 T: Clone, 68 { 69 fn clone(&self) -> Self { 70 Self { 71 inner: self.inner.clone(), 72 } 73 } 74 } 75 impl<T, E> Drop for JitInner<T, E> { 76 fn drop(&mut self) { 77 self.handle.get_mut().abort(); 78 } 79 } 80 impl<T, E> JitCore<T, E> 81 where 82 T: MaybeSend + 'static, 83 E: MaybeSend + 'static + fmt::Display, 84 { 85 /// Create `JitTry` value, and spawn a future `f` that computes its value 86 /// 87 /// Unlike normal Rust futures, the `f` executes eagerly (is spawned as a 88 /// tokio task). 89 pub fn new_try<Fut>(f: impl FnOnce() -> Fut + 'static + MaybeSend) -> Self 90 where 91 Fut: Future<Output = std::result::Result<T, E>> + 'static + MaybeSend, 92 { 93 let handle = crate::runtime::spawn("jit-value", async move { f().await }); 94 95 Self { 96 inner: JitInner { 97 handle: handle.into(), 98 val: sync::OnceCell::new(), 99 } 100 .into(), 101 } 102 } 103 104 /// Get the reference to the value, potentially blocking for the 105 /// initialization future to complete 106 pub async fn get_try(&self) -> Result<&T, OneTimeError<E>> { 107 let mut init_error = None; 108 let value = self 109 .inner 110 .val 111 .get_or_init(|| async { 112 let handle: &mut _ = &mut *self.inner.handle.lock().await; 113 match handle.await { 114 Ok(Ok(o)) => Ok(o), 115 Ok(Err(err)) => { 116 let err_str = err.to_string(); 117 init_error = Some(err); 118 Err(err_str) 119 }, 120 Err(err) => { 121 122 #[cfg(not(target_family = "wasm"))] 123 if err.is_panic() { 124 warn!(target: LOG_TASK, %err, type_name = %std::any::type_name::<T>(), "Jit value panicked"); 125 // Resume the panic on the main task 126 panic::resume_unwind(err.into_panic()); 127 } 128 #[cfg(not(target_family = "wasm"))] 129 if err.is_cancelled() { 130 warn!(target: LOG_TASK, %err, type_name = %std::any::type_name::<T>(), "Jit value task canceled:"); 131 } 132 Err(format!("Jit value {} failed unexpectedly with: {}", std::any::type_name::<T>(), err)) 133 }, 134 } 135 }) 136 .await; 137 if let Some(err) = init_error { 138 return Err(OneTimeError::Original(err)); 139 } 140 value 141 .as_ref() 142 .map_err(|err_str| OneTimeError::Copy(anyhow::Error::msg(err_str.to_owned()))) 143 } 144 } 145 impl<T> JitCore<T, Infallible> 146 where 147 T: MaybeSend + 'static, 148 { 149 pub fn new<Fut>(f: impl FnOnce() -> Fut + 'static + MaybeSend) -> Self 150 where 151 Fut: Future<Output = T> + 'static + MaybeSend, 152 T: 'static, 153 { 154 Self::new_try(|| async { Ok(f().await) }) 155 } 156 157 pub async fn get(&self) -> &T { 158 self.get_try().await.expect("can't fail") 159 } 160 } 161 #[cfg(test)] 162 mod tests { 163 use std::time::Duration; 164 165 use anyhow::bail; 166 167 use super::{Jit, JitTry, JitTryAnyhow}; 168 169 #[test_log::test(tokio::test)] 170 async fn sanity_jit() { 171 let v = Jit::new(|| async { 172 fedimint_core::runtime::sleep(Duration::from_millis(0)).await; 173 3 174 }); 175 176 assert_eq!(*v.get().await, 3); 177 assert_eq!(*v.get().await, 3); 178 assert_eq!(*v.clone().get().await, 3); 179 } 180 181 #[test_log::test(tokio::test)] 182 async fn sanity_jit_try_ok() { 183 let v = JitTryAnyhow::new_try(|| async { 184 fedimint_core::runtime::sleep(Duration::from_millis(0)).await; 185 Ok(3) 186 }); 187 188 assert_eq!(*v.get_try().await.expect("ok"), 3); 189 assert_eq!(*v.get_try().await.expect("ok"), 3); 190 assert_eq!(*v.clone().get_try().await.expect("ok"), 3); 191 } 192 193 #[test_log::test(tokio::test)] 194 async fn sanity_jit_try_err() { 195 let v = JitTry::new_try(|| async { 196 fedimint_core::runtime::sleep(Duration::from_millis(0)).await; 197 bail!("BOOM"); 198 #[allow(unreachable_code)] 199 Ok(3) 200 }); 201 202 assert!(v.get_try().await.is_err()); 203 assert!(v.get_try().await.is_err()); 204 assert!(v.clone().get_try().await.is_err()); 205 } 206 }