/ algorithms / src / crypto_hash / poseidon.rs
poseidon.rs
  1  // Copyright (c) 2025-2026 ACDC Network
  2  // This file is part of the alphavm library.
  3  //
  4  // Alpha Chain | Delta Chain Protocol
  5  // International Monetary Graphite.
  6  //
  7  // Derived from Aleo (https://aleo.org) and ProvableHQ (https://provable.com).
  8  // They built world-class ZK infrastructure. We installed the EASY button.
  9  // Their cryptography: elegant. Our modifications: bureaucracy-compatible.
 10  // Original brilliance: theirs. Robert's Rules: ours. Bugs: definitely ours.
 11  //
 12  // Original Aleo/ProvableHQ code subject to Apache 2.0 https://www.apache.org/licenses/LICENSE-2.0
 13  // All modifications and new work: CC0 1.0 Universal Public Domain Dedication.
 14  // No rights reserved. No permission required. No warranty. No refunds.
 15  //
 16  // https://creativecommons.org/publicdomain/zero/1.0/
 17  // SPDX-License-Identifier: CC0-1.0
 18  
 19  use crate::{nonnative_params::*, AlgebraicSponge, DuplexSpongeMode};
 20  use alphavm_fields::{FieldParameters, PoseidonParameters, PrimeField, ToConstraintField};
 21  use alphavm_utilities::{BigInteger, FromBits, ToBits};
 22  
 23  use smallvec::SmallVec;
 24  use std::{
 25      iter::Peekable,
 26      ops::{Index, IndexMut},
 27      sync::Arc,
 28  };
 29  
 30  #[derive(Copy, Clone, Debug)]
 31  pub struct State<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
 32      capacity_state: [F; CAPACITY],
 33      rate_state: [F; RATE],
 34  }
 35  
 36  impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Default for State<F, RATE, CAPACITY> {
 37      fn default() -> Self {
 38          Self { capacity_state: [F::zero(); CAPACITY], rate_state: [F::zero(); RATE] }
 39      }
 40  }
 41  
 42  impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> State<F, RATE, CAPACITY> {
 43      /// Returns an immutable iterator over the state.
 44      pub fn iter(&self) -> impl Iterator<Item = &F> + Clone {
 45          self.capacity_state.iter().chain(self.rate_state.iter())
 46      }
 47  
 48      /// Returns a mutable iterator over the state.
 49      pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut F> {
 50          self.capacity_state.iter_mut().chain(self.rate_state.iter_mut())
 51      }
 52  }
 53  
 54  impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Index<usize> for State<F, RATE, CAPACITY> {
 55      type Output = F;
 56  
 57      fn index(&self, index: usize) -> &Self::Output {
 58          assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
 59          if index < CAPACITY {
 60              &self.capacity_state[index]
 61          } else {
 62              &self.rate_state[index - CAPACITY]
 63          }
 64      }
 65  }
 66  
 67  impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> IndexMut<usize> for State<F, RATE, CAPACITY> {
 68      fn index_mut(&mut self, index: usize) -> &mut Self::Output {
 69          assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY);
 70          if index < CAPACITY {
 71              &mut self.capacity_state[index]
 72          } else {
 73              &mut self.rate_state[index - CAPACITY]
 74          }
 75      }
 76  }
 77  
 78  #[derive(Clone, Debug, PartialEq, Eq)]
 79  pub struct Poseidon<F: PrimeField, const RATE: usize> {
 80      parameters: Arc<PoseidonParameters<F, RATE, 1>>,
 81  }
 82  
 83  impl<F: PrimeField, const RATE: usize> Poseidon<F, RATE> {
 84      /// Initializes a new instance of the cryptographic hash function.
 85      pub fn setup() -> Self {
 86          Self { parameters: Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) }
 87      }
 88  
 89      /// Evaluate the cryptographic hash function over a list of field elements
 90      /// as input.
 91      pub fn evaluate(&self, input: &[F]) -> F {
 92          self.evaluate_many(input, 1)[0]
 93      }
 94  
 95      /// Evaluate the cryptographic hash function over a list of field elements
 96      /// as input, and returns the specified number of field elements as
 97      /// output.
 98      pub fn evaluate_many(&self, input: &[F], num_outputs: usize) -> Vec<F> {
 99          let mut sponge = PoseidonSponge::<F, RATE, 1>::new_with_parameters(&self.parameters);
100          sponge.absorb_native_field_elements(input);
101          sponge.squeeze_native_field_elements(num_outputs).to_vec()
102      }
103  
104      /// Evaluate the cryptographic hash function over a non-fixed-length vector,
105      /// in which the length also needs to be hashed.
106      pub fn evaluate_with_len(&self, input: &[F]) -> F {
107          self.evaluate(&[vec![F::from(input.len() as u128)], input.to_vec()].concat())
108      }
109  
110      pub fn parameters(&self) -> &Arc<PoseidonParameters<F, RATE, 1>> {
111          &self.parameters
112      }
113  }
114  
115  /// A duplex sponge based using the Poseidon permutation.
116  ///
117  /// This implementation of Poseidon is entirely from Fractal's implementation in
118  /// [COS20][cos] with small syntax changes.
119  ///
120  /// [cos]: https://eprint.iacr.org/2019/1076
121  #[derive(Clone, Debug)]
122  pub struct PoseidonSponge<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
123      /// Sponge Parameters
124      parameters: Arc<PoseidonParameters<F, RATE, CAPACITY>>,
125      /// Current sponge's state (current elements in the permutation block)
126      state: State<F, RATE, CAPACITY>,
127      /// Current mode (whether its absorbing or squeezing)
128      pub mode: DuplexSpongeMode,
129      /// A persistent lookup table used when compressing elements.
130      adjustment_factor_lookup_table: Arc<[F]>,
131  }
132  
133  impl<F: PrimeField, const RATE: usize> AlgebraicSponge<F, RATE> for PoseidonSponge<F, RATE, 1> {
134      type Parameters = Arc<PoseidonParameters<F, RATE, 1>>;
135  
136      fn sample_parameters() -> Self::Parameters {
137          Arc::new(F::default_poseidon_parameters::<RATE>().unwrap())
138      }
139  
140      fn new_with_parameters(parameters: &Self::Parameters) -> Self {
141          Self {
142              parameters: parameters.clone(),
143              state: State::default(),
144              mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 },
145              adjustment_factor_lookup_table: {
146                  let capacity = F::size_in_bits() - 1;
147                  let mut table = Vec::<F>::with_capacity(capacity);
148  
149                  let mut cur = F::one();
150                  for _ in 0..capacity {
151                      table.push(cur);
152                      cur.double_in_place();
153                  }
154  
155                  table.into()
156              },
157          }
158      }
159  
160      /// Takes in field elements.
161      fn absorb_native_field_elements<T: ToConstraintField<F>>(&mut self, elements: &[T]) {
162          let input = elements.iter().flat_map(|e| e.to_field_elements().unwrap()).collect::<Vec<_>>();
163          if !input.is_empty() {
164              match self.mode {
165                  DuplexSpongeMode::Absorbing { mut next_absorb_index } => {
166                      if next_absorb_index == RATE {
167                          self.permute();
168                          next_absorb_index = 0;
169                      }
170                      self.absorb_internal(next_absorb_index, &input);
171                  }
172                  DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => {
173                      self.permute();
174                      self.absorb_internal(0, &input);
175                  }
176              }
177          }
178      }
179  
180      /// Takes in field elements.
181      fn absorb_nonnative_field_elements<Target: PrimeField>(&mut self, elements: impl IntoIterator<Item = Target>) {
182          Self::push_elements_to_sponge(self, elements, OptimizationType::Weight);
183      }
184  
185      fn squeeze_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
186          self.get_fe(num, false)
187      }
188  
189      fn squeeze_native_field_elements(&mut self, num_elements: usize) -> SmallVec<[F; 10]> {
190          if num_elements == 0 {
191              return SmallVec::<[F; 10]>::new();
192          }
193          let mut output = if num_elements <= 10 {
194              smallvec::smallvec_inline![F::zero(); 10]
195          } else {
196              smallvec::smallvec![F::zero(); num_elements]
197          };
198  
199          match self.mode {
200              DuplexSpongeMode::Absorbing { next_absorb_index: _ } => {
201                  self.permute();
202                  self.squeeze_internal(0, &mut output[..num_elements]);
203              }
204              DuplexSpongeMode::Squeezing { mut next_squeeze_index } => {
205                  if next_squeeze_index == RATE {
206                      self.permute();
207                      next_squeeze_index = 0;
208                  }
209                  self.squeeze_internal(next_squeeze_index, &mut output[..num_elements]);
210              }
211          }
212  
213          output.truncate(num_elements);
214          output
215      }
216  
217      /// Takes out field elements of 168 bits.
218      fn squeeze_short_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> {
219          self.get_fe(num, true)
220      }
221  }
222  
223  impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> {
224      #[inline]
225      fn apply_ark(&mut self, round_number: usize) {
226          for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) {
227              *state_elem += ark_elem;
228          }
229      }
230  
231      #[inline]
232      fn apply_s_box(&mut self, is_full_round: bool) {
233          if is_full_round {
234              // Full rounds apply the S Box (x^alpha) to every element of state
235              for elem in self.state.iter_mut() {
236                  *elem = elem.pow([self.parameters.alpha]);
237              }
238          } else {
239              // Partial rounds apply the S Box (x^alpha) to just the first element of state
240              self.state[0] = self.state[0].pow([self.parameters.alpha]);
241          }
242      }
243  
244      #[inline]
245      fn apply_mds(&mut self) {
246          let mut new_state = State::default();
247          new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
248              *new_elem = F::sum_of_products(self.state.iter(), mds_row.iter());
249          });
250          self.state = new_state;
251      }
252  
253      #[inline]
254      fn permute(&mut self) {
255          // Determine the partial rounds range bound.
256          let partial_rounds = self.parameters.partial_rounds;
257          let full_rounds = self.parameters.full_rounds;
258          let full_rounds_over_2 = full_rounds / 2;
259          let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds);
260  
261          // Iterate through all rounds to permute.
262          for i in 0..(partial_rounds + full_rounds) {
263              let is_full_round = !partial_round_range.contains(&i);
264              self.apply_ark(i);
265              self.apply_s_box(is_full_round);
266              self.apply_mds();
267          }
268      }
269  
270      /// Absorbs everything in elements, this does not end in an absorption.
271      #[inline]
272      fn absorb_internal(&mut self, mut rate_start: usize, input: &[F]) {
273          if !input.is_empty() {
274              let first_chunk_size = std::cmp::min(RATE - rate_start, input.len());
275              let num_elements_remaining = input.len() - first_chunk_size;
276              let (first_chunk, rest_chunk) = input.split_at(first_chunk_size);
277              let rest_chunks = rest_chunk.chunks(RATE);
278              // The total number of chunks is `elements[num_elements_remaining..].len() /
279              // RATE`, plus 1 for the remainder.
280              let total_num_chunks = 1 + // 1 for the first chunk
281                  // We add all the chunks that are perfectly divisible by `RATE`
282                  (num_elements_remaining / RATE) +
283                  // And also add 1 if the last chunk is non-empty
284                  // (i.e. if `num_elements_remaining` is not a multiple of `RATE`)
285                  usize::from(!num_elements_remaining.is_multiple_of(RATE));
286  
287              // Absorb the input elements, `RATE` elements at a time, except for the first
288              // chunk, which is of size `RATE - rate_start`.
289              for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
290                  for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state[rate_start..]) {
291                      *state_elem += element;
292                  }
293                  // Are we in the last chunk?
294                  // If so, let's wrap up.
295                  if i == total_num_chunks - 1 {
296                      self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() };
297                      return;
298                  } else {
299                      self.permute();
300                  }
301                  rate_start = 0;
302              }
303          }
304      }
305  
306      /// Squeeze |output| many elements. This does not end in a squeeze
307      #[inline]
308      fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [F]) {
309          let output_size = output.len();
310          if output_size != 0 {
311              let first_chunk_size = std::cmp::min(RATE - rate_start, output.len());
312              let num_output_remaining = output.len() - first_chunk_size;
313              let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size);
314              assert_eq!(rest_chunk.len(), num_output_remaining);
315              let rest_chunks = rest_chunk.chunks_mut(RATE);
316              // The total number of chunks is `output[num_output_remaining..].len() / RATE`,
317              // plus 1 for the remainder.
318              let total_num_chunks = 1 + // 1 for the first chunk
319                  // We add all the chunks that are perfectly divisible by `RATE`
320                  (num_output_remaining / RATE) +
321                  // And also add 1 if the last chunk is non-empty
322                  // (i.e. if `num_output_remaining` is not a multiple of `RATE`)
323                  usize::from(!num_output_remaining.is_multiple_of(RATE));
324  
325              // Absorb the input output, `RATE` output at a time, except for the first chunk,
326              // which is of size `RATE - rate_start`.
327              for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
328                  let range = rate_start..(rate_start + chunk.len());
329                  debug_assert_eq!(
330                      chunk.len(),
331                      self.state.rate_state[range.clone()].len(),
332                      "failed with squeeze {output_size} at rate {RATE} and rate_start {rate_start}"
333                  );
334                  chunk.copy_from_slice(&self.state.rate_state[range]);
335                  // Are we in the last chunk?
336                  // If so, let's wrap up.
337                  if i == total_num_chunks - 1 {
338                      self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) };
339                      return;
340                  } else {
341                      self.permute();
342                  }
343                  rate_start = 0;
344              }
345          }
346      }
347  
348      /// Compress every two elements if possible.
349      /// Provides a vector of (limb, num_of_additions), both of which are F.
350      pub fn compress_elements<TargetField: PrimeField, I: Iterator<Item = (F, F)>>(
351          &self,
352          mut src_limbs: Peekable<I>,
353          ty: OptimizationType,
354      ) -> Vec<F> {
355          let capacity = F::size_in_bits() - 1;
356          let mut dest_limbs = Vec::<F>::new();
357  
358          let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty);
359  
360          // Prepare a reusable vector to be used in overhead calculation.
361          let mut num_bits = Vec::new();
362  
363          while let Some(first) = src_limbs.next() {
364              let second = src_limbs.peek();
365  
366              let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one(), &mut num_bits);
367              let second_max_bits_per_limb = if let Some(second) = second {
368                  params.bits_per_limb + crate::overhead!(second.1 + F::one(), &mut num_bits)
369              } else {
370                  0
371              };
372  
373              if let Some(second) = second {
374                  if first_max_bits_per_limb + second_max_bits_per_limb <= capacity {
375                      let adjustment_factor = &self.adjustment_factor_lookup_table[second_max_bits_per_limb];
376  
377                      dest_limbs.push(first.0 * adjustment_factor + second.0);
378                      src_limbs.next();
379                  } else {
380                      dest_limbs.push(first.0);
381                  }
382              } else {
383                  dest_limbs.push(first.0);
384              }
385          }
386  
387          dest_limbs
388      }
389  
390      /// Convert a `TargetField` element into limbs (not constraints)
391      /// This is an internal function that would be reused by a number of other
392      /// functions
393      pub fn get_limbs_representations<TargetField: PrimeField>(
394          elem: &TargetField,
395          optimization_type: OptimizationType,
396      ) -> SmallVec<[F; 10]> {
397          Self::get_limbs_representations_from_big_integer::<TargetField>(&elem.to_bigint(), optimization_type)
398      }
399  
400      /// Obtain the limbs directly from a big int
401      pub fn get_limbs_representations_from_big_integer<TargetField: PrimeField>(
402          elem: &<TargetField as PrimeField>::BigInteger,
403          optimization_type: OptimizationType,
404      ) -> SmallVec<[F; 10]> {
405          let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type);
406  
407          // Prepare a reusable vector for the BE bits.
408          let mut cur_bits = Vec::new();
409          // Push the lower limbs first
410          let mut limbs: SmallVec<[F; 10]> = SmallVec::new();
411          let mut cur = *elem;
412          for _ in 0..params.num_limbs {
413              cur.write_bits_be(&mut cur_bits); // `write_bits_be` is big endian
414              let cur_mod_r =
415                  <F as PrimeField>::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..])
416                      .unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want.
417              limbs.push(F::from_bigint(cur_mod_r).unwrap());
418              cur.divn(params.bits_per_limb as u32);
419              // Clear the vector after every iteration so its allocation can be reused.
420              cur_bits.clear();
421          }
422  
423          // then we reverse, so that the limbs are ``big limb first''
424          limbs.reverse();
425  
426          limbs
427      }
428  
429      /// Push elements to sponge, treated in the non-native field
430      /// representations.
431      pub fn push_elements_to_sponge<TargetField: PrimeField>(
432          &mut self,
433          src: impl IntoIterator<Item = TargetField>,
434          ty: OptimizationType,
435      ) {
436          let src_limbs = src
437              .into_iter()
438              .flat_map(|elem| {
439                  let limbs = Self::get_limbs_representations(&elem, ty);
440                  limbs.into_iter().map(|limb| (limb, F::one()))
441                  // specifically set to one, since most gadgets in the constraint
442                  // world would not have zero noise (due to the relatively weak
443                  // normal form testing in `alloc`)
444              })
445              .peekable();
446  
447          let dest_limbs = self.compress_elements::<TargetField, _>(src_limbs, ty);
448          self.absorb_native_field_elements(&dest_limbs);
449      }
450  
451      /// obtain random bits from hashchain.
452      /// not guaranteed to be uniformly distributed, should only be used in
453      /// certain situations.
454      pub fn get_bits(&mut self, num_bits: usize) -> Vec<bool> {
455          let bits_per_element = F::size_in_bits() - 1;
456          let num_elements = num_bits.div_ceil(bits_per_element);
457  
458          let src_elements = self.squeeze_native_field_elements(num_elements);
459          let mut dest_bits = Vec::<bool>::with_capacity(num_elements * bits_per_element);
460  
461          let skip = (F::Parameters::REPR_SHAVE_BITS + 1) as usize;
462          for elem in src_elements.iter() {
463              // discard the highest bit
464              let elem_bits = elem.to_bigint().to_bits_be();
465              dest_bits.extend_from_slice(&elem_bits[skip..]);
466          }
467          dest_bits.truncate(num_bits);
468  
469          dest_bits
470      }
471  
472      /// obtain random field elements from hashchain.
473      /// not guaranteed to be uniformly distributed, should only be used in
474      /// certain situations.
475      pub fn get_fe<TargetField: PrimeField>(
476          &mut self,
477          num_elements: usize,
478          outputs_short_elements: bool,
479      ) -> SmallVec<[TargetField; 10]> {
480          let num_bits_per_nonnative = if outputs_short_elements {
481              168
482          } else {
483              TargetField::size_in_bits() - 1 // also omit the highest bit
484          };
485          let bits = self.get_bits(num_bits_per_nonnative * num_elements);
486  
487          let mut lookup_table = Vec::<TargetField>::with_capacity(num_bits_per_nonnative);
488          let mut cur = TargetField::one();
489          for _ in 0..num_bits_per_nonnative {
490              lookup_table.push(cur);
491              cur.double_in_place();
492          }
493  
494          let dest_elements = bits
495              .chunks_exact(num_bits_per_nonnative)
496              .map(|per_nonnative_bits| {
497                  // technically, this can be done via BigInteger::from_bits; here, we use this
498                  // method for consistency with the gadget counterpart
499                  let mut res = TargetField::zero();
500  
501                  for (i, bit) in per_nonnative_bits.iter().rev().enumerate() {
502                      if *bit {
503                          res += &lookup_table[i];
504                      }
505                  }
506                  res
507              })
508              .collect::<SmallVec<_>>();
509          debug_assert_eq!(dest_elements.len(), num_elements);
510  
511          dest_elements
512      }
513  }