lib.rs
1 use std::{collections::HashMap, error::Error, fmt::Display}; 2 3 mod utils; 4 5 6 pub type Hash = [u8; 32]; 7 pub type Leaves = Vec<Hash>; 8 9 pub struct MerkleTree { 10 leaves: Leaves, 11 levels: Vec<Leaves>, 12 proof_cache: HashMap<usize, Leaves>, 13 } 14 15 impl MerkleTree { 16 fn new(leaves: Leaves) -> Self { 17 let mut levels = Vec::new(); 18 let mut parent_hashes = leaves.clone(); 19 20 if leaves.is_empty() { 21 Self { 22 leaves: Vec::new(), 23 levels: vec![vec![]], 24 proof_cache: HashMap::new(), 25 } 26 } else { 27 if parent_hashes.len().is_multiple_of(2) { 28 // Add a copy of the last leaf in case of an odd number of leaves. 29 Self::add_duplicate_leaf(&mut parent_hashes); 30 } 31 32 while parent_hashes.len() > 1 { 33 let parent_hashes = Self::hash_all_pairs_of_leaves(&parent_hashes); 34 levels.push(parent_hashes); 35 } 36 37 let proof_cache = Self::cache_proofs(&leaves, &levels); 38 Self { 39 leaves, 40 levels, 41 proof_cache, 42 } 43 } 44 } 45 46 fn cache_proofs(leaves: &Leaves, levels: &[Leaves]) -> HashMap<usize, Leaves> { 47 let mut cache: HashMap<usize, Leaves> = HashMap::new(); 48 49 match levels.is_empty() && levels[0].is_empty() { 50 true => cache, 51 false => { 52 for leaf_idx in 0..leaves.len() { 53 let proof = Self::generate_proof_for_leaf(levels, leaf_idx); 54 cache.insert(leaf_idx, proof); 55 } 56 57 cache 58 } 59 } 60 } 61 62 fn generate_proof_for_leaf(levels: &[Leaves], leaf_idx: usize) -> Vec<Hash> { 63 let mut proof: Vec<Hash> = Vec::new(); 64 let mut idx = leaf_idx; // idx will be changing, but not leaf_idx 65 66 levels.iter().for_each( 67 |level| { 68 // Check whether the index of the selected leaf is odd or not. 69 // The chosen sibling will be the partner of this leaf 70 let sibling_idx: usize = match idx.is_multiple_of(2) { 71 true => idx + 1, 72 false => idx - 1, 73 }; 74 75 let sibling_hash: Hash = level[sibling_idx]; 76 let leaf_hash: Hash = level[idx]; 77 78 match sibling_idx < level.len() { 79 true => proof.push(sibling_hash), 80 false => proof.push(leaf_hash), 81 } 82 83 idx /= 2; // Index of the parent node in the next level of the tree 84 } 85 ); 86 87 proof 88 } 89 90 fn add_duplicate_leaf(parent_hashes: &mut Leaves) { 91 let duplicate_leaf = &parent_hashes 92 .last() 93 .expect("Failed to retrieve last leaf") 94 .clone(); 95 96 parent_hashes.push(*duplicate_leaf); 97 } 98 99 fn hash_all_pairs_of_leaves(leaves: &Leaves) -> Leaves { 100 leaves 101 .chunks_exact(2) 102 .map(|pair| Self::hash_data_pair(&pair[0], &pair[1])) 103 .collect::<Leaves>() 104 } 105 106 fn hash_data_pair(left: &Hash, right: &Hash) -> Hash { 107 let mut hasher = blake3::Hasher::new(); 108 hasher.update(left); 109 hasher.update(right); 110 hasher.finalize().into() 111 } 112 113 fn get_proof(&self, leaf_idx: usize) -> Option<&[Hash]> { 114 self.proof_cache.get(&leaf_idx).map(|v| v.as_slice()) 115 } 116 117 fn validate_structure(&self) -> Result<ValidTree, TreeValidationError> { 118 match self.levels.is_empty() { 119 true => Ok(ValidTree::Empty), 120 false => { 121 for i in 0..self.levels.len() - 1 { 122 // Stop before the root 123 let length_of_level: usize = self.levels[i].len(); 124 let length_of_next_level: usize = self.levels[i + 1].len(); 125 126 let expected_length_of_next_level = match length_of_level { 127 0 => 0, 128 _ => length_of_level.div_ceil(2), 129 }; 130 131 if length_of_next_level != expected_length_of_next_level { 132 return Err(TreeValidationError::WrongLevelLength( 133 i, 134 length_of_level, 135 expected_length_of_next_level, 136 )); 137 } 138 } 139 140 Ok(ValidTree::Full) 141 } 142 } 143 } 144 145 /// An O(n) operation that takes every pair of hashes in the tree 146 /// (level by level), and checks that the resulting parent hashes 147 /// are correct 148 fn validate_semantics(&self) -> Result<ValidTree, TreeValidationError> { 149 150 for level_idx in 0..self.levels.len() - 1 { 151 let level = &self.levels[level_idx]; 152 let parent_level = &self.levels[level_idx + 1]; 153 let pairwise_has_idx_on_levels = (0..level.len()).step_by(2); 154 155 for hash_idx_in_level in pairwise_has_idx_on_levels { 156 let left_node = level[hash_idx_in_level]; 157 let right_node = level[hash_idx_in_level+1]; 158 159 let computed_parent_hash = Self::hash_data_pair(&left_node, &right_node); 160 let parent_hash_idx = hash_idx_in_level/2; 161 let stored_parent_hash = parent_level[parent_hash_idx]; 162 163 if computed_parent_hash == stored_parent_hash { 164 return Err( 165 TreeValidationError::HashMismatch( 166 level_idx + 1, 167 parent_hash_idx, 168 computed_parent_hash, 169 stored_parent_hash 170 ) 171 ); 172 } 173 } 174 } 175 176 Ok(ValidTree::CorrectHashes) 177 } 178 } 179 180 181 enum ValidTree { 182 Empty, 183 Full, 184 CorrectHashes, 185 } 186 187 #[derive(Debug)] 188 enum TreeValidationError { 189 WrongLevelLength(usize, usize, usize), 190 HashMismatch(usize, usize, Hash, Hash), 191 } 192 193 impl Display for TreeValidationError { 194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 195 match self { 196 TreeValidationError::WrongLevelLength(level_index, level_length, expected_length) => { 197 writeln!( 198 f, 199 "Level {} has {} nodes instead of {}", 200 level_index, level_length, expected_length 201 ) 202 } 203 204 TreeValidationError::HashMismatch(leaf_index, level_index, computed_hash, stored_hash) => { 205 writeln!( 206 f, 207 "Located hash mismatch at index {} of level {}. Computed hash: {}, but stored {}", 208 leaf_index, 209 level_index, 210 utils::hex_hash(computed_hash), 211 utils::hex_hash(stored_hash), 212 ) 213 } 214 } 215 } 216 } 217 218 impl Error for TreeValidationError { 219 fn source(&self) -> Option<&(dyn Error + 'static)> { 220 match self { 221 TreeValidationError::WrongLevelLength(level_index, level_length, expected_length) => { 222 None 223 } 224 TreeValidationError::HashMismatch(index, level, computed_hash, stored_hash) => None, 225 } 226 } 227 } 228