/ algorithms / src / crypto_hash / poseidon.rs
poseidon.rs
  1  // Copyright (c) 2025 ADnet Contributors
  2  // This file is part of the AlphaVM library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use crate::{AlgebraicSponge, DuplexSpongeMode, nonnative_params::*};
 17  use alphavm_fields::{FieldParameters, PoseidonParameters, PrimeField, ToConstraintField};
 18  use alphavm_utilities::{BigInteger, FromBits, ToBits};
 19  
 20  use smallvec::SmallVec;
 21  use std::{
 22      iter::Peekable,
 23      ops::{Index, IndexMut},
 24      sync::Arc,
 25  };
 26  
 27  #[derive(Copy, Clone, Debug)]
 28  pub struct State<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
 29      capacity_state: [F; CAPACITY],
 30      rate_state: [F; RATE],
 31  }
 32  
 33  impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Default for State<F, RATE, CAPACITY> {
 34      fn default() -> Self {
 35          Self { capacity_state: [F::zero(); CAPACITY], rate_state: [F::zero(); RATE] }
 36      }
 37  }
 38  
 39  impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> State<F, RATE, CAPACITY> {
 40      /// Returns an immutable iterator over the state.
 41      pub fn iter(&self) -> impl Iterator<Item = &F> + Clone {
 42          self.capacity_state.iter().chain(self.rate_state.iter())
 43      }
 44  
 45      /// Returns a mutable iterator over the state.
 46      pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut F> {
 47          self.capacity_state.iter_mut().chain(self.rate_state.iter_mut())
 48      }
 49  }
 50  
 51  impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Index<usize> for State<F, RATE, CAPACITY> {
 52      type Output = F;
 53  
 54      fn index(&self, index: usize) -> &Self::Output {
 55          assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
 56          if index < CAPACITY { &self.capacity_state[index] } else { &self.rate_state[index - CAPACITY] }
 57      }
 58  }
 59  
 60  impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> IndexMut<usize> for State<F, RATE, CAPACITY> {
 61      fn index_mut(&mut self, index: usize) -> &mut Self::Output {
 62          assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
 63          if index < CAPACITY { &mut self.capacity_state[index] } else { &mut self.rate_state[index - CAPACITY] }
 64      }
 65  }
 66  
 67  #[derive(Clone, Debug, PartialEq, Eq)]
 68  pub struct Poseidon<F: PrimeField, const RATE: usize> {
 69      parameters: Arc<PoseidonParameters<F, RATE, 1>>,
 70  }
 71  
 72  impl<F: PrimeField, const RATE: usize> Poseidon<F, RATE> {
 73      /// Initializes a new instance of the cryptographic hash function.
 74      pub fn setup() -> Self {
 75          Self { parameters: Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) }
 76      }
 77  
 78      /// Evaluate the cryptographic hash function over a list of field elements
 79      /// as input.
 80      pub fn evaluate(&self, input: &[F]) -> F {
 81          self.evaluate_many(input, 1)[0]
 82      }
 83  
 84      /// Evaluate the cryptographic hash function over a list of field elements
 85      /// as input, and returns the specified number of field elements as
 86      /// output.
 87      pub fn evaluate_many(&self, input: &[F], num_outputs: usize) -> Vec<F> {
 88          let mut sponge = PoseidonSponge::<F, RATE, 1>::new_with_parameters(&self.parameters);
 89          sponge.absorb_native_field_elements(input);
 90          sponge.squeeze_native_field_elements(num_outputs).to_vec()
 91      }
 92  
 93      /// Evaluate the cryptographic hash function over a non-fixed-length vector,
 94      /// in which the length also needs to be hashed.
 95      pub fn evaluate_with_len(&self, input: &[F]) -> F {
 96          self.evaluate(&[vec![F::from(input.len() as u128)], input.to_vec()].concat())
 97      }
 98  
 99      pub fn parameters(&self) -> &Arc<PoseidonParameters<F, RATE, 1>> {
100          &self.parameters
101      }
102  }
103  
104  /// A duplex sponge based using the Poseidon permutation.
105  ///
106  /// This implementation of Poseidon is entirely from Fractal's implementation in
107  /// [COS20][cos] with small syntax changes.
108  ///
109  /// [cos]: https://eprint.iacr.org/2019/1076
110  #[derive(Clone, Debug)]
111  pub struct PoseidonSponge<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
112      /// Sponge Parameters
113      parameters: Arc<PoseidonParameters<F, RATE, CAPACITY>>,
114      /// Current sponge's state (current elements in the permutation block)
115      state: State<F, RATE, CAPACITY>,
116      /// Current mode (whether its absorbing or squeezing)
117      pub mode: DuplexSpongeMode,
118      /// A persistent lookup table used when compressing elements.
119      adjustment_factor_lookup_table: Arc<[F]>,
120  }
121  
122  impl<F: PrimeField, const RATE: usize> AlgebraicSponge<F, RATE> for PoseidonSponge<F, RATE, 1> {
123      type Parameters = Arc<PoseidonParameters<F, RATE, 1>>;
124  
125      fn sample_parameters() -> Self::Parameters {
126          Arc::new(F::default_poseidon_parameters::<RATE>().unwrap())
127      }
128  
129      fn new_with_parameters(parameters: &Self::Parameters) -> Self {
130          Self {
131              parameters: parameters.clone(),
132              state: State::default(),
133              mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 },
134              adjustment_factor_lookup_table: {
135                  let capacity = F::size_in_bits() - 1;
136                  let mut table = Vec::<F>::with_capacity(capacity);
137  
138                  let mut cur = F::one();
139                  for _ in 0..capacity {
140                      table.push(cur);
141                      cur.double_in_place();
142                  }
143  
144                  table.into()
145              },
146          }
147      }
148  
149      /// Takes in field elements.
150      fn absorb_native_field_elements<T: ToConstraintField<F>>(&mut self, elements: &[T]) {
151          let input = elements.iter().flat_map(|e| e.to_field_elements().unwrap()).collect::<Vec<_>>();
152          if !input.is_empty() {
153              match self.mode {
154                  DuplexSpongeMode::Absorbing { mut next_absorb_index } => {
155                      if next_absorb_index == RATE {
156                          self.permute();
157                          next_absorb_index = 0;
158                      }
159                      self.absorb_internal(next_absorb_index, &input);
160                  }
161                  DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => {
162                      self.permute();
163                      self.absorb_internal(0, &input);
164                  }
165              }
166          }
167      }
168  
169      /// Takes in field elements.
170      fn absorb_nonnative_field_elements<Target: PrimeField>(&mut self, elements: impl IntoIterator<Item = Target>) {
171          Self::push_elements_to_sponge(self, elements, OptimizationType::Weight);
172      }
173  
174      fn squeeze_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
175          self.get_fe(num, false)
176      }
177  
178      fn squeeze_native_field_elements(&mut self, num_elements: usize) -> SmallVec<[F; 10]> {
179          if num_elements == 0 {
180              return SmallVec::<[F; 10]>::new();
181          }
182          let mut output = if num_elements <= 10 {
183              smallvec::smallvec_inline![F::zero(); 10]
184          } else {
185              smallvec::smallvec![F::zero(); num_elements]
186          };
187  
188          match self.mode {
189              DuplexSpongeMode::Absorbing { next_absorb_index: _ } => {
190                  self.permute();
191                  self.squeeze_internal(0, &mut output[..num_elements]);
192              }
193              DuplexSpongeMode::Squeezing { mut next_squeeze_index } => {
194                  if next_squeeze_index == RATE {
195                      self.permute();
196                      next_squeeze_index = 0;
197                  }
198                  self.squeeze_internal(next_squeeze_index, &mut output[..num_elements]);
199              }
200          }
201  
202          output.truncate(num_elements);
203          output
204      }
205  
206      /// Takes out field elements of 168 bits.
207      fn squeeze_short_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
208          self.get_fe(num, true)
209      }
210  }
211  
212  impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
213      #[inline]
214      fn apply_ark(&mut self, round_number: usize) {
215          for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) {
216              *state_elem += ark_elem;
217          }
218      }
219  
220      #[inline]
221      fn apply_s_box(&mut self, is_full_round: bool) {
222          if is_full_round {
223              // Full rounds apply the S Box (x^alpha) to every element of state
224              for elem in self.state.iter_mut() {
225                  *elem = elem.pow([self.parameters.alpha]);
226              }
227          } else {
228              // Partial rounds apply the S Box (x^alpha) to just the first element of state
229              self.state[0] = self.state[0].pow([self.parameters.alpha]);
230          }
231      }
232  
233      #[inline]
234      fn apply_mds(&mut self) {
235          let mut new_state = State::default();
236          new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
237              *new_elem = F::sum_of_products(self.state.iter(), mds_row.iter());
238          });
239          self.state = new_state;
240      }
241  
242      #[inline]
243      fn permute(&mut self) {
244          // Determine the partial rounds range bound.
245          let partial_rounds = self.parameters.partial_rounds;
246          let full_rounds = self.parameters.full_rounds;
247          let full_rounds_over_2 = full_rounds / 2;
248          let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds);
249  
250          // Iterate through all rounds to permute.
251          for i in 0..(partial_rounds + full_rounds) {
252              let is_full_round = !partial_round_range.contains(&i);
253              self.apply_ark(i);
254              self.apply_s_box(is_full_round);
255              self.apply_mds();
256          }
257      }
258  
259      /// Absorbs everything in elements, this does not end in an absorption.
260      #[inline]
261      fn absorb_internal(&mut self, mut rate_start: usize, input: &[F]) {
262          if !input.is_empty() {
263              let first_chunk_size = std::cmp::min(RATE - rate_start, input.len());
264              let num_elements_remaining = input.len() - first_chunk_size;
265              let (first_chunk, rest_chunk) = input.split_at(first_chunk_size);
266              let rest_chunks = rest_chunk.chunks(RATE);
267              // The total number of chunks is `elements[num_elements_remaining..].len() /
268              // RATE`, plus 1 for the remainder.
269              let total_num_chunks = 1 + // 1 for the first chunk
270                  // We add all the chunks that are perfectly divisible by `RATE`
271                  (num_elements_remaining / RATE) +
272                  // And also add 1 if the last chunk is non-empty
273                  // (i.e. if `num_elements_remaining` is not a multiple of `RATE`)
274                  usize::from((num_elements_remaining % RATE) != 0);
275  
276              // Absorb the input elements, `RATE` elements at a time, except for the first
277              // chunk, which is of size `RATE - rate_start`.
278              for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
279                  for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state[rate_start..]) {
280                      *state_elem += element;
281                  }
282                  // Are we in the last chunk?
283                  // If so, let's wrap up.
284                  if i == total_num_chunks - 1 {
285                      self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() };
286                      return;
287                  } else {
288                      self.permute();
289                  }
290                  rate_start = 0;
291              }
292          }
293      }
294  
295      /// Squeeze |output| many elements. This does not end in a squeeze
296      #[inline]
297      fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [F]) {
298          let output_size = output.len();
299          if output_size != 0 {
300              let first_chunk_size = std::cmp::min(RATE - rate_start, output.len());
301              let num_output_remaining = output.len() - first_chunk_size;
302              let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size);
303              assert_eq!(rest_chunk.len(), num_output_remaining);
304              let rest_chunks = rest_chunk.chunks_mut(RATE);
305              // The total number of chunks is `output[num_output_remaining..].len() / RATE`,
306              // plus 1 for the remainder.
307              let total_num_chunks = 1 + // 1 for the first chunk
308                  // We add all the chunks that are perfectly divisible by `RATE`
309                  (num_output_remaining / RATE) +
310                  // And also add 1 if the last chunk is non-empty
311                  // (i.e. if `num_output_remaining` is not a multiple of `RATE`)
312                  usize::from((num_output_remaining % RATE) != 0);
313  
314              // Absorb the input output, `RATE` output at a time, except for the first chunk,
315              // which is of size `RATE - rate_start`.
316              for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
317                  let range = rate_start..(rate_start + chunk.len());
318                  debug_assert_eq!(
319                      chunk.len(),
320                      self.state.rate_state[range.clone()].len(),
321                      "failed with squeeze {output_size} at rate {RATE} and rate_start {rate_start}"
322                  );
323                  chunk.copy_from_slice(&self.state.rate_state[range]);
324                  // Are we in the last chunk?
325                  // If so, let's wrap up.
326                  if i == total_num_chunks - 1 {
327                      self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) };
328                      return;
329                  } else {
330                      self.permute();
331                  }
332                  rate_start = 0;
333              }
334          }
335      }
336  
337      /// Compress every two elements if possible.
338      /// Provides a vector of (limb, num_of_additions), both of which are F.
339      pub fn compress_elements<TargetField: PrimeField, I: Iterator<Item = (F, F)>>(
340          &self,
341          mut src_limbs: Peekable<I>,
342          ty: OptimizationType,
343      ) -> Vec<F> {
344          let capacity = F::size_in_bits() - 1;
345          let mut dest_limbs = Vec::<F>::new();
346  
347          let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty);
348  
349          // Prepare a reusable vector to be used in overhead calculation.
350          let mut num_bits = Vec::new();
351  
352          while let Some(first) = src_limbs.next() {
353              let second = src_limbs.peek();
354  
355              let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one(), &mut num_bits);
356              let second_max_bits_per_limb = if let Some(second) = second {
357                  params.bits_per_limb + crate::overhead!(second.1 + F::one(), &mut num_bits)
358              } else {
359                  0
360              };
361  
362              if let Some(second) = second {
363                  if first_max_bits_per_limb + second_max_bits_per_limb <= capacity {
364                      let adjustment_factor = &self.adjustment_factor_lookup_table[second_max_bits_per_limb];
365  
366                      dest_limbs.push(first.0 * adjustment_factor + second.0);
367                      src_limbs.next();
368                  } else {
369                      dest_limbs.push(first.0);
370                  }
371              } else {
372                  dest_limbs.push(first.0);
373              }
374          }
375  
376          dest_limbs
377      }
378  
379      /// Convert a `TargetField` element into limbs (not constraints)
380      /// This is an internal function that would be reused by a number of other
381      /// functions
382      pub fn get_limbs_representations<TargetField: PrimeField>(
383          elem: &TargetField,
384          optimization_type: OptimizationType,
385      ) -> SmallVec<[F; 10]> {
386          Self::get_limbs_representations_from_big_integer::<TargetField>(&elem.to_bigint(), optimization_type)
387      }
388  
389      /// Obtain the limbs directly from a big int
390      pub fn get_limbs_representations_from_big_integer<TargetField: PrimeField>(
391          elem: &<TargetField as PrimeField>::BigInteger,
392          optimization_type: OptimizationType,
393      ) -> SmallVec<[F; 10]> {
394          let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type);
395  
396          // Prepare a reusable vector for the BE bits.
397          let mut cur_bits = Vec::new();
398          // Push the lower limbs first
399          let mut limbs: SmallVec<[F; 10]> = SmallVec::new();
400          let mut cur = *elem;
401          for _ in 0..params.num_limbs {
402              cur.write_bits_be(&mut cur_bits); // `write_bits_be` is big endian
403              let cur_mod_r =
404                  <F as PrimeField>::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..])
405                      .unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want.
406              limbs.push(F::from_bigint(cur_mod_r).unwrap());
407              cur.divn(params.bits_per_limb as u32);
408              // Clear the vector after every iteration so its allocation can be reused.
409              cur_bits.clear();
410          }
411  
412          // then we reverse, so that the limbs are ``big limb first''
413          limbs.reverse();
414  
415          limbs
416      }
417  
418      /// Push elements to sponge, treated in the non-native field
419      /// representations.
420      pub fn push_elements_to_sponge<TargetField: PrimeField>(
421          &mut self,
422          src: impl IntoIterator<Item = TargetField>,
423          ty: OptimizationType,
424      ) {
425          let src_limbs = src
426              .into_iter()
427              .flat_map(|elem| {
428                  let limbs = Self::get_limbs_representations(&elem, ty);
429                  limbs.into_iter().map(|limb| (limb, F::one()))
430                  // specifically set to one, since most gadgets in the constraint
431                  // world would not have zero noise (due to the relatively weak
432                  // normal form testing in `alloc`)
433              })
434              .peekable();
435  
436          let dest_limbs = self.compress_elements::<TargetField, _>(src_limbs, ty);
437          self.absorb_native_field_elements(&dest_limbs);
438      }
439  
440      /// obtain random bits from hashchain.
441      /// not guaranteed to be uniformly distributed, should only be used in
442      /// certain situations.
443      pub fn get_bits(&mut self, num_bits: usize) -> Vec<bool> {
444          let bits_per_element = F::size_in_bits() - 1;
445          let num_elements = num_bits.div_ceil(bits_per_element);
446  
447          let src_elements = self.squeeze_native_field_elements(num_elements);
448          let mut dest_bits = Vec::<bool>::with_capacity(num_elements * bits_per_element);
449  
450          let skip = (F::Parameters::REPR_SHAVE_BITS + 1) as usize;
451          for elem in src_elements.iter() {
452              // discard the highest bit
453              let elem_bits = elem.to_bigint().to_bits_be();
454              dest_bits.extend_from_slice(&elem_bits[skip..]);
455          }
456          dest_bits.truncate(num_bits);
457  
458          dest_bits
459      }
460  
461      /// obtain random field elements from hashchain.
462      /// not guaranteed to be uniformly distributed, should only be used in
463      /// certain situations.
464      pub fn get_fe<TargetField: PrimeField>(
465          &mut self,
466          num_elements: usize,
467          outputs_short_elements: bool,
468      ) -> SmallVec<[TargetField; 10]> {
469          let num_bits_per_nonnative = if outputs_short_elements {
470              168
471          } else {
472              TargetField::size_in_bits() - 1 // also omit the highest bit
473          };
474          let bits = self.get_bits(num_bits_per_nonnative * num_elements);
475  
476          let mut lookup_table = Vec::<TargetField>::with_capacity(num_bits_per_nonnative);
477          let mut cur = TargetField::one();
478          for _ in 0..num_bits_per_nonnative {
479              lookup_table.push(cur);
480              cur.double_in_place();
481          }
482  
483          let dest_elements = bits
484              .chunks_exact(num_bits_per_nonnative)
485              .map(|per_nonnative_bits| {
486                  // technically, this can be done via BigInteger::from_bits; here, we use this
487                  // method for consistency with the gadget counterpart
488                  let mut res = TargetField::zero();
489  
490                  for (i, bit) in per_nonnative_bits.iter().rev().enumerate() {
491                      if *bit {
492                          res += &lookup_table[i];
493                      }
494                  }
495                  res
496              })
497              .collect::<SmallVec<_>>();
498          debug_assert_eq!(dest_elements.len(), num_elements);
499  
500          dest_elements
501      }
502  }