/ node / src / message_storage.rs
message_storage.rs
  1  //! Store messages in mdbx
  2  
  3  use std::{marker::PhantomData, sync::Arc};
  4  
  5  use apibara_core::stream::{MessageData, RawMessageData, Sequence};
  6  use libmdbx::{Environment, EnvironmentKind, Error as MdbxError, Transaction, RO, RW};
  7  
  8  use crate::db::{tables, MdbxRWTransactionExt, MdbxTransactionExt, TableCursor};
  9  
 10  pub trait MessageStorage<M: MessageData> {
 11      type Error: std::error::Error + Send + Sync + 'static;
 12  
 13      /// Retrieves a message with the given sequencer number, if any.
 14      fn get(
 15          &self,
 16          sequence: &Sequence,
 17      ) -> std::result::Result<Option<RawMessageData<M>>, Self::Error>;
 18  }
 19  
 20  /// Store messages in mdbx.
 21  pub struct MdbxMessageStorage<E: EnvironmentKind, M: MessageData> {
 22      db: Arc<Environment<E>>,
 23      phantom: PhantomData<M>,
 24  }
 25  
 26  /// [MessageStorage]-related error.
 27  #[derive(Debug, thiserror::Error)]
 28  pub enum MdbxMessageStorageError {
 29      #[error("message has the wrong sequence number")]
 30      InvalidMessageSequence { expected: u64, actual: u64 },
 31      #[error("error originating from database")]
 32      Database(#[from] MdbxError),
 33  }
 34  
 35  pub type Result<T> = std::result::Result<T, MdbxMessageStorageError>;
 36  
 37  pub struct MessageIterator<'txn, E: EnvironmentKind, M: MessageData> {
 38      _txn: Transaction<'txn, RO, E>,
 39      current: Option<Result<M>>,
 40      cursor: TableCursor<'txn, tables::MessageTable<M>, RO>,
 41  }
 42  
 43  impl<E, M> MdbxMessageStorage<E, M>
 44  where
 45      E: EnvironmentKind,
 46      M: MessageData,
 47  {
 48      /// Create a new message store, persisting data to the given mdbx environment.
 49      pub fn new(db: Arc<Environment<E>>) -> Result<Self> {
 50          let txn = db.begin_rw_txn()?;
 51          txn.ensure_table::<tables::MessageTable<M>>(None)?;
 52          txn.commit()?;
 53          Ok(MdbxMessageStorage {
 54              db,
 55              phantom: PhantomData,
 56          })
 57      }
 58  
 59      /// Insert the given `message` in the store.
 60      ///
 61      /// Expect `sequence` to be the successor of the current highest sequence number.
 62      pub fn insert(&self, sequence: &Sequence, message: &M) -> Result<()> {
 63          let txn = self.db.begin_rw_txn()?;
 64          self.insert_with_txn(sequence, message, &txn)?;
 65          txn.commit()?;
 66          Ok(())
 67      }
 68  
 69      /// Same as `insert` but using the given [Transaction].
 70      pub fn insert_with_txn(
 71          &self,
 72          sequence: &Sequence,
 73          message: &M,
 74          txn: &Transaction<RW, E>,
 75      ) -> Result<()> {
 76          let table = txn.open_table::<tables::MessageTable<M>>()?;
 77          let mut cursor = table.cursor()?;
 78  
 79          match cursor.last()? {
 80              None => {
 81                  // First element, assert sequence is 0
 82                  if sequence.as_u64() != 0 {
 83                      return Err(MdbxMessageStorageError::InvalidMessageSequence {
 84                          expected: 0,
 85                          actual: sequence.as_u64(),
 86                      });
 87                  }
 88                  cursor.put(sequence, message)?;
 89                  Ok(())
 90              }
 91              Some((prev_sequence, _)) => {
 92                  if sequence.as_u64() != prev_sequence.as_u64() + 1 {
 93                      return Err(MdbxMessageStorageError::InvalidMessageSequence {
 94                          expected: prev_sequence.as_u64() + 1,
 95                          actual: sequence.as_u64(),
 96                      });
 97                  }
 98                  cursor.put(sequence, message)?;
 99                  Ok(())
100              }
101          }
102      }
103  
104      /// Delete all messages with sequence number greater than or equal the given `sequence`.
105      ///
106      /// Returns the number of messages deleted.
107      pub fn invalidate(&self, sequence: &Sequence) -> Result<usize> {
108          let txn = self.db.begin_rw_txn()?;
109          let invalidated = self.invalidate_with_txn(sequence, &txn)?;
110          txn.commit()?;
111          Ok(invalidated)
112      }
113  
114      /// Same as `invalidate` but using the given [Transaction].
115      pub fn invalidate_with_txn(
116          &self,
117          sequence: &Sequence,
118          txn: &Transaction<RW, E>,
119      ) -> Result<usize> {
120          let table = txn.open_table::<tables::MessageTable<M>>()?;
121          let mut cursor = table.cursor()?;
122  
123          let mut count = 0;
124          loop {
125              match cursor.last()? {
126                  None => break,
127                  Some((key, _)) => {
128                      if key.as_u64() < sequence.as_u64() {
129                          break;
130                      }
131                      cursor.del()?;
132                      count += 1;
133                  }
134              }
135          }
136          Ok(count)
137      }
138  
139      /// Returns an iterator over all messages, starting at the given `start` index.
140      pub fn iter_from(&self, start: &Sequence) -> Result<MessageIterator<'_, E, M>> {
141          let txn = self.db.begin_ro_txn()?;
142          let table = txn.open_table::<tables::MessageTable<M>>()?;
143          let mut cursor = table.cursor()?;
144          let current = cursor.seek_exact(start)?.map(|v| Ok(v.1));
145          Ok(MessageIterator {
146              cursor,
147              _txn: txn,
148              current,
149          })
150      }
151  }
152  
153  impl<E, M> MessageStorage<M> for MdbxMessageStorage<E, M>
154  where
155      E: EnvironmentKind,
156      M: MessageData,
157  {
158      type Error = MdbxMessageStorageError;
159  
160      fn get(&self, sequence: &Sequence) -> Result<Option<RawMessageData<M>>> {
161          let txn = self.db.begin_rw_txn()?;
162          let table = txn.open_table::<tables::MessageTable<M>>()?;
163          let mut cursor = table.cursor()?;
164          let data = cursor.seek_exact_raw(sequence)?.map(|t| t.1);
165          Ok(data)
166      }
167  }
168  
169  impl<E, M> MessageStorage<M> for Arc<MdbxMessageStorage<E, M>>
170  where
171      E: EnvironmentKind,
172      M: MessageData,
173  {
174      type Error = MdbxMessageStorageError;
175  
176      fn get(&self, sequence: &Sequence) -> Result<Option<RawMessageData<M>>> {
177          let txn = self.db.begin_rw_txn()?;
178          let table = txn.open_table::<tables::MessageTable<M>>()?;
179          let mut cursor = table.cursor()?;
180          let data = cursor.seek_exact_raw(sequence)?.map(|t| t.1);
181          Ok(data)
182      }
183  }
184  
185  impl<'txn, E, M> Iterator for MessageIterator<'txn, E, M>
186  where
187      E: EnvironmentKind,
188      M: MessageData,
189  {
190      type Item = Result<M>;
191  
192      fn next(&mut self) -> Option<Self::Item> {
193          match self.current.take() {
194              None => None,
195              Some(value) => {
196                  self.current = match self.cursor.next() {
197                      Err(err) => Some(Err(err.into())),
198                      Ok(None) => None,
199                      Ok(Some(value)) => Some(Ok(value.1)),
200                  };
201                  Some(value)
202              }
203          }
204      }
205  }
206  
207  #[cfg(test)]
208  mod tests {
209      use std::sync::Arc;
210  
211      use apibara_core::stream::Sequence;
212      use libmdbx::{Environment, NoWriteMap};
213      use tempfile::tempdir;
214  
215      use crate::db::MdbxEnvironmentExt;
216  
217      use super::MdbxMessageStorage;
218  
219      #[derive(Clone, PartialEq, prost::Message)]
220      pub struct Transfer {
221          #[prost(string, tag = "1")]
222          pub sender: String,
223          #[prost(string, tag = "2")]
224          pub receiver: String,
225      }
226  
227      #[test]
228      pub fn test_message_storage() {
229          let path = tempdir().unwrap();
230          let db = Environment::<NoWriteMap>::open(path.path()).unwrap();
231          let storage = MdbxMessageStorage::<_, Transfer>::new(Arc::new(db)).unwrap();
232  
233          // first message must have index 0
234          let t0_bad_sequence = Transfer {
235              sender: "ABC".to_string(),
236              receiver: "XYZ".to_string(),
237          };
238          assert!(storage
239              .insert(&Sequence::from_u64(1), &t0_bad_sequence)
240              .is_err());
241  
242          let t0 = Transfer {
243              sender: "ABC".to_string(),
244              receiver: "XYZ".to_string(),
245          };
246          storage.insert(&Sequence::from_u64(0), &t0).unwrap();
247  
248          // next message must have index 1
249          let t1 = Transfer {
250              sender: "AOE".to_string(),
251              receiver: "TNS".to_string(),
252          };
253          assert!(storage.insert(&Sequence::from_u64(0), &t1).is_err());
254          assert!(storage.insert(&Sequence::from_u64(2), &t1).is_err());
255          storage.insert(&Sequence::from_u64(1), &t1).unwrap();
256  
257          let all_messages = storage
258              .iter_from(&Sequence::from_u64(0))
259              .unwrap()
260              .collect::<Result<Vec<_>, _>>()
261              .unwrap();
262  
263          assert!(all_messages.len() == 2);
264          assert!(all_messages[0] == t0);
265          assert!(all_messages[1] == t1);
266  
267          // invalidate latest message
268          let count = storage.invalidate(&Sequence::from_u64(1)).unwrap();
269          assert!(count == 1);
270          // second time is a noop
271          let count = storage.invalidate(&Sequence::from_u64(1)).unwrap();
272          assert!(count == 0);
273  
274          let all_messages = storage
275              .iter_from(&Sequence::from_u64(0))
276              .unwrap()
277              .collect::<Result<Vec<_>, _>>()
278              .unwrap();
279          assert!(all_messages.len() == 1);
280          assert!(all_messages[0] == t0);
281  
282          // insert value again
283          assert!(storage.insert(&Sequence::from_u64(0), &t1).is_err());
284          assert!(storage.insert(&Sequence::from_u64(2), &t1).is_err());
285          storage.insert(&Sequence::from_u64(1), &t1).unwrap();
286  
287          let all_messages = storage
288              .iter_from(&Sequence::from_u64(1))
289              .unwrap()
290              .collect::<Result<Vec<_>, _>>()
291              .unwrap();
292          assert!(all_messages.len() == 1);
293          assert!(all_messages[0] == t1);
294      }
295  }