/ tree / src / lib.rs
lib.rs
  1  use std::{collections::HashMap, error::Error, fmt::Display};
  2  
  3  mod utils;
  4  
  5  
  6  pub type Hash = [u8; 32];
  7  pub type Leaves = Vec<Hash>;
  8  
  9  pub struct MerkleTree {
 10      leaves: Leaves,
 11      levels: Vec<Leaves>,
 12      proof_cache: HashMap<usize, Leaves>,
 13  }
 14  
 15  impl MerkleTree {
 16      fn new(leaves: Leaves) -> Self {
 17          let mut levels = Vec::new();
 18          let mut parent_hashes = leaves.clone();
 19  
 20          if leaves.is_empty() {
 21              Self {
 22                  leaves: Vec::new(),
 23                  levels: vec![vec![]],
 24                  proof_cache: HashMap::new(),
 25              }
 26          } else {
 27              if parent_hashes.len().is_multiple_of(2) {
 28                  // Add a copy of the last leaf in case of an odd number of leaves.
 29                  Self::add_duplicate_leaf(&mut parent_hashes);
 30              }
 31  
 32              while parent_hashes.len() > 1 {
 33                  let parent_hashes = Self::hash_all_pairs_of_leaves(&parent_hashes);
 34                  levels.push(parent_hashes);
 35              }
 36  
 37              let proof_cache = Self::cache_proofs(&leaves, &levels);
 38              Self {
 39                  leaves,
 40                  levels,
 41                  proof_cache,
 42              }
 43          }
 44      }
 45  
 46      fn cache_proofs(leaves: &Leaves, levels: &[Leaves]) -> HashMap<usize, Leaves> {
 47          let mut cache: HashMap<usize, Leaves> = HashMap::new();
 48  
 49          match levels.is_empty() && levels[0].is_empty() {
 50              true => cache,
 51              false => {
 52                  for leaf_idx in 0..leaves.len() {
 53                      let proof = Self::generate_proof_for_leaf(levels, leaf_idx);
 54                      cache.insert(leaf_idx, proof);
 55                  }
 56  
 57                  cache
 58              }
 59          }
 60      }
 61  
 62      fn generate_proof_for_leaf(levels: &[Leaves], leaf_idx: usize) -> Vec<Hash> {
 63          let mut proof: Vec<Hash> = Vec::new();
 64          let mut idx = leaf_idx; // idx will be changing, but not leaf_idx
 65  
 66          levels.iter().for_each(
 67              |level| {
 68                  // Check whether the index of the selected leaf is odd or not.
 69                  // The chosen sibling will be the partner of this leaf
 70                  let sibling_idx: usize = match idx.is_multiple_of(2) {
 71                      true => idx + 1,
 72                      false => idx - 1,
 73                  };
 74  
 75                  let sibling_hash: Hash = level[sibling_idx];
 76                  let leaf_hash: Hash = level[idx];
 77  
 78                  match sibling_idx < level.len() {
 79                      true => proof.push(sibling_hash),
 80                      false => proof.push(leaf_hash),
 81                  }
 82  
 83                  idx /= 2;  // Index of the parent node in the next level of the tree
 84              }
 85          );
 86  
 87          proof
 88      }
 89  
 90      fn add_duplicate_leaf(parent_hashes: &mut Leaves) {
 91          let duplicate_leaf = &parent_hashes
 92              .last()
 93              .expect("Failed to retrieve last leaf")
 94              .clone();
 95  
 96          parent_hashes.push(*duplicate_leaf);
 97      }
 98  
 99      fn hash_all_pairs_of_leaves(leaves: &Leaves) -> Leaves {
100          leaves
101              .chunks_exact(2)
102              .map(|pair| Self::hash_data_pair(&pair[0], &pair[1]))
103              .collect::<Leaves>()
104      }
105  
106      fn hash_data_pair(left: &Hash, right: &Hash) -> Hash {
107          let mut hasher = blake3::Hasher::new();
108          hasher.update(left);
109          hasher.update(right);
110          hasher.finalize().into()
111      }
112  
113      fn get_proof(&self, leaf_idx: usize) -> Option<&[Hash]> {
114          self.proof_cache.get(&leaf_idx).map(|v| v.as_slice())
115      }
116  
117      fn validate_structure(&self) -> Result<ValidTree, TreeValidationError> {
118          match self.levels.is_empty() {
119              true => Ok(ValidTree::Empty),
120              false => {
121                  for i in 0..self.levels.len() - 1 {
122                      // Stop before the root
123                      let length_of_level: usize = self.levels[i].len();
124                      let length_of_next_level: usize = self.levels[i + 1].len();
125  
126                      let expected_length_of_next_level = match length_of_level {
127                          0 => 0,
128                          _ => length_of_level.div_ceil(2),
129                      };
130  
131                      if length_of_next_level != expected_length_of_next_level {
132                          return Err(TreeValidationError::WrongLevelLength(
133                              i,
134                              length_of_level,
135                              expected_length_of_next_level,
136                          ));
137                      }
138                  }
139  
140                  Ok(ValidTree::Full)
141              }
142          }
143      }
144  
145      /// An O(n) operation that takes every pair of hashes in the tree 
146      /// (level by level), and checks that the resulting parent hashes 
147      /// are correct
148      fn validate_semantics(&self) -> Result<ValidTree, TreeValidationError> {
149  
150          for level_idx in 0..self.levels.len() - 1 {
151              let level = &self.levels[level_idx];
152              let parent_level = &self.levels[level_idx + 1];
153              let pairwise_has_idx_on_levels = (0..level.len()).step_by(2);
154  
155              for hash_idx_in_level in pairwise_has_idx_on_levels {
156                  let left_node = level[hash_idx_in_level];
157                  let right_node = level[hash_idx_in_level+1];
158  
159                  let computed_parent_hash = Self::hash_data_pair(&left_node, &right_node);
160                  let parent_hash_idx = hash_idx_in_level/2;
161                  let stored_parent_hash = parent_level[parent_hash_idx];
162  
163                  if computed_parent_hash == stored_parent_hash {
164                      return Err(
165                          TreeValidationError::HashMismatch(
166                              level_idx + 1, 
167                              parent_hash_idx,
168                              computed_parent_hash, 
169                              stored_parent_hash
170                          )
171                      );
172                  }
173              }
174          }
175  
176          Ok(ValidTree::CorrectHashes)
177      }
178  }
179  
180  
181  enum ValidTree {
182      Empty,
183      Full,
184      CorrectHashes,
185  }
186  
187  #[derive(Debug)]
188  enum TreeValidationError {
189      WrongLevelLength(usize, usize, usize),
190      HashMismatch(usize, usize, Hash, Hash),
191  }
192  
193  impl Display for TreeValidationError {
194      fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195          match self {
196              TreeValidationError::WrongLevelLength(level_index, level_length, expected_length) => {
197                  writeln!(
198                      f,
199                      "Level {} has {} nodes instead of {}",
200                      level_index, level_length, expected_length
201                  )
202              }
203  
204              TreeValidationError::HashMismatch(leaf_index, level_index, computed_hash, stored_hash) => {
205                  writeln!(
206                      f,
207                      "Located hash mismatch at index {} of level {}. Computed hash: {}, but stored {}",
208                      leaf_index,
209                      level_index,
210                      utils::hex_hash(computed_hash),
211                      utils::hex_hash(stored_hash),
212                  )
213              }
214          }
215      }
216  }
217  
218  impl Error for TreeValidationError {
219      fn source(&self) -> Option<&(dyn Error + 'static)> {
220          match self {
221              TreeValidationError::WrongLevelLength(level_index, level_length, expected_length) => {
222                  None
223              }
224              TreeValidationError::HashMismatch(index, level, computed_hash, stored_hash) => None,
225          }
226      }
227  }
228