/ fedimint-core / src / tiered.rs
tiered.rs
  1  use std::collections::BTreeMap;
  2  
  3  use fedimint_core::Amount;
  4  use serde::{Deserialize, Serialize};
  5  
  6  use crate::encoding::{Decodable, DecodeError, Encodable};
  7  use crate::module::registry::ModuleDecoderRegistry;
  8  
  9  #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash, Deserialize, Serialize)]
 10  pub struct InvalidAmountTierError(pub Amount);
 11  
 12  impl std::fmt::Display for InvalidAmountTierError {
 13      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 14          write!(f, "Amount tier unknown to mint: {}", self.0)
 15      }
 16  }
 17  
 18  #[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize, Serialize)]
 19  #[serde(transparent)]
 20  pub struct Tiered<T>(BTreeMap<Amount, T>);
 21  
 22  impl<T> Default for Tiered<T> {
 23      fn default() -> Self {
 24          Self(Default::default())
 25      }
 26  }
 27  
 28  impl<T> Tiered<T> {
 29      /// Returns the highest tier amount
 30      pub fn max_tier(&self) -> &Amount {
 31          self.0.keys().max().expect("has tiers")
 32      }
 33  
 34      pub fn structural_eq<O>(&self, other: &Tiered<O>) -> bool {
 35          self.0.keys().eq(other.0.keys())
 36      }
 37  
 38      /// Returns a reference to the key of the specified tier
 39      pub fn tier(&self, amount: &Amount) -> Result<&T, InvalidAmountTierError> {
 40          self.0.get(amount).ok_or(InvalidAmountTierError(*amount))
 41      }
 42  
 43      pub fn count_tiers(&self) -> usize {
 44          self.0.len()
 45      }
 46  
 47      pub fn tiers(&self) -> impl DoubleEndedIterator<Item = &Amount> {
 48          self.0.keys()
 49      }
 50  
 51      pub fn iter(&self) -> impl Iterator<Item = (Amount, &T)> {
 52          self.0.iter().map(|(amt, key)| (*amt, key))
 53      }
 54  
 55      pub fn get(&self, amt: Amount) -> Option<&T> {
 56          self.0.get(&amt)
 57      }
 58  
 59      pub fn get_mut(&mut self, amt: Amount) -> Option<&mut T> {
 60          self.0.get_mut(&amt)
 61      }
 62  
 63      pub fn insert(&mut self, amt: Amount, v: T) -> Option<T> {
 64          self.0.insert(amt, v)
 65      }
 66  
 67      pub fn get_mut_or_default(&mut self, amt: Amount) -> &mut T
 68      where
 69          T: Default,
 70      {
 71          self.0.entry(amt).or_default()
 72      }
 73  
 74      pub fn entry(&mut self, amt: Amount) -> std::collections::btree_map::Entry<'_, Amount, T>
 75      where
 76          T: Default,
 77      {
 78          self.0.entry(amt)
 79      }
 80  
 81      pub fn as_map(&self) -> &BTreeMap<Amount, T> {
 82          &self.0
 83      }
 84  }
 85  
 86  impl Tiered<()> {
 87      /// Generates denominations of a given base up to and including `max`
 88      pub fn gen_denominations(denomination_base: u16, max: Amount) -> Tiered<()> {
 89          let mut amounts = vec![];
 90  
 91          let mut denomination = Amount::from_msats(1);
 92          while denomination <= max {
 93              amounts.push((denomination, ()));
 94              denomination = denomination * denomination_base.into();
 95          }
 96  
 97          amounts.into_iter().collect()
 98      }
 99  }
100  
101  impl<T> FromIterator<(Amount, T)> for Tiered<T> {
102      fn from_iter<I: IntoIterator<Item = (Amount, T)>>(iter: I) -> Self {
103          Tiered(iter.into_iter().collect())
104      }
105  }
106  
107  impl<C> Encodable for Tiered<C>
108  where
109      C: Encodable,
110  {
111      fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> {
112          self.0.consensus_encode(writer)
113      }
114  }
115  
116  impl<C> Decodable for Tiered<C>
117  where
118      C: Decodable,
119  {
120      fn consensus_decode<D: std::io::Read>(
121          d: &mut D,
122          modules: &ModuleDecoderRegistry,
123      ) -> Result<Self, DecodeError> {
124          Ok(Tiered(BTreeMap::consensus_decode(d, modules)?))
125      }
126  }
127  
128  #[cfg(test)]
129  mod tests {
130      use fedimint_core::Amount;
131  
132      use super::Tiered;
133  
134      #[test]
135      fn tier_generation_including_max_amount() {
136          let max_amount = Amount::from_msats(16);
137          let denominations = Tiered::gen_denominations(2, max_amount);
138  
139          // should produce [1, 2, 4, 8, 16]
140          assert_eq!(denominations.tiers().collect::<Vec<&Amount>>().len(), 5);
141      }
142  
143      #[test]
144      fn tier_generation_base_10() {
145          let max_amount = Amount::from_msats(10000);
146          let denominations = Tiered::gen_denominations(10, max_amount);
147  
148          // should produce [1, 10, 100, 1000, 10_000]
149          assert_eq!(denominations.tiers().collect::<Vec<&Amount>>().len(), 5);
150      }
151  }