/ fedimint-client / src / sm / notifier.rs
notifier.rs
  1  use std::marker::PhantomData;
  2  use std::sync::Arc;
  3  
  4  use fedimint_core::core::{ModuleInstanceId, OperationId};
  5  use fedimint_core::db::{Database, IDatabaseTransactionOpsCoreTyped};
  6  use fedimint_core::util::broadcaststream::BroadcastStream;
  7  use fedimint_core::util::BoxStream;
  8  use futures::StreamExt;
  9  use tracing::{debug, error, trace};
 10  
 11  use crate::sm::executor::{
 12      ActiveModuleOperationStateKeyPrefix, ActiveStateKey, InactiveModuleOperationStateKeyPrefix,
 13      InactiveStateKey,
 14  };
 15  use crate::sm::{ActiveStateMeta, DynState, InactiveStateMeta, State};
 16  
 17  /// State transition notifier owned by the modularized client used to inform
 18  /// modules of state transitions.
 19  ///
 20  /// To not lose any state transitions that happen before a module subscribes to
 21  /// the operation the notifier loads all belonging past state transitions from
 22  /// the DB. State transitions may be reported multiple times and out of order.
 23  #[derive(Clone)]
 24  pub struct Notifier {
 25      /// Broadcast channel used to send state transitions to all subscribers
 26      broadcast: tokio::sync::broadcast::Sender<DynState>,
 27      /// Database used to load all states that happened before subscribing
 28      db: Database,
 29  }
 30  
 31  impl Notifier {
 32      pub fn new(db: Database) -> Self {
 33          let (sender, _receiver) = tokio::sync::broadcast::channel(10_000);
 34          Self {
 35              broadcast: sender,
 36              db,
 37          }
 38      }
 39  
 40      /// Notify all subscribers of a state transition
 41      pub fn notify(&self, state: DynState) {
 42          let queue_len = self.broadcast.len();
 43          trace!(?state, %queue_len, "Sending notification about state transition");
 44          // FIXME: use more robust notification mechanism
 45          if let Err(e) = self.broadcast.send(state) {
 46              debug!(
 47                  ?e,
 48                  %queue_len,
 49                  receivers=self.broadcast.receiver_count(),
 50                  "Could not send state transition notification, no active receivers"
 51              );
 52          }
 53      }
 54  
 55      /// Create a new notifier for a specific module instance that can only
 56      /// subscribe to the instance's state transitions
 57      pub fn module_notifier<S>(&self, module_instance: ModuleInstanceId) -> ModuleNotifier<S> {
 58          ModuleNotifier {
 59              broadcast: self.broadcast.clone(),
 60              module_instance,
 61              db: self.db.clone(),
 62              _pd: Default::default(),
 63          }
 64      }
 65  
 66      /// Create a [`NotifierSender`] handle that lets the owner trigger
 67      /// notifications without having to hold a full `Notifier`.
 68      pub fn sender(&self) -> NotifierSender {
 69          NotifierSender {
 70              sender: self.broadcast.clone(),
 71          }
 72      }
 73  }
 74  
 75  /// Notifier send handle that can be shared to places where we don't need an
 76  /// entire [`Notifier`] but still need to trigger notifications. The main use
 77  /// case is triggering notifications when a DB transaction was committed
 78  /// successfully.
 79  pub struct NotifierSender {
 80      sender: tokio::sync::broadcast::Sender<DynState>,
 81  }
 82  
 83  impl NotifierSender {
 84      /// Notify all subscribers of a state transition
 85      pub fn notify(&self, state: DynState) {
 86          let _res = self.sender.send(state);
 87      }
 88  }
 89  
 90  /// State transition notifier for a specific module instance that can only
 91  /// subscribe to transitions belonging to that module
 92  #[derive(Debug, Clone)]
 93  pub struct ModuleNotifier<S> {
 94      broadcast: tokio::sync::broadcast::Sender<DynState>,
 95      module_instance: ModuleInstanceId,
 96      /// Database used to load all states that happened before subscribing, see
 97      /// [`Notifier`]
 98      db: Database,
 99      /// `S` limits the type of state that can be subscribed to the one
100      /// associated with the module instance
101      _pd: PhantomData<S>,
102  }
103  
104  impl<S> ModuleNotifier<S>
105  where
106      S: State,
107  {
108      // TODO: remove duplicates and order old transitions
109      /// Subscribe to state transitions belonging to an operation and module
110      /// (module context contained in struct).
111      ///
112      /// The returned stream will contain all past state transitions that
113      /// happened before the subscription and are read from the database, after
114      /// these the stream will contain all future state transitions. The states
115      /// loaded from the database are not returned in a specific order. There may
116      /// also be duplications.
117      pub async fn subscribe(&self, operation_id: OperationId) -> BoxStream<'static, S> {
118          let to_typed_state = |state: DynState| {
119              state
120                  .as_any()
121                  .downcast_ref::<S>()
122                  .expect("Tried to subscribe to wrong state type")
123                  .clone()
124          };
125  
126          // It's important to start the subscription first and then query the database to
127          // not lose any transitions in the meantime.
128          let new_transitions = self.subscribe_all_operations().await;
129  
130          let db_states = {
131              let mut dbtx = self.db.begin_transaction().await;
132              let active_states = dbtx
133                  .find_by_prefix(&ActiveModuleOperationStateKeyPrefix {
134                      operation_id,
135                      module_instance: self.module_instance,
136                  })
137                  .await
138                  .map(|(key, val): (ActiveStateKey, ActiveStateMeta)| {
139                      (to_typed_state(key.state), val.created_at)
140                  })
141                  .collect::<Vec<(S, _)>>()
142                  .await;
143  
144              let inactive_states = dbtx
145                  .find_by_prefix(&InactiveModuleOperationStateKeyPrefix {
146                      operation_id,
147                      module_instance: self.module_instance,
148                  })
149                  .await
150                  .map(|(key, val): (InactiveStateKey, InactiveStateMeta)| {
151                      (to_typed_state(key.state), val.created_at)
152                  })
153                  .collect::<Vec<(S, _)>>()
154                  .await;
155  
156              // FIXME: don't rely on SystemTime for ordering and introduce a state transition
157              // index instead (dpc was right again xD)
158              let mut all_states_timed = active_states
159                  .into_iter()
160                  .chain(inactive_states)
161                  .collect::<Vec<(S, _)>>();
162              all_states_timed.sort_by(|(_, t1), (_, t2)| t1.cmp(t2));
163              debug!(
164                  %operation_id,
165                  "Returning {} state transitions from DB for notifier subscription",
166                  all_states_timed.len()
167              );
168              all_states_timed
169                  .into_iter()
170                  .map(|(s, _)| s)
171                  .collect::<Vec<S>>()
172          };
173  
174          let new_transitions = new_transitions.filter_map({
175              let db_states: Arc<_> = Arc::new(db_states.clone());
176  
177              move |state: S| {
178                  let db_states = db_states.clone();
179                  async move {
180                      if state.operation_id() == operation_id {
181                          trace!(%operation_id, ?state, "Received state transition notification");
182                          // Deduplicate events that might have both come from the DB and streamed,
183                          // due to subscribing to notifier before querying the DB.
184                          //
185                          // Note: linear search should be good enough in practice for many reasons.
186                          // Eg. states tend to have all the states in the DB, or all streamed "live",
187                          // so the overlap here should be minimal.
188                          // And we'll rewrite the whole thing anyway and use only db as a reference.
189                          if db_states.iter().any(|db_s| db_s == &state) {
190                              debug!(%operation_id, ?state, "Ignoring duplicated event");
191                              return None;
192                          }
193                          Some(state)
194                      } else {
195                          None
196                      }
197                  }
198              }
199          });
200          Box::pin(futures::stream::iter(db_states).chain(new_transitions))
201      }
202  
203      /// Subscribe to all state transitions belonging to the module instance.
204      pub async fn subscribe_all_operations(&self) -> BoxStream<'static, S> {
205          let module_instance_id = self.module_instance;
206          Box::pin(
207              BroadcastStream::new(self.broadcast.subscribe())
208                  .take_while(|res| {
209                      let cont = if let Err(err) = res {
210                          error!(?err, "ModuleNotifier stream stopped on error");
211                          false
212                      } else {
213                          true
214                      };
215                      std::future::ready(cont)
216                  })
217                  .filter_map(move |res| async move {
218                      let s = res.expect("We filtered out errors above");
219                      if s.module_instance_id() == module_instance_id {
220                          Some(
221                              s.as_any()
222                                  .downcast_ref::<S>()
223                                  .expect("Tried to subscribe to wrong state type")
224                                  .clone(),
225                          )
226                      } else {
227                          None
228                      }
229                  }),
230          )
231      }
232  }