/ fedimint-client / src / sm / state.rs
state.rs
  1  use std::any::Any;
  2  use std::fmt::Debug;
  3  use std::future::Future;
  4  use std::hash;
  5  use std::io::{Error, Read, Write};
  6  use std::pin::Pin;
  7  use std::sync::Arc;
  8  
  9  use fedimint_core::core::{IntoDynInstance, ModuleInstanceId, OperationId};
 10  use fedimint_core::encoding::{Decodable, DecodeError, DynEncodable, Encodable};
 11  use fedimint_core::module::registry::ModuleDecoderRegistry;
 12  use fedimint_core::task::{MaybeSend, MaybeSync};
 13  use fedimint_core::util::BoxFuture;
 14  use fedimint_core::{maybe_add_send, maybe_add_send_sync, module_plugin_dyn_newtype_define};
 15  
 16  use crate::sm::ClientSMDatabaseTransaction;
 17  use crate::DynGlobalClientContext;
 18  
 19  /// Implementors act as state machines that can be executed
 20  pub trait State:
 21      Debug
 22      + Clone
 23      + Eq
 24      + PartialEq
 25      + std::hash::Hash
 26      + Encodable
 27      + Decodable
 28      + MaybeSend
 29      + MaybeSync
 30      + 'static
 31  {
 32      /// Additional resources made available in this module's state transitions
 33      type ModuleContext: Context;
 34  
 35      /// All possible transitions from the current state to other states. See
 36      /// [`StateTransition`] for details.
 37      fn transitions(
 38          &self,
 39          context: &Self::ModuleContext,
 40          global_context: &DynGlobalClientContext,
 41      ) -> Vec<StateTransition<Self>>;
 42  
 43      // TODO: move out of this interface into wrapper struct (see OperationState)
 44      /// Operation this state machine belongs to. See [`OperationId`] for
 45      /// details.
 46      fn operation_id(&self) -> OperationId;
 47  }
 48  
 49  /// Object-safe version of [`State`]
 50  pub trait IState: Debug + DynEncodable + MaybeSend + MaybeSync {
 51      fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
 52  
 53      /// All possible transitions from the state
 54      fn transitions(
 55          &self,
 56          context: &DynContext,
 57          global_context: &DynGlobalClientContext,
 58      ) -> Vec<StateTransition<DynState>>;
 59  
 60      /// Operation this state machine belongs to. See [`OperationId`] for
 61      /// details.
 62      fn operation_id(&self) -> OperationId;
 63  
 64      /// Clone state
 65      fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState;
 66  
 67      fn erased_eq_no_instance_id(&self, other: &DynState) -> bool;
 68  
 69      fn erased_hash_no_instance_id(&self, hasher: &mut dyn std::hash::Hasher);
 70  }
 71  
 72  /// Something that can be a [`DynContext`] for a state machine
 73  ///
 74  /// General purpose code should use [`DynContext`] instead
 75  pub trait IContext: Debug {
 76      fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any));
 77  }
 78  
 79  module_plugin_dyn_newtype_define! {
 80      /// A shared context for a module client state machine
 81      #[derive(Clone)]
 82      pub DynContext(Arc<IContext>)
 83  }
 84  
 85  /// Additional data made available to state machines of a module (e.g. API
 86  /// clients)
 87  pub trait Context: std::fmt::Debug + MaybeSend + MaybeSync + 'static {}
 88  
 89  impl Context for () {}
 90  
 91  /// Type-erased version of [`Context`]
 92  impl<T> IContext for T
 93  where
 94      T: Context + 'static + MaybeSend + MaybeSync,
 95  {
 96      fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
 97          self
 98      }
 99  }
