tiered_multi.rs
1 use std::collections::btree_map::Entry; 2 use std::collections::BTreeMap; 3 use std::marker::PhantomData; 4 5 use fedimint_core::encoding::{Decodable, DecodeError, Encodable}; 6 use serde::{Deserialize, Serialize}; 7 8 use crate::module::registry::ModuleDecoderRegistry; 9 use crate::tiered::InvalidAmountTierError; 10 use crate::{Amount, Tiered}; 11 12 /// Represents notes of different denominations. 13 /// 14 /// **Attention:** care has to be taken when constructing this to avoid overflow 15 /// when calculating the total amount represented. As it is prudent to limit 16 /// both the maximum note amount and maximum note count per transaction this 17 /// shouldn't be a problem in practice though. 18 #[derive(Debug, Clone, Eq, PartialEq, Hash, Deserialize, Serialize)] 19 pub struct TieredMulti<T>(BTreeMap<Amount, Vec<T>>); 20 21 impl<T> TieredMulti<T> { 22 /// Returns a new `TieredMulti` with the given `BTreeMap` map 23 pub fn new(map: BTreeMap<Amount, Vec<T>>) -> Self { 24 TieredMulti(map.into_iter().filter(|(_, v)| !v.is_empty()).collect()) 25 } 26 27 /// Returns the total value of all notes in msat as `Amount` 28 pub fn total_amount(&self) -> Amount { 29 let milli_sat = self 30 .0 31 .iter() 32 .map(|(tier, notes)| tier.msats * (notes.len() as u64)) 33 .sum(); 34 Amount::from_msats(milli_sat) 35 } 36 37 /// Returns the number of items in all vectors 38 pub fn count_items(&self) -> usize { 39 self.0.values().map(|notes| notes.len()).sum() 40 } 41 42 /// Returns the number of tiers 43 pub fn count_tiers(&self) -> usize { 44 self.0.len() 45 } 46 47 /// Returns an iterator over the keys 48 pub fn iter_tiers(&self) -> impl Iterator<Item = &Amount> { 49 self.0.keys() 50 } 51 52 /// Returns the summary of number of items in each tier 53 pub fn summary(&self) -> TieredCounts { 54 TieredCounts(Tiered::from_iter( 55 self.iter().map(|(amount, values)| (*amount, values.len())), 56 )) 57 } 58 59 /// Verifies whether all vectors in all tiers are empty 60 pub fn is_empty(&self) -> bool { 61 self.assert_invariants(); 62 self.count_items() == 0 63 } 64 65 /// Verifies whether the structure of `self` and `other` is identical 66 pub fn structural_eq<O>(&self, other: &TieredMulti<O>) -> bool { 67 let tier_eq = self.0.keys().eq(other.0.keys()); 68 let per_tier_eq = self 69 .0 70 .values() 71 .zip(other.0.values()) 72 .all(|(c1, c2)| c1.len() == c2.len()); 73 tier_eq && per_tier_eq 74 } 75 76 /// Returns an borrowing iterator 77 pub fn iter(&self) -> impl Iterator<Item = (&Amount, &Vec<T>)> { 78 self.0.iter() 79 } 80 81 /// Returns an iterator over every `(Amount, &T)` 82 /// 83 /// Note: The order of the elements is important: 84 /// from the lowest tier to the highest, then in order of elements in the 85 /// Vec 86 pub fn iter_items(&self) -> impl DoubleEndedIterator<Item = (Amount, &T)> { 87 // Note: If you change the method implementation, make sure that the returned 88 // order of the elements stays consistent. 89 self.0 90 .iter() 91 .flat_map(|(amt, notes)| notes.iter().map(move |c| (*amt, c))) 92 } 93 94 /// Returns an consuming iterator over every `(Amount, T)` 95 /// 96 /// Note: The order of the elements is important: 97 /// from the lowest tier to the highest, then in order of elements in the 98 /// Vec 99 pub fn into_iter_items(self) -> impl DoubleEndedIterator<Item = (Amount, T)> { 100 // Note: If you change the method implementation, make sure that the returned 101 // order of the elements stays consistent. 102 self.0 103 .into_iter() 104 .flat_map(|(amt, notes)| notes.into_iter().map(move |c| (amt, c))) 105 } 106 107 /// Returns the length of the longest vector of all tiers, ignoring the 108 /// `except` tier 109 pub fn longest_tier_except(&self, except: &Amount) -> usize { 110 self.0 111 .iter() 112 .filter_map(|(amt, notes)| { 113 if amt != except { 114 Some(notes.len()) 115 } else { 116 None 117 } 118 }) 119 .max() 120 .unwrap_or_default() 121 } 122 123 /// Verifies that all keys in `self` are present in the keys of the given 124 /// parameter `Tiered` 125 pub fn all_tiers_exist_in<K>(&self, keys: &Tiered<K>) -> Result<(), InvalidAmountTierError> { 126 match self.0.keys().find(|&amt| keys.get(*amt).is_none()) { 127 Some(amt) => Err(InvalidAmountTierError(*amt)), 128 None => Ok(()), 129 } 130 } 131 132 /// Returns an `Option` with a reference to the vector of the given `Amount` 133 pub fn get(&self, amt: Amount) -> Option<&Vec<T>> { 134 self.assert_invariants(); 135 self.0.get(&amt) 136 } 137 138 pub fn push(&mut self, amt: Amount, val: T) { 139 self.0.entry(amt).or_default().push(val) 140 } 141 142 fn assert_invariants(&self) { 143 // Just for compactness and determinism, we don't want entries with 0 items 144 #[cfg(debug_assertions)] 145 self.iter().for_each(|(_, v)| debug_assert!(!v.is_empty())) 146 } 147 } 148 149 impl<C> FromIterator<(Amount, C)> for TieredMulti<C> { 150 fn from_iter<T: IntoIterator<Item = (Amount, C)>>(iter: T) -> Self { 151 let mut res = TieredMulti::default(); 152 res.extend(iter); 153 res.assert_invariants(); 154 res 155 } 156 } 157 158 impl<C> IntoIterator for TieredMulti<C> 159 where 160 C: 'static + Send, 161 { 162 type Item = (Amount, C); 163 type IntoIter = Box<dyn Iterator<Item = (Amount, C)> + Send>; 164 165 fn into_iter(self) -> Self::IntoIter { 166 Box::new( 167 self.0 168 .into_iter() 169 .flat_map(|(amt, notes)| notes.into_iter().map(move |c| (amt, c))), 170 ) 171 } 172 } 173 174 impl<C> Default for TieredMulti<C> { 175 fn default() -> Self { 176 TieredMulti(BTreeMap::default()) 177 } 178 } 179 180 impl<C> Extend<(Amount, C)> for TieredMulti<C> { 181 fn extend<T: IntoIterator<Item = (Amount, C)>>(&mut self, iter: T) { 182 for (amount, note) in iter { 183 self.0.entry(amount).or_default().push(note) 184 } 185 } 186 } 187 188 impl<C> Encodable for TieredMulti<C> 189 where 190 C: Encodable + 'static, 191 { 192 fn consensus_encode<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, std::io::Error> { 193 self.0.consensus_encode(writer) 194 } 195 } 196 197 impl<C> Decodable for TieredMulti<C> 198 where 199 C: Decodable + 'static, 200 { 201 fn consensus_decode_from_finite_reader<D: std::io::Read>( 202 d: &mut D, 203 modules: &ModuleDecoderRegistry, 204 ) -> Result<Self, DecodeError> { 205 Ok(TieredMulti(BTreeMap::consensus_decode_from_finite_reader( 206 d, modules, 207 )?)) 208 } 209 } 210 211 pub struct TieredMultiZip<'a, I, T> 212 where 213 I: 'a, 214 { 215 iters: Vec<I>, 216 _pd: PhantomData<&'a T>, 217 } 218 219 impl<'a, I, C> TieredMultiZip<'a, I, C> { 220 /// Creates a new MultiZip Iterator from `Notes` iterators. These have to be 221 /// checked for structural equality! There also has to be at least one 222 /// iterator in the `iter` vector. 223 pub fn new(iters: Vec<I>) -> Self { 224 assert!(!iters.is_empty()); 225 226 TieredMultiZip { 227 iters, 228 _pd: Default::default(), 229 } 230 } 231 } 232 233 impl<'a, I, C> Iterator for TieredMultiZip<'a, I, C> 234 where 235 I: Iterator<Item = (Amount, C)>, 236 { 237 type Item = (Amount, Vec<C>); 238 239 fn next(&mut self) -> Option<Self::Item> { 240 let mut notes = Vec::with_capacity(self.iters.len()); 241 let mut amount = None; 242 for iter in self.iters.iter_mut() { 243 match iter.next() { 244 Some((amt, note)) => { 245 if let Some(amount) = amount { 246 // This may fail if notes weren't tested for structural equality 247 assert_eq!(amount, amt); 248 } else { 249 amount = Some(amt); 250 } 251 notes.push(note); 252 } 253 None => return None, 254 } 255 } 256 257 // This should always hold as long as this impl is correct 258 assert_eq!(notes.len(), self.iters.len()); 259 260 Some(( 261 amount.expect("The multi zip must contain at least one iterator"), 262 notes, 263 )) 264 } 265 } 266 267 #[derive(Debug, PartialEq, Default, Serialize, Deserialize, Clone)] 268 pub struct TieredCounts(Tiered<usize>); 269 270 impl TieredCounts { 271 pub fn inc(&mut self, tier: Amount, n: usize) { 272 if 0 < n { 273 *self.0.get_mut_or_default(tier) += n; 274 } 275 } 276 277 pub fn dec(&mut self, tier: Amount) { 278 match self.0.entry(tier) { 279 Entry::Vacant(_) => panic!("Trying to decrement an empty tier"), 280 Entry::Occupied(mut c) => { 281 assert!(*c.get() != 0); 282 if *c.get() == 1 { 283 c.remove_entry(); 284 } else { 285 *c.get_mut() -= 1; 286 } 287 } 288 } 289 self.assert_invariants(); 290 } 291 292 pub fn iter(&self) -> impl Iterator<Item = (Amount, usize)> + '_ { 293 self.0.iter().map(|(k, v)| (k, *v)) 294 } 295 296 pub fn total_amount(&self) -> Amount { 297 self.0.iter().map(|(k, v)| k * (*v as u64)).sum::<Amount>() 298 } 299 300 pub fn count_items(&self) -> usize { 301 self.0.iter().map(|(_, v)| *v).sum() 302 } 303 304 pub fn count_tiers(&self) -> usize { 305 self.0.count_tiers() 306 } 307 308 pub fn is_empty(&self) -> bool { 309 self.count_items() == 0 310 } 311 312 pub fn get(&self, tier: Amount) -> usize { 313 self.assert_invariants(); 314 self.0.get(tier).copied().unwrap_or_default() 315 } 316 317 fn assert_invariants(&self) { 318 // Just for compactness and determinism, we don't want entries with 0 count 319 #[cfg(debug_assertions)] 320 self.iter().for_each(|(_, count)| debug_assert!(0 < count)) 321 } 322 } 323 324 impl FromIterator<(Amount, usize)> for TieredCounts { 325 fn from_iter<I: IntoIterator<Item = (Amount, usize)>>(iter: I) -> Self { 326 TieredCounts(iter.into_iter().filter(|(_, count)| *count != 0).collect()) 327 } 328 } 329 330 #[cfg(test)] 331 mod test { 332 333 use super::*; 334 335 #[test] 336 fn summary_works() { 337 let notes = TieredMulti::from_iter(vec![ 338 (Amount::from_sats(1), ()), 339 (Amount::from_sats(2), ()), 340 (Amount::from_sats(3), ()), 341 (Amount::from_sats(3), ()), 342 (Amount::from_sats(2), ()), 343 (Amount::from_sats(2), ()), 344 ]); 345 let summary = notes.summary(); 346 assert_eq!( 347 summary.iter().collect::<Vec<_>>(), 348 vec![ 349 (Amount::from_sats(1), 1), 350 (Amount::from_sats(2), 3), 351 (Amount::from_sats(3), 2), 352 ] 353 ); 354 assert_eq!(summary.total_amount(), notes.total_amount()); 355 assert_eq!(summary.count_items(), notes.count_items()); 356 assert_eq!(summary.count_tiers(), notes.count_tiers()); 357 } 358 }