tiered.rs
1 use std::collections::BTreeMap; 2 3 use fedimint_core::Amount; 4 use serde::{Deserialize, Serialize}; 5 6 use crate::encoding::{Decodable, DecodeError, Encodable}; 7 use crate::module::registry::ModuleDecoderRegistry; 8 9 #[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash, Deserialize, Serialize)] 10 pub struct InvalidAmountTierError(pub Amount); 11 12 impl std::fmt::Display for InvalidAmountTierError { 13 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 14 write!(f, "Amount tier unknown to mint: {}", self.0) 15 } 16 } 17 18 #[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize, Serialize)] 19 #[serde(transparent)] 20 pub struct Tiered<T>(BTreeMap<Amount, T>); 21 22 impl<T> Default for Tiered<T> { 23 fn default() -> Self { 24 Self(Default::default()) 25 } 26 } 27 28 impl<T> Tiered<T> { 29 /// Returns the highest tier amount 30 pub fn max_tier(&self) -> &Amount { 31 self.0.keys().max().expect("has tiers") 32 } 33 34 pub fn structural_eq<O>(&self, other: &Tiered<O>) -> bool { 35 self.0.keys().eq(other.0.keys()) 36 } 37 38 /// Returns a reference to the key of the specified tier 39 pub fn tier(&self, amount: &Amount) -> Result<&T, InvalidAmountTierError> { 40 self.0.get(amount).ok_or(InvalidAmountTierError(*amount)) 41 } 42 43 pub fn count_tiers(&self) -> usize { 44 self.0.len() 45 } 46 47 pub fn tiers(&self) -> impl DoubleEndedIterator<Item = &Amount> { 48 self.0.keys() 49 } 50 51 pub fn iter(&self) -> impl Iterator<Item = (Amount, &T)> { 52 self.0.iter().map(|(amt, key)| (*amt, key)) 53 } 54 55 pub fn get(&self, amt: Amount) -> Option<&T> { 56 self.0.get(&amt) 57 } 58 59 pub fn get_mut(&mut self, amt: Amount) -> Option<&mut T> { 60 self.0.get_mut(&amt) 61 } 62 63 pub fn insert(&mut self, amt: Amount, v: T) -> Option<T> { 64 self.0.insert(amt, v) 65 } 66 67 pub fn get_mut_or_default(&mut self, amt: Amount) -> &mut T 68 where 69 T: Default, 70 { 71 self.0.entry(amt).or_default() 72 } 73 74 pub fn entry(&mut self, amt: Amount) -> std::collections::btree_map::Entry<'_, Amount, T> 75 where 76 T: Default, 77 { 78 self.0.entry(amt) 79 } 80 81 pub fn as_map(&self) -> &BTreeMap<Amount, T> { 82 &self.0 83 } 84 } 85 86 impl Tiered<()> { 87 /// Generates denominations of a given base up to and including `max` 88 pub fn gen_denominations(denomination_base: u16, max: Amount) -> Tiered<()> { 89 let mut amounts = vec![]; 90 91 let mut denomination = Amount::from_msats(1); 92 while denomination <= max { 93 amounts.push((denomination, ())); 94 denomination = denomination * denomination_base.into(); 95 } 96 97 amounts.into_iter().collect() 98 } 99 } 100 101 impl<T> FromIterator<(Amount, T)> for Tiered<T> { 102 fn from_iter<I: IntoIterator<Item = (Amount, T)>>(iter: I) -> Self { 103 Tiered(iter.into_iter().collect()) 104 } 105 } 106 107 impl<C> Encodable for Tiered<C> 108 where 109 C: Encodable, 110 { 111 fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> { 112 self.0.consensus_encode(writer) 113 } 114 } 115 116 impl<C> Decodable for Tiered<C> 117 where 118 C: Decodable, 119 { 120 fn consensus_decode<D: std::io::Read>( 121 d: &mut D, 122 modules: &ModuleDecoderRegistry, 123 ) -> Result<Self, DecodeError> { 124 Ok(Tiered(BTreeMap::consensus_decode(d, modules)?)) 125 } 126 } 127 128 #[cfg(test)] 129 mod tests { 130 use fedimint_core::Amount; 131 132 use super::Tiered; 133 134 #[test] 135 fn tier_generation_including_max_amount() { 136 let max_amount = Amount::from_msats(16); 137 let denominations = Tiered::gen_denominations(2, max_amount); 138 139 // should produce [1, 2, 4, 8, 16] 140 assert_eq!(denominations.tiers().collect::<Vec<&Amount>>().len(), 5); 141 } 142 143 #[test] 144 fn tier_generation_base_10() { 145 let max_amount = Amount::from_msats(10000); 146 let denominations = Tiered::gen_denominations(10, max_amount); 147 148 // should produce [1, 10, 100, 1000, 10_000] 149 assert_eq!(denominations.tiers().collect::<Vec<&Amount>>().len(), 5); 150 } 151 }