message_storage.rs
1 //! Store messages in mdbx 2 3 use std::{marker::PhantomData, sync::Arc}; 4 5 use apibara_core::stream::{MessageData, RawMessageData, Sequence}; 6 use libmdbx::{Environment, EnvironmentKind, Error as MdbxError, Transaction, RO, RW}; 7 8 use crate::db::{tables, MdbxRWTransactionExt, MdbxTransactionExt, TableCursor}; 9 10 pub trait MessageStorage<M: MessageData> { 11 type Error: std::error::Error + Send + Sync + 'static; 12 13 /// Retrieves a message with the given sequencer number, if any. 14 fn get( 15 &self, 16 sequence: &Sequence, 17 ) -> std::result::Result<Option<RawMessageData<M>>, Self::Error>; 18 } 19 20 /// Store messages in mdbx. 21 pub struct MdbxMessageStorage<E: EnvironmentKind, M: MessageData> { 22 db: Arc<Environment<E>>, 23 phantom: PhantomData<M>, 24 } 25 26 /// [MessageStorage]-related error. 27 #[derive(Debug, thiserror::Error)] 28 pub enum MdbxMessageStorageError { 29 #[error("message has the wrong sequence number")] 30 InvalidMessageSequence { expected: u64, actual: u64 }, 31 #[error("error originating from database")] 32 Database(#[from] MdbxError), 33 } 34 35 pub type Result<T> = std::result::Result<T, MdbxMessageStorageError>; 36 37 pub struct MessageIterator<'txn, E: EnvironmentKind, M: MessageData> { 38 _txn: Transaction<'txn, RO, E>, 39 current: Option<Result<M>>, 40 cursor: TableCursor<'txn, tables::MessageTable<M>, RO>, 41 } 42 43 impl<E, M> MdbxMessageStorage<E, M> 44 where 45 E: EnvironmentKind, 46 M: MessageData, 47 { 48 /// Create a new message store, persisting data to the given mdbx environment. 49 pub fn new(db: Arc<Environment<E>>) -> Result<Self> { 50 let txn = db.begin_rw_txn()?; 51 txn.ensure_table::<tables::MessageTable<M>>(None)?; 52 txn.commit()?; 53 Ok(MdbxMessageStorage { 54 db, 55 phantom: PhantomData, 56 }) 57 } 58 59 /// Insert the given `message` in the store. 60 /// 61 /// Expect `sequence` to be the successor of the current highest sequence number. 62 pub fn insert(&self, sequence: &Sequence, message: &M) -> Result<()> { 63 let txn = self.db.begin_rw_txn()?; 64 self.insert_with_txn(sequence, message, &txn)?; 65 txn.commit()?; 66 Ok(()) 67 } 68 69 /// Same as `insert` but using the given [Transaction]. 70 pub fn insert_with_txn( 71 &self, 72 sequence: &Sequence, 73 message: &M, 74 txn: &Transaction<RW, E>, 75 ) -> Result<()> { 76 let table = txn.open_table::<tables::MessageTable<M>>()?; 77 let mut cursor = table.cursor()?; 78 79 match cursor.last()? { 80 None => { 81 // First element, assert sequence is 0 82 if sequence.as_u64() != 0 { 83 return Err(MdbxMessageStorageError::InvalidMessageSequence { 84 expected: 0, 85 actual: sequence.as_u64(), 86 }); 87 } 88 cursor.put(sequence, message)?; 89 Ok(()) 90 } 91 Some((prev_sequence, _)) => { 92 if sequence.as_u64() != prev_sequence.as_u64() + 1 { 93 return Err(MdbxMessageStorageError::InvalidMessageSequence { 94 expected: prev_sequence.as_u64() + 1, 95 actual: sequence.as_u64(), 96 }); 97 } 98 cursor.put(sequence, message)?; 99 Ok(()) 100 } 101 } 102 } 103 104 /// Delete all messages with sequence number greater than or equal the given `sequence`. 105 /// 106 /// Returns the number of messages deleted. 107 pub fn invalidate(&self, sequence: &Sequence) -> Result<usize> { 108 let txn = self.db.begin_rw_txn()?; 109 let invalidated = self.invalidate_with_txn(sequence, &txn)?; 110 txn.commit()?; 111 Ok(invalidated) 112 } 113 114 /// Same as `invalidate` but using the given [Transaction]. 115 pub fn invalidate_with_txn( 116 &self, 117 sequence: &Sequence, 118 txn: &Transaction<RW, E>, 119 ) -> Result<usize> { 120 let table = txn.open_table::<tables::MessageTable<M>>()?; 121 let mut cursor = table.cursor()?; 122 123 let mut count = 0; 124 loop { 125 match cursor.last()? { 126 None => break, 127 Some((key, _)) => { 128 if key.as_u64() < sequence.as_u64() { 129 break; 130 } 131 cursor.del()?; 132 count += 1; 133 } 134 } 135 } 136 Ok(count) 137 } 138 139 /// Returns an iterator over all messages, starting at the given `start` index. 140 pub fn iter_from(&self, start: &Sequence) -> Result<MessageIterator<'_, E, M>> { 141 let txn = self.db.begin_ro_txn()?; 142 let table = txn.open_table::<tables::MessageTable<M>>()?; 143 let mut cursor = table.cursor()?; 144 let current = cursor.seek_exact(start)?.map(|v| Ok(v.1)); 145 Ok(MessageIterator { 146 cursor, 147 _txn: txn, 148 current, 149 }) 150 } 151 } 152 153 impl<E, M> MessageStorage<M> for MdbxMessageStorage<E, M> 154 where 155 E: EnvironmentKind, 156 M: MessageData, 157 { 158 type Error = MdbxMessageStorageError; 159 160 fn get(&self, sequence: &Sequence) -> Result<Option<RawMessageData<M>>> { 161 let txn = self.db.begin_rw_txn()?; 162 let table = txn.open_table::<tables::MessageTable<M>>()?; 163 let mut cursor = table.cursor()?; 164 let data = cursor.seek_exact_raw(sequence)?.map(|t| t.1); 165 Ok(data) 166 } 167 } 168 169 impl<E, M> MessageStorage<M> for Arc<MdbxMessageStorage<E, M>> 170 where 171 E: EnvironmentKind, 172 M: MessageData, 173 { 174 type Error = MdbxMessageStorageError; 175 176 fn get(&self, sequence: &Sequence) -> Result<Option<RawMessageData<M>>> { 177 let txn = self.db.begin_rw_txn()?; 178 let table = txn.open_table::<tables::MessageTable<M>>()?; 179 let mut cursor = table.cursor()?; 180 let data = cursor.seek_exact_raw(sequence)?.map(|t| t.1); 181 Ok(data) 182 } 183 } 184 185 impl<'txn, E, M> Iterator for MessageIterator<'txn, E, M> 186 where 187 E: EnvironmentKind, 188 M: MessageData, 189 { 190 type Item = Result<M>; 191 192 fn next(&mut self) -> Option<Self::Item> { 193 match self.current.take() { 194 None => None, 195 Some(value) => { 196 self.current = match self.cursor.next() { 197 Err(err) => Some(Err(err.into())), 198 Ok(None) => None, 199 Ok(Some(value)) => Some(Ok(value.1)), 200 }; 201 Some(value) 202 } 203 } 204 } 205 } 206 207 #[cfg(test)] 208 mod tests { 209 use std::sync::Arc; 210 211 use apibara_core::stream::Sequence; 212 use libmdbx::{Environment, NoWriteMap}; 213 use tempfile::tempdir; 214 215 use crate::db::MdbxEnvironmentExt; 216 217 use super::MdbxMessageStorage; 218 219 #[derive(Clone, PartialEq, prost::Message)] 220 pub struct Transfer { 221 #[prost(string, tag = "1")] 222 pub sender: String, 223 #[prost(string, tag = "2")] 224 pub receiver: String, 225 } 226 227 #[test] 228 pub fn test_message_storage() { 229 let path = tempdir().unwrap(); 230 let db = Environment::<NoWriteMap>::open(path.path()).unwrap(); 231 let storage = MdbxMessageStorage::<_, Transfer>::new(Arc::new(db)).unwrap(); 232 233 // first message must have index 0 234 let t0_bad_sequence = Transfer { 235 sender: "ABC".to_string(), 236 receiver: "XYZ".to_string(), 237 }; 238 assert!(storage 239 .insert(&Sequence::from_u64(1), &t0_bad_sequence) 240 .is_err()); 241 242 let t0 = Transfer { 243 sender: "ABC".to_string(), 244 receiver: "XYZ".to_string(), 245 }; 246 storage.insert(&Sequence::from_u64(0), &t0).unwrap(); 247 248 // next message must have index 1 249 let t1 = Transfer { 250 sender: "AOE".to_string(), 251 receiver: "TNS".to_string(), 252 }; 253 assert!(storage.insert(&Sequence::from_u64(0), &t1).is_err()); 254 assert!(storage.insert(&Sequence::from_u64(2), &t1).is_err()); 255 storage.insert(&Sequence::from_u64(1), &t1).unwrap(); 256 257 let all_messages = storage 258 .iter_from(&Sequence::from_u64(0)) 259 .unwrap() 260 .collect::<Result<Vec<_>, _>>() 261 .unwrap(); 262 263 assert!(all_messages.len() == 2); 264 assert!(all_messages[0] == t0); 265 assert!(all_messages[1] == t1); 266 267 // invalidate latest message 268 let count = storage.invalidate(&Sequence::from_u64(1)).unwrap(); 269 assert!(count == 1); 270 // second time is a noop 271 let count = storage.invalidate(&Sequence::from_u64(1)).unwrap(); 272 assert!(count == 0); 273 274 let all_messages = storage 275 .iter_from(&Sequence::from_u64(0)) 276 .unwrap() 277 .collect::<Result<Vec<_>, _>>() 278 .unwrap(); 279 assert!(all_messages.len() == 1); 280 assert!(all_messages[0] == t0); 281 282 // insert value again 283 assert!(storage.insert(&Sequence::from_u64(0), &t1).is_err()); 284 assert!(storage.insert(&Sequence::from_u64(2), &t1).is_err()); 285 storage.insert(&Sequence::from_u64(1), &t1).unwrap(); 286 287 let all_messages = storage 288 .iter_from(&Sequence::from_u64(1)) 289 .unwrap() 290 .collect::<Result<Vec<_>, _>>() 291 .unwrap(); 292 assert!(all_messages.len() == 1); 293 assert!(all_messages[0] == t1); 294 } 295 }