/ circuit / algorithms / src / poseidon / hash_many.rs
hash_many.rs
  1  // Copyright (c) 2019-2025 Alpha-Delta Network Inc.
  2  // This file is part of the alphavm library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use super::*;
 17  
 18  impl<E: Environment, const RATE: usize> HashMany for Poseidon<E, RATE> {
 19      type Input = Field<E>;
 20      type Output = Field<E>;
 21  
 22      #[inline]
 23      fn hash_many(&self, input: &[Self::Input], num_outputs: u16) -> Vec<Self::Output> {
 24          // Construct the preimage: [ DOMAIN || LENGTH(INPUT) || [0; RATE-2] || INPUT ].
 25          let mut preimage = Vec::with_capacity(RATE + input.len());
 26          preimage.push(self.domain.clone());
 27          preimage.push(Field::constant(console::Field::from_u128(input.len() as u128)));
 28          preimage.resize(RATE, Field::zero()); // Pad up to RATE.
 29          preimage.extend_from_slice(input);
 30  
 31          // Initialize a new sponge.
 32          let mut state = vec![Field::zero(); RATE + CAPACITY];
 33          let mut mode = DuplexSpongeMode::Absorbing { next_absorb_index: 0 };
 34  
 35          // Absorb the input and squeeze the output.
 36          self.absorb(&mut state, &mut mode, &preimage);
 37          self.squeeze(&mut state, &mut mode, num_outputs)
 38      }
 39  }
 40  
 41  #[allow(clippy::needless_borrow)]
 42  impl<E: Environment, const RATE: usize> Poseidon<E, RATE> {
 43      /// Absorbs the input elements into state.
 44      #[inline]
 45      fn absorb(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, input: &[Field<E>]) {
 46          if !input.is_empty() {
 47              // Determine the absorb index.
 48              let (mut absorb_index, should_permute) = match *mode {
 49                  DuplexSpongeMode::Absorbing { next_absorb_index } => match next_absorb_index == RATE {
 50                      true => (0, true),
 51                      false => (next_absorb_index, false),
 52                  },
 53                  DuplexSpongeMode::Squeezing { .. } => (0, true),
 54              };
 55  
 56              // Proceed to permute the state, if necessary.
 57              if should_permute {
 58                  self.permute(state);
 59              }
 60  
 61              let mut remaining = input;
 62              loop {
 63                  // Compute the starting index.
 64                  let start = CAPACITY + absorb_index;
 65  
 66                  // Check if we can exit the loop.
 67                  if absorb_index + remaining.len() <= RATE {
 68                      // Absorb the state elements into the input.
 69                      remaining.iter().enumerate().for_each(|(i, element)| state[start + i] += element);
 70                      // Update the sponge mode.
 71                      *mode = DuplexSpongeMode::Absorbing { next_absorb_index: absorb_index + remaining.len() };
 72                      return;
 73                  }
 74  
 75                  // Otherwise, proceed to absorb `(rate - absorb_index)` elements.
 76                  let num_absorbed = RATE - absorb_index;
 77                  remaining.iter().enumerate().take(num_absorbed).for_each(|(i, element)| state[start + i] += element);
 78  
 79                  // Permute the state.
 80                  self.permute(state);
 81  
 82                  // Repeat with the updated input slice and absorb index.
 83                  remaining = &remaining[num_absorbed..];
 84                  absorb_index = 0;
 85              }
 86          }
 87      }
 88  
 89      /// Squeeze the specified number of state elements into the output.
 90      #[inline]
 91      fn squeeze(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, num_outputs: u16) -> Vec<Field<E>> {
 92          let mut output = vec![Field::zero(); num_outputs as usize];
 93          if num_outputs != 0 {
 94              self.squeeze_internal(state, mode, &mut output);
 95          }
 96          output
 97      }
 98  
 99      /// Squeeze the state elements into the output.
100      #[inline]
101      fn squeeze_internal(&self, state: &mut [Field<E>], mode: &mut DuplexSpongeMode, output: &mut [Field<E>]) {
102          // Determine the squeeze index.
103          let (mut squeeze_index, should_permute) = match *mode {
104              DuplexSpongeMode::Absorbing { .. } => (0, true),
105              DuplexSpongeMode::Squeezing { next_squeeze_index } => match next_squeeze_index == RATE {
106                  true => (0, true),
107                  false => (next_squeeze_index, false),
108              },
109          };
110  
111          // Proceed to permute the state, if necessary.
112          if should_permute {
113              self.permute(state);
114          }
115  
116          let mut remaining = output;
117          loop {
118              // Compute the starting index.
119              let start = CAPACITY + squeeze_index;
120  
121              // Check if we can exit the loop.
122              if squeeze_index + remaining.len() <= RATE {
123                  // Store the state elements into the output.
124                  remaining.clone_from_slice(&state[start..(start + remaining.len())]);
125                  // Update the sponge mode.
126                  *mode = DuplexSpongeMode::Squeezing { next_squeeze_index: squeeze_index + remaining.len() };
127                  return;
128              }
129  
130              // Otherwise, proceed to squeeze `(rate - squeeze_index)` elements.
131              let num_squeezed = RATE - squeeze_index;
132              remaining[..num_squeezed].clone_from_slice(&state[start..(start + num_squeezed)]);
133  
134              // Permute.
135              self.permute(state);
136  
137              // Repeat with the updated output slice and squeeze index.
138              remaining = &mut remaining[num_squeezed..];
139              squeeze_index = 0;
140          }
141      }
142  
143      /// Apply the additive round keys in-place.
144      #[inline]
145      fn apply_ark(&self, state: &mut [Field<E>], round: usize) {
146          for (i, element) in state.iter_mut().enumerate() {
147              *element += &self.ark[round][i];
148          }
149      }
150  
151      /// Apply the S-Box based on whether it is a full round or partial round.
152      #[inline]
153      fn apply_s_box(&self, state: &mut [Field<E>], is_full_round: bool) {
154          if is_full_round {
155              // Full rounds apply the S Box (x^alpha) to every element of state
156              for element in state.iter_mut() {
157                  *element = (&*element).pow(&self.alpha);
158              }
159          } else {
160              // Partial rounds apply the S Box (x^alpha) to just the first element of state
161              state[0] = (&state[0]).pow(&self.alpha);
162          }
163      }
164  
165      /// Apply the Maximally Distance Separating (MDS) matrix in-place.
166      #[inline]
167      fn apply_mds(&self, state: &mut [Field<E>], new_state: &mut Vec<Field<E>>) {
168          new_state.clear();
169          for i in 0..state.len() {
170              let mut accumulator = Field::zero();
171              for (j, element) in state.iter().enumerate() {
172                  accumulator += element * &self.mds[i][j];
173              }
174              new_state.push(accumulator);
175          }
176          state.swap_with_slice(new_state);
177      }
178  
179      /// Apply the permutation for all rounds in-place.
180      #[inline]
181      fn permute(&self, state: &mut [Field<E>]) {
182          // Determine the partial rounds range bound.
183          let full_rounds_over_2 = self.full_rounds / 2;
184          let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + self.partial_rounds);
185  
186          // Iterate through all rounds to permute.
187          let mut new_state = Vec::with_capacity(state.len());
188          for i in 0..(self.partial_rounds + self.full_rounds) {
189              let is_full_round = !partial_round_range.contains(&i);
190              self.apply_ark(state, i);
191              self.apply_s_box(state, is_full_round);
192              self.apply_mds(state, &mut new_state);
193          }
194      }
195  }
196  
197  #[cfg(test)]
198  mod tests {
199      use super::*;
200      use alphavm_circuit_types::environment::Circuit;
201  
202      use anyhow::Result;
203  
204      const DOMAIN: &str = "PoseidonCircuit0";
205      const ITERATIONS: usize = 10;
206      const RATE: u16 = 4;
207  
208      fn check_hash_many(
209          mode: Mode,
210          num_inputs: usize,
211          num_outputs: u16,
212          num_constants: u64,
213          num_public: u64,
214          num_private: u64,
215          num_constraints: u64,
216          rng: &mut TestRng,
217      ) -> Result<()> {
218          use console::HashMany as H;
219  
220          let native = console::Poseidon::<<Circuit as Environment>::Network, { RATE as usize }>::setup(DOMAIN)?;
221          let poseidon = Poseidon::<Circuit, { RATE as usize }>::constant(native.clone());
222  
223          for i in 0..ITERATIONS {
224              // Prepare the preimage.
225              let native_input = (0..num_inputs)
226                  .map(|_| console::Field::<<Circuit as Environment>::Network>::rand(rng))
227                  .collect::<Vec<_>>();
228              let input = native_input.iter().map(|v| Field::<Circuit>::new(mode, *v)).collect::<Vec<_>>();
229  
230              // Compute the native hash.
231              let expected = native.hash_many(&native_input, num_outputs);
232  
233              // Compute the circuit hash.
234              Circuit::scope(format!("Poseidon {mode} {i} {num_outputs}"), || {
235                  let candidate = poseidon.hash_many(&input, num_outputs);
236                  for (expected_element, candidate_element) in expected.iter().zip_eq(&candidate) {
237                      assert_eq!(*expected_element, candidate_element.eject_value());
238                  }
239                  let case = format!("(mode = {mode}, num_inputs = {num_inputs}, num_outputs = {num_outputs})");
240                  assert_scope!(case, num_constants, num_public, num_private, num_constraints);
241              });
242              Circuit::reset();
243          }
244          Ok(())
245      }
246  
247      #[test]
248      fn test_hash_many_constant() -> Result<()> {
249          let mut rng = TestRng::default();
250  
251          for num_inputs in 0..=RATE {
252              for num_outputs in 0..=RATE {
253                  check_hash_many(Mode::Constant, num_inputs as usize, num_outputs, 1, 0, 0, 0, &mut rng)?;
254              }
255          }
256          Ok(())
257      }
258  
259      #[test]
260      fn test_hash_many_public() -> Result<()> {
261          let mut rng = TestRng::default();
262  
263          for num_outputs in 0..=RATE {
264              check_hash_many(Mode::Public, 0, num_outputs, 1, 0, 0, 0, &mut rng)?;
265          }
266          for num_outputs in 1..=RATE {
267              check_hash_many(Mode::Public, 1, num_outputs, 1, 0, 335, 335, &mut rng)?;
268              check_hash_many(Mode::Public, 2, num_outputs, 1, 0, 340, 340, &mut rng)?;
269              check_hash_many(Mode::Public, 3, num_outputs, 1, 0, 345, 345, &mut rng)?;
270              check_hash_many(Mode::Public, 4, num_outputs, 1, 0, 350, 350, &mut rng)?;
271              check_hash_many(Mode::Public, 5, num_outputs, 1, 0, 705, 705, &mut rng)?;
272              check_hash_many(Mode::Public, 6, num_outputs, 1, 0, 705, 705, &mut rng)?;
273          }
274          for num_outputs in (RATE + 1)..=(RATE * 2) {
275              check_hash_many(Mode::Public, 1, num_outputs, 1, 0, 690, 690, &mut rng)?;
276              check_hash_many(Mode::Public, 2, num_outputs, 1, 0, 695, 695, &mut rng)?;
277              check_hash_many(Mode::Public, 3, num_outputs, 1, 0, 700, 700, &mut rng)?;
278              check_hash_many(Mode::Public, 4, num_outputs, 1, 0, 705, 705, &mut rng)?;
279              check_hash_many(Mode::Public, 5, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
280              check_hash_many(Mode::Public, 6, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
281          }
282          Ok(())
283      }
284  
285      #[test]
286      fn test_hash_many_private() -> Result<()> {
287          let mut rng = TestRng::default();
288  
289          for num_outputs in 0..=RATE {
290              check_hash_many(Mode::Private, 0, num_outputs, 1, 0, 0, 0, &mut rng)?;
291          }
292          for num_outputs in 1..=RATE {
293              check_hash_many(Mode::Private, 1, num_outputs, 1, 0, 335, 335, &mut rng)?;
294              check_hash_many(Mode::Private, 2, num_outputs, 1, 0, 340, 340, &mut rng)?;
295              check_hash_many(Mode::Private, 3, num_outputs, 1, 0, 345, 345, &mut rng)?;
296              check_hash_many(Mode::Private, 4, num_outputs, 1, 0, 350, 350, &mut rng)?;
297              check_hash_many(Mode::Private, 5, num_outputs, 1, 0, 705, 705, &mut rng)?;
298              check_hash_many(Mode::Private, 6, num_outputs, 1, 0, 705, 705, &mut rng)?;
299          }
300          for num_outputs in (RATE + 1)..=(RATE * 2) {
301              check_hash_many(Mode::Private, 1, num_outputs, 1, 0, 690, 690, &mut rng)?;
302              check_hash_many(Mode::Private, 2, num_outputs, 1, 0, 695, 695, &mut rng)?;
303              check_hash_many(Mode::Private, 3, num_outputs, 1, 0, 700, 700, &mut rng)?;
304              check_hash_many(Mode::Private, 4, num_outputs, 1, 0, 705, 705, &mut rng)?;
305              check_hash_many(Mode::Private, 5, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
306              check_hash_many(Mode::Private, 6, num_outputs, 1, 0, 1060, 1060, &mut rng)?;
307          }
308          Ok(())
309      }
310  }