/ node / bft / storage-service / src / memory.rs
memory.rs
  1  // Copyright (c) 2025 ADnet Contributors
  2  // This file is part of the AlphaOS library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use crate::StorageService;
 17  use alphavm::{
 18      ledger::narwhal::{BatchHeader, Transmission, TransmissionID},
 19      prelude::{Field, Network, Result, bail},
 20  };
 21  
 22  use indexmap::{IndexMap, IndexSet, indexset, map::Entry};
 23  #[cfg(feature = "locktick")]
 24  use locktick::parking_lot::RwLock;
 25  #[cfg(not(feature = "locktick"))]
 26  use parking_lot::RwLock;
 27  use std::collections::{HashMap, HashSet};
 28  use tracing::error;
 29  
 30  /// A BFT in-memory storage service.
 31  #[derive(Debug)]
 32  pub struct BFTMemoryService<N: Network> {
 33      /// The map of `transmission ID` to `(transmission, certificate IDs)` entries.
 34      transmissions: RwLock<IndexMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)>>,
 35      /// The map of `aborted transmission ID` to `certificate IDs` entries.
 36      aborted_transmission_ids: RwLock<IndexMap<TransmissionID<N>, IndexSet<Field<N>>>>,
 37  }
 38  
 39  impl<N: Network> Default for BFTMemoryService<N> {
 40      /// Initializes a new BFT in-memory storage service.
 41      fn default() -> Self {
 42          Self::new()
 43      }
 44  }
 45  
 46  impl<N: Network> BFTMemoryService<N> {
 47      /// Initializes a new BFT in-memory storage service.
 48      pub fn new() -> Self {
 49          Self { transmissions: Default::default(), aborted_transmission_ids: Default::default() }
 50      }
 51  }
 52  
 53  impl<N: Network> StorageService<N> for BFTMemoryService<N> {
 54      /// Returns `true` if the storage contains the specified `transmission ID`.
 55      fn contains_transmission(&self, transmission_id: TransmissionID<N>) -> bool {
 56          // Check if the transmission ID exists in storage.
 57          self.transmissions.read().contains_key(&transmission_id)
 58              || self.aborted_transmission_ids.read().contains_key(&transmission_id)
 59      }
 60  
 61      /// Returns the transmission for the given `transmission ID`.
 62      /// If the transmission does not exist in storage, `None` is returned.
 63      fn get_transmission(&self, transmission_id: TransmissionID<N>) -> Option<Transmission<N>> {
 64          // Get the transmission.
 65          self.transmissions.read().get(&transmission_id).map(|(transmission, _)| transmission).cloned()
 66      }
 67  
 68      /// Returns the missing transmissions in storage from the given transmissions.
 69      fn find_missing_transmissions(
 70          &self,
 71          batch_header: &BatchHeader<N>,
 72          mut transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
 73          aborted_transmissions: HashSet<TransmissionID<N>>,
 74      ) -> Result<HashMap<TransmissionID<N>, Transmission<N>>> {
 75          // Initialize a list for the missing transmissions from storage.
 76          let mut missing_transmissions = HashMap::new();
 77          // Lock the existing transmissions.
 78          let known_transmissions = self.transmissions.read();
 79          // Ensure the declared transmission IDs are all present in storage or the given transmissions map.
 80          for transmission_id in batch_header.transmission_ids() {
 81              // If the transmission ID does not exist, ensure it was provided by the caller or aborted.
 82              if !known_transmissions.contains_key(transmission_id) {
 83                  // Retrieve the transmission.
 84                  match transmissions.remove(transmission_id) {
 85                      // Append the transmission if it exists.
 86                      Some(transmission) => {
 87                          missing_transmissions.insert(*transmission_id, transmission);
 88                      }
 89                      // If the transmission does not exist, check if it was aborted.
 90                      None => {
 91                          if !aborted_transmissions.contains(transmission_id) {
 92                              bail!("Failed to provide a transmission");
 93                          }
 94                      }
 95                  }
 96              }
 97          }
 98          Ok(missing_transmissions)
 99      }
100  
101      /// Inserts the given certificate ID for each of the transmission IDs, using the missing transmissions map, into storage.
102      fn insert_transmissions(
103          &self,
104          certificate_id: Field<N>,
105          transmission_ids: IndexSet<TransmissionID<N>>,
106          aborted_transmission_ids: HashSet<TransmissionID<N>>,
107          mut missing_transmissions: HashMap<TransmissionID<N>, Transmission<N>>,
108      ) {
109          // Acquire the transmissions write lock.
110          let mut transmissions = self.transmissions.write();
111          // Acquire the aborted transmission IDs write lock.
112          let mut aborted_transmission_ids_lock = self.aborted_transmission_ids.write();
113          // Inserts the following:
114          //   - Inserts **only the missing** transmissions from storage.
115          //   - Inserts the certificate ID into the corresponding set for **all** transmissions.
116          'outer: for transmission_id in transmission_ids {
117              // Retrieve the transmission entry.
118              match transmissions.entry(transmission_id) {
119                  Entry::Occupied(mut occupied_entry) => {
120                      let (_, certificate_ids) = occupied_entry.get_mut();
121                      // Insert the certificate ID into the set.
122                      certificate_ids.insert(certificate_id);
123                  }
124                  Entry::Vacant(vacant_entry) => {
125                      // Retrieve the missing transmission.
126                      let Some(transmission) = missing_transmissions.remove(&transmission_id) else {
127                          if !aborted_transmission_ids.contains(&transmission_id)
128                              && !self.contains_transmission(transmission_id)
129                          {
130                              error!("Failed to provide a missing transmission {transmission_id}");
131                          }
132                          continue 'outer;
133                      };
134                      // Prepare the set of certificate IDs.
135                      let certificate_ids = indexset! { certificate_id };
136                      // Insert the transmission and a new set with the certificate ID.
137                      vacant_entry.insert((transmission, certificate_ids));
138                  }
139              }
140          }
141          // Inserts the aborted transmission IDs.
142          for aborted_transmission_id in aborted_transmission_ids {
143              // Retrieve the transmission entry.
144              match aborted_transmission_ids_lock.entry(aborted_transmission_id) {
145                  Entry::Occupied(mut occupied_entry) => {
146                      let certificate_ids = occupied_entry.get_mut();
147                      // Insert the certificate ID into the set.
148                      certificate_ids.insert(certificate_id);
149                  }
150                  Entry::Vacant(vacant_entry) => {
151                      // Prepare the set of certificate IDs.
152                      let certificate_ids = indexset! { certificate_id };
153                      // Insert the transmission and a new set with the certificate ID.
154                      vacant_entry.insert(certificate_ids);
155                  }
156              }
157          }
158      }
159  
160      /// Removes the certificate ID for the transmissions from storage.
161      ///
162      /// If the transmission no longer references any certificate IDs, the entry is removed from storage.
163      fn remove_transmissions(&self, certificate_id: &Field<N>, transmission_ids: &IndexSet<TransmissionID<N>>) {
164          // Acquire the transmissions write lock.
165          let mut transmissions = self.transmissions.write();
166          // Acquire the aborted transmission IDs write lock.
167          let mut aborted_transmission_ids = self.aborted_transmission_ids.write();
168          // If this is the last certificate ID for the transmission ID, remove the transmission.
169          for transmission_id in transmission_ids {
170              // Remove the certificate ID for the transmission ID, and determine if there are any more certificate IDs.
171              match transmissions.entry(*transmission_id) {
172                  Entry::Occupied(mut occupied_entry) => {
173                      let (_, certificate_ids) = occupied_entry.get_mut();
174                      // Remove the certificate ID for the transmission ID.
175                      certificate_ids.swap_remove(certificate_id);
176                      // If there are no more certificate IDs for the transmission ID, remove the transmission.
177                      if certificate_ids.is_empty() {
178                          // Remove the entry for the transmission ID.
179                          occupied_entry.shift_remove();
180                      }
181                  }
182                  Entry::Vacant(_) => {}
183              }
184              // Remove the certificate ID for the aborted transmission ID, and determine if there are any more certificate IDs.
185              match aborted_transmission_ids.entry(*transmission_id) {
186                  Entry::Occupied(mut occupied_entry) => {
187                      let certificate_ids = occupied_entry.get_mut();
188                      // Remove the certificate ID for the transmission ID.
189                      certificate_ids.swap_remove(certificate_id);
190                      // If there are no more certificate IDs for the transmission ID, remove the transmission.
191                      if certificate_ids.is_empty() {
192                          // Remove the entry for the transmission ID.
193                          occupied_entry.shift_remove();
194                      }
195                  }
196                  Entry::Vacant(_) => {}
197              }
198          }
199      }
200  
201      /// Returns a HashMap over the `(transmission ID, (transmission, certificate IDs))` entries.
202      #[cfg(any(test, feature = "test"))]
203      fn as_hashmap(&self) -> HashMap<TransmissionID<N>, (Transmission<N>, IndexSet<Field<N>>)> {
204          self.transmissions.read().clone().into_iter().collect()
205      }
206  }