100  
101  type TriggerFuture = Pin<Box<maybe_add_send!(dyn Future<Output = serde_json::Value> + 'static)>>;
102  
103  // TODO: remove Arc, maybe make it a fn pointer?
104  pub(super) type StateTransitionFunction<S> = Arc<
105      maybe_add_send_sync!(
106          dyn for<'a> Fn(
107              &'a mut ClientSMDatabaseTransaction<'_, '_>,
108              serde_json::Value,
109              S,
110          ) -> BoxFuture<'a, S>
111      ),
112  >;
113  
114  /// Represents one or multiple possible state transitions triggered in a common
115  /// way
116  pub struct StateTransition<S> {
117      /// Future that will block until a state transition is possible.
118      ///
119      /// **The trigger future must be idempotent since it might be re-run if the
120      /// client is restarted.**
121      ///
122      /// To wait for a possible state transition it can query external APIs,
123      /// subscribe to events emitted by other state machines, etc.
124      /// Optionally, it can also return some data that will be given to the
125      /// state transition function, see the `transition` docs for details.
126      pub trigger: TriggerFuture,
127      /// State transition function that, using the output of the `trigger`,
128      /// performs the appropriate state transition.
129      ///
130      /// **This function shall not block on network IO or similar things as all
131      /// actual state transitions are run serially.**
132      ///
133      /// Since the this function can return different output states depending on
134      /// the `Value` returned by the `trigger` future it can be used to model
135      /// multiple possible state transition at once. E.g. instead of having
136      /// two state transitions querying the same API endpoint and each waiting
137      /// for a specific value to be returned to trigger their respective state
138      /// transition we can have one `trigger` future querying the API and
139      /// depending on the return value run different state transitions,
140      /// saving network requests.
141      pub transition: StateTransitionFunction<S>,
142  }
143  
144  impl<S> StateTransition<S> {
145      /// Creates a new `StateTransition` where the `trigger` future returns a
146      /// value of type `V` that is then given to the `transition` function.
147      pub fn new<V, Trigger, TransitionFn>(
148          trigger: Trigger,
149          transition: TransitionFn,
150      ) -> StateTransition<S>
151      where
152          S: MaybeSend + MaybeSync + Clone + 'static,
153          V: serde::Serialize + serde::de::DeserializeOwned + Send,
154          Trigger: Future<Output = V> + MaybeSend + 'static,
155          TransitionFn: for<'a> Fn(&'a mut ClientSMDatabaseTransaction<'_, '_>, V, S) -> BoxFuture<'a, S>
156              + MaybeSend
157              + MaybeSync
158              + Clone
159              + 'static,
160      {
161          StateTransition {
162              trigger: Box::pin(async move {
163                  let val = trigger.await;
164                  serde_json::to_value(val).expect("Value could not be serialized")
165              }),
166              transition: Arc::new(move |dbtx, val, state| {
167                  let transition = transition.clone();
168                  Box::pin(async move {
169                      let typed_val: V = serde_json::from_value(val)
170                          .expect("Deserialize trigger return value failed");
171                      transition(dbtx, typed_val, state.clone()).await
172                  })
173              }),
174          }
175      }
176  }
177  
178  impl<T> IState for T
179  where
180      T: State,
181  {
182      fn as_any(&self) -> &(maybe_add_send_sync!(dyn Any)) {
183          self
184      }
185  
186      fn transitions(
187          &self,
188          context: &DynContext,
189          global_context: &DynGlobalClientContext,
190      ) -> Vec<StateTransition<DynState>> {
191          <T as State>::transitions(
192              self,
193              context.as_any().downcast_ref().expect("Wrong module"),
194              global_context,
195          )
196          .into_iter()
197          .map(|st| StateTransition {
198              trigger: st.trigger,
199              transition: Arc::new(
200                  move |dbtx: &mut ClientSMDatabaseTransaction<'_, '_>, val, state: DynState| {
201                      let transition = st.transition.clone();
202                      Box::pin(async move {
203                          let new_state = transition(
204                              dbtx,
205                              val,
206                              state
207                                  .as_any()
208                                  .downcast_ref::<T>()
209                                  .expect("Wrong module")
210                                  .clone(),
211                          )
212                          .await;
213                          DynState::from_typed(state.module_instance_id(), new_state)
214                      })
215                  },
216              ),
217          })
218          .collect()
219      }
220  
221      fn operation_id(&self) -> OperationId {
222          <T as State>::operation_id(self)
223      }
224  
225      fn clone(&self, module_instance_id: ModuleInstanceId) -> DynState {
226          DynState::from_typed(module_instance_id, <T as Clone>::clone(self))
227      }
228  
229      fn erased_eq_no_instance_id(&self, other: &DynState) -> bool {
230          let other: &T = other
231              .as_any()
232              .downcast_ref()
233              .expect("Type is ensured in previous step");
234  
235          self == other
236      }
237  
238      fn erased_hash_no_instance_id(&self, mut hasher: &mut dyn std::hash::Hasher) {
239          self.hash(&mut hasher);
240      }
241  }
242  
243  /// A type-erased state of a state machine belonging to a module instance, see
244  /// [`State`]
245  pub struct DynState(
246      Box<maybe_add_send_sync!(dyn IState + 'static)>,
247      ModuleInstanceId,
248  );
249  
250  impl std::ops::Deref for DynState {
251      type Target = maybe_add_send_sync!(dyn IState + 'static);
252  
253      fn deref(&self) -> &<Self as std::ops::Deref>::Target {
254          &*self.0
255      }
256  }
257  
258  impl hash::Hash for DynState {
259      fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
260          self.1.hash(hasher);
261          self.0.erased_hash_no_instance_id(hasher);
262      }
263  }
264  
265  impl DynState {
266      pub fn module_instance_id(&self) -> ModuleInstanceId {
267          self.1
268      }
269  
270      pub fn from_typed<I>(module_instance_id: ModuleInstanceId, typed: I) -> Self
271      where
272          I: IState + 'static,
273      {
274          Self(Box::new(typed), module_instance_id)
275      }
276  
277      pub fn from_parts(
278          module_instance_id: ::fedimint_core::core::ModuleInstanceId,
279          dynbox: Box<maybe_add_send_sync!(dyn IState + 'static)>,
280      ) -> Self {
281          Self(dynbox, module_instance_id)
282      }
283  }
284  
285  impl std::fmt::Debug for DynState {
286      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
287          std::fmt::Debug::fmt(&self.0, f)
288      }
289  }
290  
291  impl std::ops::DerefMut for DynState {
292      fn deref_mut(&mut self) -> &mut <Self as std::ops::Deref>::Target {
293          &mut *self.0
294      }
295  }
296  
297  impl Clone for DynState {
298      fn clone(&self) -> Self {
299          self.0.clone(self.1)
300      }
301  }
302  
303  impl PartialEq for DynState {
304      fn eq(&self, other: &Self) -> bool {
305          if self.1 != other.1 {
306              return false;
307          }
308          self.erased_eq_no_instance_id(other)
309      }
310  }
311  impl Eq for DynState {}
312  
313  impl Encodable for DynState {
314      fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
315          self.1.consensus_encode(writer)?;
316          self.0.consensus_encode_dyn(writer)
317      }
318  }
319  impl Decodable for DynState {
320      fn consensus_decode<R: std::io::Read>(
321          reader: &mut R,
322          decoders: &::fedimint_core::module::registry::ModuleDecoderRegistry,
323      ) -> Result<Self, fedimint_core::encoding::DecodeError> {
324          let module_id = fedimint_core::core::ModuleInstanceId::consensus_decode(reader, decoders)?;
325          decoders
326              .get_expect(module_id)
327              .decode_partial(reader, module_id, decoders)
328      }
329  }
330  
331  impl DynState {
332      /// `true` if this state allows no further transitions
333      pub fn is_terminal(
334          &self,
335          context: &DynContext,
336          global_context: &DynGlobalClientContext,
337      ) -> bool {
338          self.transitions(context, global_context).is_empty()
339      }
340  }
341  
342  #[derive(Debug)]
343  pub struct OperationState<S> {
344      pub operation_id: OperationId,
345      pub state: S,
346  }
347  
348  /// Wrapper for states that don't want to carry around their operation id. `S`
349  /// is allowed to panic when `operation_id` is called.
350  impl<S> State for OperationState<S>
351  where
352      S: State,
353  {
354      type ModuleContext = S::ModuleContext;
355  
356      fn transitions(
357          &self,
358          context: &Self::ModuleContext,
359          global_context: &DynGlobalClientContext,
360      ) -> Vec<StateTransition<Self>> {
361          let transitions: Vec<StateTransition<OperationState<S>>> = self
362              .state
363              .transitions(context, global_context)
364              .into_iter()
365              .map(
366                  |StateTransition {
367                       trigger,
368                       transition,
369                   }| {
370                      let op_transition: StateTransitionFunction<Self> =
371                          Arc::new(move |dbtx, value, op_state| {
372                              let transition = transition.clone();
373                              Box::pin(async move {
374                                  let state = transition(dbtx, value, op_state.state).await;
375                                  OperationState {
376                                      operation_id: op_state.operation_id,
377                                      state,
378                                  }
379                              })
380                          });
381  
382                      StateTransition {
383                          trigger,
384                          transition: op_transition,
385                      }
386                  },
387              )
388              .collect();
389          transitions
390      }
391  
392      fn operation_id(&self) -> OperationId {
393          self.operation_id
394      }
395  }
396  
397  // TODO: can we get rid of `GC`? Maybe make it an associated type of `State`
398  // instead?
399  impl<S> IntoDynInstance for OperationState<S>
400  where
401      S: State,
402  {
403      type DynType = DynState;
404  
405      fn into_dyn(self, instance_id: ModuleInstanceId) -> Self::DynType {
406          DynState::from_typed(instance_id, self)
407      }
408  }
409  
410  impl<S> Encodable for OperationState<S>
411  where
412      S: State,
413  {
414      fn consensus_encode<W: Write>(&self, writer: &mut W) -> Result<usize, Error> {
415          let mut len = 0;
416          len += self.operation_id.consensus_encode(writer)?;
417          len += self.state.consensus_encode(writer)?;
418          Ok(len)
419      }
420  }
421  
422  impl<S> Decodable for OperationState<S>
423  where
424      S: State,
425  {
426      fn consensus_decode<R: Read>(
427          read: &mut R,
428          modules: &ModuleDecoderRegistry,
429      ) -> Result<Self, DecodeError> {
430          let operation_id = OperationId::consensus_decode(read, modules)?;
431          let state = S::consensus_decode(read, modules)?;
432  
433          Ok(OperationState {
434              operation_id,
435              state,
436          })
437      }
438  }
439  
440  // TODO: derive after getting rid of `GC` type arg
441  impl<S> PartialEq for OperationState<S>
442  where
443      S: State,
444  {
445      fn eq(&self, other: &Self) -> bool {
446          self.operation_id.eq(&other.operation_id) && self.state.eq(&other.state)
447      }
448  }
449  
450  impl<S> Eq for OperationState<S> where S: State {}
451  
452  impl<S> hash::Hash for OperationState<S>
453  where
454      S: hash::Hash,
455  {
456      fn hash<H: hash::Hasher>(&self, hasher: &mut H) {
457          self.operation_id.hash(hasher);
458          self.state.hash(hasher);
459      }
460  }
461  
462  impl<S> Clone for OperationState<S>
463  where
464      S: State,
465  {
466      fn clone(&self) -> Self {
467          OperationState {
468              operation_id: self.operation_id,
469              state: self.state.clone(),
470          }
471      }
472  }