/ fedimint-core / src / tiered_multi.rs
tiered_multi.rs
  1  use std::collections::btree_map::Entry;
  2  use std::collections::BTreeMap;
  3  use std::marker::PhantomData;
  4  
  5  use fedimint_core::encoding::{Decodable, DecodeError, Encodable};
  6  use serde::{Deserialize, Serialize};
  7  
  8  use crate::module::registry::ModuleDecoderRegistry;
  9  use crate::tiered::InvalidAmountTierError;
 10  use crate::{Amount, Tiered};
 11  
 12  /// Represents notes of different denominations.
 13  ///
 14  /// **Attention:** care has to be taken when constructing this to avoid overflow
 15  /// when calculating the total amount represented. As it is prudent to limit
 16  /// both the maximum note amount and maximum note count per transaction this
 17  /// shouldn't be a problem in practice though.
 18  #[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize, Serialize)]
 19  pub struct TieredMulti<T>(BTreeMap<Amount, Vec<T>>);
 20  
 21  impl<T> TieredMulti<T> {
 22      /// Returns a new `TieredMulti` with the given `BTreeMap` map
 23      pub fn new(map: BTreeMap<Amount, Vec<T>>) -> Self {
 24          TieredMulti(map.into_iter().filter(|(_, v)| !v.is_empty()).collect())
 25      }
 26  
 27      /// Returns the total value of all notes in msat as `Amount`
 28      pub fn total_amount(&self) -> Amount {
 29          let milli_sat = self
 30              .0
 31              .iter()
 32              .map(|(tier, notes)| tier.msats * (notes.len() as u64))
 33              .sum();
 34          Amount::from_msats(milli_sat)
 35      }
 36  
 37      /// Returns the number of items in all vectors
 38      pub fn count_items(&self) -> usize {
 39          self.0.values().map(|notes| notes.len()).sum()
 40      }
 41  
 42      /// Returns the number of tiers
 43      pub fn count_tiers(&self) -> usize {
 44          self.0.len()
 45      }
 46  
 47      /// Returns an iterator over the keys
 48      pub fn iter_tiers(&self) -> impl Iterator<Item = &Amount> {
 49          self.0.keys()
 50      }
 51  
 52      /// Returns the summary of number of items in each tier
 53      pub fn summary(&self) -> TieredCounts {
 54          TieredCounts(Tiered::from_iter(
 55              self.iter().map(|(amount, values)| (*amount, values.len())),
 56          ))
 57      }
 58  
 59      /// Verifies whether all vectors in all tiers are empty
 60      pub fn is_empty(&self) -> bool {
 61          self.assert_invariants();
 62          self.count_items() == 0
 63      }
 64  
 65      /// Verifies whether the structure of `self` and `other` is identical
 66      pub fn structural_eq<O>(&self, other: &TieredMulti<O>) -> bool {
 67          let tier_eq = self.0.keys().eq(other.0.keys());
 68          let per_tier_eq = self
 69              .0
 70              .values()
 71              .zip(other.0.values())
 72              .all(|(c1, c2)| c1.len() == c2.len());
 73          tier_eq && per_tier_eq
 74      }
 75  
 76      /// Returns an borrowing iterator
 77      pub fn iter(&self) -> impl Iterator<Item = (&Amount, &Vec<T>)> {
 78          self.0.iter()
 79      }
 80  
 81      /// Returns an iterator over every `(Amount, &T)`
 82      ///
 83      /// Note: The order of the elements is important:
 84      /// from the lowest tier to the highest, then in order of elements in the
 85      /// Vec
 86      pub fn iter_items(&self) -> impl DoubleEndedIterator<Item = (Amount, &T)> {
 87          // Note: If you change the method implementation, make sure that the returned
 88          // order of the elements stays consistent.
 89          self.0
 90              .iter()
 91              .flat_map(|(amt, notes)| notes.iter().map(move |c| (*amt, c)))
 92      }
 93  
 94      /// Returns an consuming iterator over every `(Amount, T)`
 95      ///
 96      /// Note: The order of the elements is important:
 97      /// from the lowest tier to the highest, then in order of elements in the
 98      /// Vec
 99      pub fn into_iter_items(self) -> impl DoubleEndedIterator<Item = (Amount, T)> {
100          // Note: If you change the method implementation, make sure that the returned
101          // order of the elements stays consistent.
102          self.0
103              .into_iter()
104              .flat_map(|(amt, notes)| notes.into_iter().map(move |c| (amt, c)))
105      }
106  
107      /// Returns the length of the longest vector of all tiers, ignoring the
108      /// `except` tier
109      pub fn longest_tier_except(&self, except: &Amount) -> usize {
110          self.0
111              .iter()
112              .filter_map(|(amt, notes)| {
113                  if amt != except {
114                      Some(notes.len())
115                  } else {
116                      None
117                  }
118              })
119              .max()
120              .unwrap_or_default()
121      }
122  
123      /// Verifies that all keys in `self` are present in the keys of the given
124      /// parameter `Tiered`
125      pub fn all_tiers_exist_in<K>(&self, keys: &Tiered<K>) -> Result<(), InvalidAmountTierError> {
126          match self.0.keys().find(|&amt| keys.get(*amt).is_none()) {
127              Some(amt) => Err(InvalidAmountTierError(*amt)),
128              None => Ok(()),
129          }
130      }
131  
132      /// Returns an `Option` with a reference to the vector of the given `Amount`
133      pub fn get(&self, amt: Amount) -> Option<&Vec<T>> {
134          self.assert_invariants();
135          self.0.get(&amt)
136      }
137  
138      pub fn push(&mut self, amt: Amount, val: T) {
139          self.0.entry(amt).or_default().push(val)
140      }
141  
142      fn assert_invariants(&self) {
143          // Just for compactness and determinism, we don't want entries with 0 items
144          #[cfg(debug_assertions)]
145          self.iter().for_each(|(_, v)| debug_assert!(!v.is_empty()))
146      }
147  }
148  
149  impl<C> FromIterator<(Amount, C)> for TieredMulti<C> {
150      fn from_iter<T: IntoIterator<Item = (Amount, C)>>(iter: T) -> Self {
151          let mut res = TieredMulti::default();
152          res.extend(iter);
153          res.assert_invariants();
154          res
155      }
156  }
157  
158  impl<C> IntoIterator for TieredMulti<C>
159  where
160      C: 'static + Send,
161  {
162      type Item = (Amount, C);
163      type IntoIter = Box<dyn Iterator<Item = (Amount, C)> + Send>;
164  
165      fn into_iter(self) -> Self::IntoIter {
166          Box::new(
167              self.0
168                  .into_iter()
169                  .flat_map(|(amt, notes)| notes.into_iter().map(move |c| (amt, c))),
170          )
171      }
172  }
173  
174  impl<C> Default for TieredMulti<C> {
175      fn default() -> Self {
176          TieredMulti(BTreeMap::default())
177      }
178  }
179  
180  impl<C> Extend<(Amount, C)> for TieredMulti<C> {
181      fn extend<T: IntoIterator<Item = (Amount, C)>>(&mut self, iter: T) {
182          for (amount, note) in iter {
183              self.0.entry(amount).or_default().push(note)
184          }
185      }
186  }
187  
188  impl<C> Encodable for TieredMulti<C>
189  where
190      C: Encodable + 'static,
191  {
192      fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
193          self.0.consensus_encode(writer)
194      }
195  }
196  
197  impl<C> Decodable for TieredMulti<C>
198  where
199      C: Decodable + 'static,
200  {
201      fn consensus_decode_from_finite_reader<D: std::io::Read>(
202          d: &mut D,
203          modules: &ModuleDecoderRegistry,
204      ) -> Result<Self, DecodeError> {
205          Ok(TieredMulti(BTreeMap::consensus_decode_from_finite_reader(
206              d, modules,
207          )?))
208      }
209  }
210  
211  pub struct TieredMultiZip<'a, I, T>
212  where
213      I: 'a,
214  {
215      iters: Vec<I>,
216      _pd: PhantomData<&'a T>,
217  }
218  
219  impl<'a, I, C> TieredMultiZip<'a, I, C> {
220      /// Creates a new MultiZip Iterator from `Notes` iterators. These have to be
221      /// checked for structural equality! There also has to be at least one
222      /// iterator in the `iter` vector.
223      pub fn new(iters: Vec<I>) -> Self {
224          assert!(!iters.is_empty());
225  
226          TieredMultiZip {
227              iters,
228              _pd: Default::default(),
229          }
230      }
231  }
232  
233  impl<'a, I, C> Iterator for TieredMultiZip<'a, I, C>
234  where
235      I: Iterator<Item = (Amount, C)>,
236  {
237      type Item = (Amount, Vec<C>);
238  
239      fn next(&mut self) -> Option<Self::Item> {
240          let mut notes = Vec::with_capacity(self.iters.len());
241          let mut amount = None;
242          for iter in self.iters.iter_mut() {
243              match iter.next() {
244                  Some((amt, note)) => {
245                      if let Some(amount) = amount {
246                          // This may fail if notes weren't tested for structural equality
247                          assert_eq!(amount, amt);
248                      } else {
249                          amount = Some(amt);
250                      }
251                      notes.push(note);
252                  }
253                  None => return None,
254              }
255          }
256  
257          // This should always hold as long as this impl is correct
258          assert_eq!(notes.len(), self.iters.len());
259  
260          Some((
261              amount.expect("The multi zip must contain at least one iterator"),
262              notes,
263          ))
264      }
265  }
266  
267  #[derive(Debug, PartialEq, Default, Serialize, Deserialize, Clone)]
268  pub struct TieredCounts(Tiered<usize>);
269  
270  impl TieredCounts {
271      pub fn inc(&mut self, tier: Amount, n: usize) {
272          if 0 < n {
273              *self.0.get_mut_or_default(tier) += n;
274          }
275      }
276  
277      pub fn dec(&mut self, tier: Amount) {
278          match self.0.entry(tier) {
279              Entry::Vacant(_) => panic!("Trying to decrement an empty tier"),
280              Entry::Occupied(mut c) => {
281                  assert!(*c.get() != 0);
282                  if *c.get() == 1 {
283                      c.remove_entry();
284                  } else {
285                      *c.get_mut() -= 1;
286                  }
287              }
288          }
289          self.assert_invariants();
290      }
291  
292      pub fn iter(&self) -> impl Iterator<Item = (Amount, usize)> + '_ {
293          self.0.iter().map(|(k, v)| (k, *v))
294      }
295  
296      pub fn total_amount(&self) -> Amount {
297          self.0.iter().map(|(k, v)| k * (*v as u64)).sum::<Amount>()
298      }
299  
300      pub fn count_items(&self) -> usize {
301          self.0.iter().map(|(_, v)| *v).sum()
302      }
303  
304      pub fn count_tiers(&self) -> usize {
305          self.0.count_tiers()
306      }
307  
308      pub fn is_empty(&self) -> bool {
309          self.count_items() == 0
310      }
311  
312      pub fn get(&self, tier: Amount) -> usize {
313          self.assert_invariants();
314          self.0.get(tier).copied().unwrap_or_default()
315      }
316  
317      fn assert_invariants(&self) {
318          // Just for compactness and determinism, we don't want entries with 0 count
319          #[cfg(debug_assertions)]
320          self.iter().for_each(|(_, count)| debug_assert!(0 < count))
321      }
322  }
323  
324  impl FromIterator<(Amount, usize)> for TieredCounts {
325      fn from_iter<I: IntoIterator<Item = (Amount, usize)>>(iter: I) -> Self {
326          TieredCounts(iter.into_iter().filter(|(_, count)| *count != 0).collect())
327      }
328  }
329  
330  #[cfg(test)]
331  mod test {
332  
333      use super::*;
334  
335      #[test]
336      fn summary_works() {
337          let notes = TieredMulti::from_iter(vec![
338              (Amount::from_sats(1), ()),
339              (Amount::from_sats(2), ()),
340              (Amount::from_sats(3), ()),
341              (Amount::from_sats(3), ()),
342              (Amount::from_sats(2), ()),
343              (Amount::from_sats(2), ()),
344          ]);
345          let summary = notes.summary();
346          assert_eq!(
347              summary.iter().collect::<Vec<_>>(),
348              vec![
349                  (Amount::from_sats(1), 1),
350                  (Amount::from_sats(2), 3),
351                  (Amount::from_sats(3), 2),
352              ]
353          );
354          assert_eq!(summary.total_amount(), notes.total_amount());
355          assert_eq!(summary.count_items(), notes.count_items());
356          assert_eq!(summary.count_tiers(), notes.count_tiers());
357      }
358  }