/ console / algorithms / src / poseidon / helpers / sponge.rs
sponge.rs
  1  // Copyright (c) 2025-2026 ACDC Network
  2  // This file is part of the alphavm library.
  3  //
  4  // Alpha Chain | Delta Chain Protocol
  5  // International Monetary Graphite.
  6  //
  7  // Derived from Aleo (https://aleo.org) and ProvableHQ (https://provable.com).
  8  // They built world-class ZK infrastructure. We installed the EASY button.
  9  // Their cryptography: elegant. Our modifications: bureaucracy-compatible.
 10  // Original brilliance: theirs. Robert's Rules: ours. Bugs: definitely ours.
 11  //
 12  // Original Aleo/ProvableHQ code subject to Apache 2.0 https://www.apache.org/licenses/LICENSE-2.0
 13  // All modifications and new work: CC0 1.0 Universal Public Domain Dedication.
 14  // No rights reserved. No permission required. No warranty. No refunds.
 15  //
 16  // https://creativecommons.org/publicdomain/zero/1.0/
 17  // SPDX-License-Identifier: CC0-1.0
 18  
 19  use crate::poseidon::{
 20      helpers::{AlgebraicSponge, DuplexSpongeMode},
 21      State,
 22  };
 23  use alphavm_console_types::{prelude::*, Field};
 24  use alphavm_fields::PoseidonParameters;
 25  
 26  use smallvec::SmallVec;
 27  use std::{ops::DerefMut, sync::Arc};
 28  
 29  /// A duplex sponge based using the Poseidon permutation.
 30  ///
 31  /// This implementation of Poseidon is entirely from Fractal's implementation in [COS20][cos]
 32  /// with small syntax changes.
 33  ///
 34  /// [cos]: https://eprint.iacr.org/2019/1076
 35  #[derive(Clone, Debug)]
 36  pub struct PoseidonSponge<E: Environment, const RATE: usize, const CAPACITY: usize> {
 37      /// Sponge Parameters
 38      parameters: Arc<PoseidonParameters<E::Field, RATE, CAPACITY>>,
 39      /// Current sponge's state (current elements in the permutation block)
 40      state: State<E, RATE, CAPACITY>,
 41      /// Current mode (whether its absorbing or squeezing)
 42      pub(in crate::poseidon) mode: DuplexSpongeMode,
 43  }
 44  
 45  impl<E: Environment, const RATE: usize, const CAPACITY: usize> AlgebraicSponge<E, RATE, CAPACITY>
 46      for PoseidonSponge<E, RATE, CAPACITY>
 47  {
 48      type Parameters = Arc<PoseidonParameters<E::Field, RATE, CAPACITY>>;
 49  
 50      fn new(parameters: &Self::Parameters) -> Self {
 51          Self {
 52              parameters: parameters.clone(),
 53              state: State::default(),
 54              mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 },
 55          }
 56      }
 57  
 58      fn absorb(&mut self, input: &[Field<E>]) {
 59          if !input.is_empty() {
 60              match self.mode {
 61                  DuplexSpongeMode::Absorbing { mut next_absorb_index } => {
 62                      if next_absorb_index == RATE {
 63                          self.permute();
 64                          next_absorb_index = 0;
 65                      }
 66                      self.absorb_internal(next_absorb_index, input);
 67                  }
 68                  DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => {
 69                      self.permute();
 70                      self.absorb_internal(0, input);
 71                  }
 72              }
 73          }
 74      }
 75  
 76      fn squeeze(&mut self, num_elements: u16) -> SmallVec<[Field<E>; 10]> {
 77          if num_elements == 0 {
 78              return SmallVec::new();
 79          }
 80          let mut output = if num_elements <= 10 {
 81              smallvec::smallvec_inline![Field::<E>::zero(); 10]
 82          } else {
 83              smallvec::smallvec![Field::<E>::zero(); num_elements as usize]
 84          };
 85  
 86          match self.mode {
 87              DuplexSpongeMode::Absorbing { next_absorb_index: _ } => {
 88                  self.permute();
 89                  self.squeeze_internal(0, &mut output[..num_elements as usize]);
 90              }
 91              DuplexSpongeMode::Squeezing { mut next_squeeze_index } => {
 92                  if next_squeeze_index == RATE {
 93                      self.permute();
 94                      next_squeeze_index = 0;
 95                  }
 96                  self.squeeze_internal(next_squeeze_index, &mut output[..num_elements as usize]);
 97              }
 98          }
 99  
100          output.truncate(num_elements as usize);
101          output
102      }
103  }
104  
105  impl<E: Environment, const RATE: usize, const CAPACITY: usize> PoseidonSponge<E, RATE, CAPACITY> {
106      #[inline]
107      fn apply_ark(&mut self, round_number: usize) {
108          for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) {
109              *state_elem += Field::<E>::new(*ark_elem);
110          }
111      }
112  
113      #[inline]
114      fn apply_s_box(&mut self, is_full_round: bool) {
115          // Full rounds apply the S Box (x^alpha) to every element of state
116          if is_full_round {
117              for elem in self.state.iter_mut() {
118                  let e = elem.deref_mut();
119                  *e = e.pow([self.parameters.alpha]);
120              }
121          }
122          // Partial rounds apply the S Box (x^alpha) to just the first element of state
123          else {
124              let e = self.state[0].deref_mut();
125              *e = e.pow([self.parameters.alpha]);
126          }
127      }
128  
129      #[inline]
130      fn apply_mds(&mut self) {
131          let mut new_state = State::default();
132          new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
133              *new_elem = Field::new(E::Field::sum_of_products(self.state.iter().map(|e| e.deref()), mds_row.iter()));
134          });
135          self.state = new_state;
136      }
137  
138      #[inline]
139      fn permute(&mut self) {
140          // Determine the partial rounds range bound.
141          let partial_rounds = self.parameters.partial_rounds;
142          let full_rounds = self.parameters.full_rounds;
143          let full_rounds_over_2 = full_rounds / 2;
144          let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds);
145  
146          // Iterate through all rounds to permute.
147          for i in 0..(partial_rounds + full_rounds) {
148              let is_full_round = !partial_round_range.contains(&i);
149              self.apply_ark(i);
150              self.apply_s_box(is_full_round);
151              self.apply_mds();
152          }
153      }
154  
155      /// Absorbs everything in elements, this does not end in an absorption.
156      #[inline]
157      fn absorb_internal(&mut self, mut rate_start: usize, input: &[Field<E>]) {
158          if !input.is_empty() {
159              let first_chunk_size = std::cmp::min(RATE - rate_start, input.len());
160              let num_elements_remaining = input.len() - first_chunk_size;
161              let (first_chunk, rest_chunk) = input.split_at(first_chunk_size);
162              let rest_chunks = rest_chunk.chunks(RATE);
163              // The total number of chunks is `elements[num_elements_remaining..].len() / RATE`, plus 1
164              // for the remainder.
165              let total_num_chunks = 1 + // 1 for the first chunk
166                  // We add all the chunks that are perfectly divisible by `RATE`
167                  (num_elements_remaining / RATE) +
168                  // And also add 1 if the last chunk is non-empty
169                  // (i.e. if `num_elements_remaining` is not a multiple of `RATE`)
170                  usize::from(!num_elements_remaining.is_multiple_of(RATE));
171  
172              // Absorb the input elements, `RATE` elements at a time, except for the first chunk, which
173              // is of size `RATE - rate_start`.
174              for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
175                  for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state_mut()[rate_start..]) {
176                      *state_elem += element;
177                  }
178                  // Are we in the last chunk?
179                  // If so, let's wrap up.
180                  if i == total_num_chunks - 1 {
181                      self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() };
182                      return;
183                  } else {
184                      self.permute();
185                  }
186                  rate_start = 0;
187              }
188          }
189      }
190  
191      /// Squeeze |output| many elements. This does not end in a squeeze
192      #[inline]
193      fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [Field<E>]) {
194          let output_size = output.len();
195          if output_size != 0 {
196              let first_chunk_size = std::cmp::min(RATE - rate_start, output.len());
197              let num_output_remaining = output.len() - first_chunk_size;
198              let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size);
199              assert_eq!(rest_chunk.len(), num_output_remaining);
200              let rest_chunks = rest_chunk.chunks_mut(RATE);
201              // The total number of chunks is `output[num_output_remaining..].len() / RATE`, plus 1
202              // for the remainder.
203              let total_num_chunks = 1 + // 1 for the first chunk
204                  // We add all the chunks that are perfectly divisible by `RATE`
205                  (num_output_remaining / RATE) +
206                  // And also add 1 if the last chunk is non-empty
207                  // (i.e. if `num_output_remaining` is not a multiple of `RATE`)
208                  usize::from(!num_output_remaining.is_multiple_of(RATE));
209  
210              // Absorb the input output, `RATE` output at a time, except for the first chunk, which
211              // is of size `RATE - rate_start`.
212              for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
213                  let range = rate_start..(rate_start + chunk.len());
214                  debug_assert_eq!(
215                      chunk.len(),
216                      self.state.rate_state(range.clone()).len(),
217                      "Failed to squeeze {output_size} at rate {RATE} & rate_start {rate_start}"
218                  );
219                  chunk.copy_from_slice(self.state.rate_state(range));
220                  // Are we in the last chunk?
221                  // If so, let's wrap up.
222                  if i == total_num_chunks - 1 {
223                      self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) };
224                      return;
225                  } else {
226                      self.permute();
227                  }
228                  rate_start = 0;
229              }
230          }
231      }
232  }