state.rs
1 use std::any::Any; 2 use std::fmt::Debug; 3 use std::future::Future; 4 use std::hash; 5 use std::io::{Error, Read, Write}; 6 use std::pin::Pin; 7 use std::sync::Arc; 8 9 use fedimint_core::core::{IntoDynInstance, ModuleInstanceId, OperationId}; 10 use fedimint_core::encoding::{Decodable, DecodeError, DynEncodable, Encodable}; 11 use fedimint_core::module::registry::ModuleDecoderRegistry; 12 use fedimint_core::task::{MaybeSend, MaybeSync}; 13 use fedimint_core::util::BoxFuture; 14 use fedimint_core::{maybe_add_send, maybe_add_send_sync, module_plugin_dyn_newtype_define}; 15 16 use crate::sm::ClientSMDatabaseTransaction; 17 use crate::DynGlobalClientContext; 18 19 /// Implementors act as state machines that can be executed 20 pub trait State: 21 Debug 22 + Clone 23 + Eq 24 + PartialEq 25 + std::hash::Hash 26 + Encodable 27 + Decodable 28 + MaybeSend 29 + MaybeSync 30 + 'static 31 { 32 /// Additional resources made available in this module's state transitions 33 type ModuleContext: Context; 34 35 /// All possible transitions from the current state to other states. See 36 /// [`StateTransition`] for details. 37 fn transitions( 38 &self, 39 context: &Self::ModuleContext, 40 global_context: &DynGlobalClientContext, 41 ) -> Vec<StateTransition<Self>>; 42 43 // TODO: move out of this interface into wrapper struct (see OperationState) 44 /// Operation this state machine belongs to. See [`OperationId`] for 45 /// details. 46 fn operation_id(&self) -> OperationId; 47 } 48 49 /// Object-safe version of [`State`] 50 pub trait IState: Debug + DynEncodable + MaybeSend + MaybeSync { 51 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)); 52 53 /// All possible transitions from the state 54 fn transitions( 55 &self, 56 context: &DynContext, 57 global_context: &DynGlobalClientContext, 58 ) -> Vec<StateTransition<DynState>>; 59 60 /// Operation this state machine belongs to. See [`OperationId`] for 61 /// details. 62 fn operation_id(&self) -> OperationId; 63 64 /// Clone state 65 fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState; 66 67 fn erased_eq_no_instance_id(&self, other: &DynState) -> bool; 68 69 fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher); 70 } 71 72 /// Something that can be a [`DynContext`] for a state machine 73 /// 74 /// General purpose code should use [`DynContext`] instead 75 pub trait IContext: Debug { 76 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)); 77 } 78 79 module_plugin_dyn_newtype_define! { 80 /// A shared context for a module client state machine 81 #[derive(Clone)] 82 pub DynContext(Arc<IContext>) 83 } 84 85 /// Additional data made available to state machines of a module (e.g. API 86 /// clients) 87 pub trait Context: std::fmt::Debug + MaybeSend + MaybeSync + 'static {} 88 89 impl Context for () {} 90 91 /// Type-erased version of [`Context`] 92 impl<T> IContext for T 93 where 94 T: Context + 'static + MaybeSend + MaybeSync, 95 { 96 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) { 97 self 98 } 99 } 100 101 type TriggerFuture = Pin<Box<maybe_add_send!(dyn Future<Output = serde_json::Value> + 'static)>>; 102 103 // TODO: remove Arc, maybe make it a fn pointer? 104 pub(super) type StateTransitionFunction<S> = Arc< 105 maybe_add_send_sync!( 106 dyn for<'a> Fn( 107 &'a mut ClientSMDatabaseTransaction<'_, '_>, 108 serde_json::Value, 109 S, 110 ) -> BoxFuture<'a, S> 111 ), 112 >; 113 114 /// Represents one or multiple possible state transitions triggered in a common 115 /// way 116 pub struct StateTransition<S> { 117 /// Future that will block until a state transition is possible. 118 /// 119 /// **The trigger future must be idempotent since it might be re-run if the 120 /// client is restarted.** 121 /// 122 /// To wait for a possible state transition it can query external APIs, 123 /// subscribe to events emitted by other state machines, etc. 124 /// Optionally, it can also return some data that will be given to the 125 /// state transition function, see the `transition` docs for details. 126 pub trigger: TriggerFuture, 127 /// State transition function that, using the output of the `trigger`, 128 /// performs the appropriate state transition. 129 /// 130 /// **This function shall not block on network IO or similar things as all 131 /// actual state transitions are run serially.** 132 /// 133 /// Since the this function can return different output states depending on 134 /// the `Value` returned by the `trigger` future it can be used to model 135 /// multiple possible state transition at once. E.g. instead of having 136 /// two state transitions querying the same API endpoint and each waiting 137 /// for a specific value to be returned to trigger their respective state 138 /// transition we can have one `trigger` future querying the API and 139 /// depending on the return value run different state transitions, 140 /// saving network requests. 141 pub transition: StateTransitionFunction<S>, 142 } 143 144 impl<S> StateTransition<S> { 145 /// Creates a new `StateTransition` where the `trigger` future returns a 146 /// value of type `V` that is then given to the `transition` function. 147 pub fn new<V, Trigger, TransitionFn>( 148 trigger: Trigger, 149 transition: TransitionFn, 150 ) -> StateTransition<S> 151 where 152 S: MaybeSend + MaybeSync + Clone + 'static, 153 V: serde::Serialize + serde::de::DeserializeOwned + Send, 154 Trigger: Future<Output = V> + MaybeSend + 'static, 155 TransitionFn: for<'a> Fn(&'a mut ClientSMDatabaseTransaction<'_, '_>, V, S) -> BoxFuture<'a, S> 156 + MaybeSend 157 + MaybeSync 158 + Clone 159 + 'static, 160 { 161 StateTransition { 162 trigger: Box::pin(async move { 163 let val = trigger.await; 164 serde_json::to_value(val).expect("Value could not be serialized") 165 }), 166 transition: Arc::new(move |dbtx, val, state| { 167 let transition = transition.clone(); 168 Box::pin(async move { 169 let typed_val: V = serde_json::from_value(val) 170 .expect("Deserialize trigger return value failed"); 171 transition(dbtx, typed_val, state.clone()).await 172 }) 173 }), 174 } 175 } 176 } 177 178 impl<T> IState for T 179 where 180 T: State, 181 { 182 fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) { 183 self 184 } 185 186 fn transitions( 187 &self, 188 context: &DynContext, 189 global_context: &DynGlobalClientContext, 190 ) -> Vec<StateTransition<DynState>> { 191 <T as State>::transitions( 192 self, 193 context.as_any().downcast_ref().expect("Wrong module"), 194 global_context, 195 ) 196 .into_iter() 197 .map(|st| StateTransition { 198 trigger: st.trigger, 199 transition: Arc::new( 200 move |dbtx: &mut ClientSMDatabaseTransaction<'_, '_>, val, state: DynState| { 201 let transition = st.transition.clone(); 202 Box::pin(async move { 203 let new_state = transition( 204 dbtx, 205 val, 206 state 207 .as_any() 208 .downcast_ref::<T>() 209 .expect("Wrong module") 210 .clone(), 211 ) 212 .await; 213 DynState::from_typed(state.module_instance_id(), new_state) 214 }) 215 }, 216 ), 217 }) 218 .collect() 219 } 220 221 fn operation_id(&self) -> OperationId { 222 <T as State>::operation_id(self) 223 } 224 225 fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState { 226 DynState::from_typed(module_instance_id, <T as Clone>::clone(self)) 227 } 228 229 fn erased_eq_no_instance_id(&self, other: &DynState) -> bool { 230 let other: &T = other 231 .as_any() 232 .downcast_ref() 233 .expect("Type is ensured in previous step"); 234 235 self == other 236 } 237 238 fn erased_hash_no_instance_id(&self, mut hasher: &mut dyn std::hash::Hasher) { 239 self.hash(&mut hasher); 240 } 241 } 242 243 /// A type-erased state of a state machine belonging to a module instance, see 244 /// [`State`] 245 pub struct DynState( 246 Box<maybe_add_send_sync!(dyn IState + 'static)>, 247 ModuleInstanceId, 248 ); 249 250 impl std::ops::Deref for DynState { 251 type Target = maybe_add_send_sync!(dyn IState + 'static); 252 253 fn deref(&self) -> &<Self as std::ops::Deref>::Target { 254 &*self.0 255 } 256 } 257 258 impl hash::Hash for DynState { 259 fn hash<H: hash::Hasher>(&self, hasher: &mut H) { 260 self.1.hash(hasher); 261 self.0.erased_hash_no_instance_id(hasher); 262 } 263 } 264 265 impl DynState { 266 pub fn module_instance_id(&self) -> ModuleInstanceId { 267 self.1 268 } 269 270 pub fn from_typed<I>(module_instance_id: ModuleInstanceId, typed: I) -> Self 271 where 272 I: IState + 'static, 273 { 274 Self(Box::new(typed), module_instance_id) 275 } 276 277 pub fn from_parts( 278 module_instance_id: ::fedimint_core::core::ModuleInstanceId, 279 dynbox: Box<maybe_add_send_sync!(dyn IState + 'static)>, 280 ) -> Self { 281 Self(dynbox, module_instance_id) 282 } 283 } 284 285 impl std::fmt::Debug for DynState { 286 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 287 std::fmt::Debug::fmt(&self.0, f) 288 } 289 } 290 291 impl std::ops::DerefMut for DynState { 292 fn deref_mut(&mut self) -> &mut <Self as std::ops::Deref>::Target { 293 &mut *self.0 294 } 295 } 296 297 impl Clone for DynState { 298 fn clone(&self) -> Self { 299 self.0.clone(self.1) 300 } 301 } 302 303 impl PartialEq for DynState { 304 fn eq(&self, other: &Self) -> bool { 305 if self.1 != other.1 { 306 return false; 307 } 308 self.erased_eq_no_instance_id(other) 309 } 310 } 311 impl Eq for DynState {} 312 313 impl Encodable for DynState { 314 fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> { 315 self.1.consensus_encode(writer)?; 316 self.0.consensus_encode_dyn(writer) 317 } 318 } 319 impl Decodable for DynState { 320 fn consensus_decode<R: std::io::Read>( 321 reader: &mut R, 322 decoders: &::fedimint_core::module::registry::ModuleDecoderRegistry, 323 ) -> Result<Self, fedimint_core::encoding::DecodeError> { 324 let module_id = fedimint_core::core::ModuleInstanceId::consensus_decode(reader, decoders)?; 325 decoders 326 .get_expect(module_id) 327 .decode_partial(reader, module_id, decoders) 328 } 329 } 330 331 impl DynState { 332 /// `true` if this state allows no further transitions 333 pub fn is_terminal( 334 &self, 335 context: &DynContext, 336 global_context: &DynGlobalClientContext, 337 ) -> bool { 338 self.transitions(context, global_context).is_empty() 339 } 340 } 341 342 #[derive(Debug)] 343 pub struct OperationState<S> { 344 pub operation_id: OperationId, 345 pub state: S, 346 } 347 348 /// Wrapper for states that don't want to carry around their operation id. `S` 349 /// is allowed to panic when `operation_id` is called. 350 impl<S> State for OperationState<S> 351 where 352 S: State, 353 { 354 type ModuleContext = S::ModuleContext; 355 356 fn transitions( 357 &self, 358 context: &Self::ModuleContext, 359 global_context: &DynGlobalClientContext, 360 ) -> Vec<StateTransition<Self>> { 361 let transitions: Vec<StateTransition<OperationState<S>>> = self 362 .state 363 .transitions(context, global_context) 364 .into_iter() 365 .map( 366 |StateTransition { 367 trigger, 368 transition, 369 }| { 370 let op_transition: StateTransitionFunction<Self> = 371 Arc::new(move |dbtx, value, op_state| { 372 let transition = transition.clone(); 373 Box::pin(async move { 374 let state = transition(dbtx, value, op_state.state).await; 375 OperationState { 376 operation_id: op_state.operation_id, 377 state, 378 } 379 }) 380 }); 381 382 StateTransition { 383 trigger, 384 transition: op_transition, 385 } 386 }, 387 ) 388 .collect(); 389 transitions 390 } 391 392 fn operation_id(&self) -> OperationId { 393 self.operation_id 394 } 395 } 396 397 // TODO: can we get rid of `GC`? Maybe make it an associated type of `State` 398 // instead? 399 impl<S> IntoDynInstance for OperationState<S> 400 where 401 S: State, 402 { 403 type DynType = DynState; 404 405 fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType { 406 DynState::from_typed(instance_id, self) 407 } 408 } 409 410 impl<S> Encodable for OperationState<S> 411 where 412 S: State, 413 { 414 fn consensus_encode<W: Write>(&self, writer: &mut W) -> Result<usize, Error> { 415 let mut len = 0; 416 len += self.operation_id.consensus_encode(writer)?; 417 len += self.state.consensus_encode(writer)?; 418 Ok(len) 419 } 420 } 421 422 impl<S> Decodable for OperationState<S> 423 where 424 S: State, 425 { 426 fn consensus_decode<R: Read>( 427 read: &mut R, 428 modules: &ModuleDecoderRegistry, 429 ) -> Result<Self, DecodeError> { 430 let operation_id = OperationId::consensus_decode(read, modules)?; 431 let state = S::consensus_decode(read, modules)?; 432 433 Ok(OperationState { 434 operation_id, 435 state, 436 }) 437 } 438 } 439 440 // TODO: derive after getting rid of `GC` type arg 441 impl<S> PartialEq for OperationState<S> 442 where 443 S: State, 444 { 445 fn eq(&self, other: &Self) -> bool { 446 self.operation_id.eq(&other.operation_id) && self.state.eq(&other.state) 447 } 448 } 449 450 impl<S> Eq for OperationState<S> where S: State {} 451 452 impl<S> hash::Hash for OperationState<S> 453 where 454 S: hash::Hash, 455 { 456 fn hash<H: hash::Hasher>(&self, hasher: &mut H) { 457 self.operation_id.hash(hasher); 458 self.state.hash(hasher); 459 } 460 } 461 462 impl<S> Clone for OperationState<S> 463 where 464 S: State, 465 { 466 fn clone(&self) -> Self { 467 OperationState { 468 operation_id: self.operation_id, 469 state: self.state.clone(), 470 } 471 } 472 }