/ console / algorithms / src / poseidon / helpers / sponge.rs
sponge.rs
  1  // Copyright (c) 2019-2025 Alpha-Delta Network Inc.
  2  // This file is part of the deltavm library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use crate::poseidon::{
 17      State,
 18      helpers::{AlgebraicSponge, DuplexSpongeMode},
 19  };
 20  use deltavm_console_types::{Field, prelude::*};
 21  use deltavm_fields::PoseidonParameters;
 22  
 23  use smallvec::SmallVec;
 24  use std::{ops::DerefMut, sync::Arc};
 25  
 26  /// A duplex sponge based using the Poseidon permutation.
 27  ///
 28  /// This implementation of Poseidon is entirely from Fractal's implementation in [COS20][cos]
 29  /// with small syntax changes.
 30  ///
 31  /// [cos]: https://eprint.iacr.org/2019/1076
 32  #[derive(Clone, Debug)]
 33  pub struct PoseidonSponge<E: Environment, const RATE: usize, const CAPACITY: usize> {
 34      /// Sponge Parameters
 35      parameters: Arc<PoseidonParameters<E::Field, RATE, CAPACITY>>,
 36      /// Current sponge's state (current elements in the permutation block)
 37      state: State<E, RATE, CAPACITY>,
 38      /// Current mode (whether its absorbing or squeezing)
 39      pub(in crate::poseidon) mode: DuplexSpongeMode,
 40  }
 41  
 42  impl<E: Environment, const RATE: usize, const CAPACITY: usize> AlgebraicSponge<E, RATE, CAPACITY>
 43      for PoseidonSponge<E, RATE, CAPACITY>
 44  {
 45      type Parameters = Arc<PoseidonParameters<E::Field, RATE, CAPACITY>>;
 46  
 47      fn new(parameters: &Self::Parameters) -> Self {
 48          Self {
 49              parameters: parameters.clone(),
 50              state: State::default(),
 51              mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 },
 52          }
 53      }
 54  
 55      fn absorb(&mut self, input: &[Field<E>]) {
 56          if !input.is_empty() {
 57              match self.mode {
 58                  DuplexSpongeMode::Absorbing { mut next_absorb_index } => {
 59                      if next_absorb_index == RATE {
 60                          self.permute();
 61                          next_absorb_index = 0;
 62                      }
 63                      self.absorb_internal(next_absorb_index, input);
 64                  }
 65                  DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => {
 66                      self.permute();
 67                      self.absorb_internal(0, input);
 68                  }
 69              }
 70          }
 71      }
 72  
 73      fn squeeze(&mut self, num_elements: u16) -> SmallVec<[Field<E>; 10]> {
 74          if num_elements == 0 {
 75              return SmallVec::new();
 76          }
 77          let mut output = if num_elements <= 10 {
 78              smallvec::smallvec_inline![Field::<E>::zero(); 10]
 79          } else {
 80              smallvec::smallvec![Field::<E>::zero(); num_elements as usize]
 81          };
 82  
 83          match self.mode {
 84              DuplexSpongeMode::Absorbing { next_absorb_index: _ } => {
 85                  self.permute();
 86                  self.squeeze_internal(0, &mut output[..num_elements as usize]);
 87              }
 88              DuplexSpongeMode::Squeezing { mut next_squeeze_index } => {
 89                  if next_squeeze_index == RATE {
 90                      self.permute();
 91                      next_squeeze_index = 0;
 92                  }
 93                  self.squeeze_internal(next_squeeze_index, &mut output[..num_elements as usize]);
 94              }
 95          }
 96  
 97          output.truncate(num_elements as usize);
 98          output
 99      }
100  }
101  
102  impl<E: Environment, const RATE: usize, const CAPACITY: usize> PoseidonSponge<E, RATE, CAPACITY> {
103      #[inline]
104      fn apply_ark(&mut self, round_number: usize) {
105          for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) {
106              *state_elem += Field::<E>::new(*ark_elem);
107          }
108      }
109  
110      #[inline]
111      fn apply_s_box(&mut self, is_full_round: bool) {
112          // Full rounds apply the S Box (x^alpha) to every element of state
113          if is_full_round {
114              for elem in self.state.iter_mut() {
115                  let e = elem.deref_mut();
116                  *e = e.pow([self.parameters.alpha]);
117              }
118          }
119          // Partial rounds apply the S Box (x^alpha) to just the first element of state
120          else {
121              let e = self.state[0].deref_mut();
122              *e = e.pow([self.parameters.alpha]);
123          }
124      }
125  
126      #[inline]
127      fn apply_mds(&mut self) {
128          let mut new_state = State::default();
129          new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| {
130              *new_elem = Field::new(E::Field::sum_of_products(self.state.iter().map(|e| e.deref()), mds_row.iter()));
131          });
132          self.state = new_state;
133      }
134  
135      #[inline]
136      fn permute(&mut self) {
137          // Determine the partial rounds range bound.
138          let partial_rounds = self.parameters.partial_rounds;
139          let full_rounds = self.parameters.full_rounds;
140          let full_rounds_over_2 = full_rounds / 2;
141          let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds);
142  
143          // Iterate through all rounds to permute.
144          for i in 0..(partial_rounds + full_rounds) {
145              let is_full_round = !partial_round_range.contains(&i);
146              self.apply_ark(i);
147              self.apply_s_box(is_full_round);
148              self.apply_mds();
149          }
150      }
151  
152      /// Absorbs everything in elements, this does not end in an absorption.
153      #[inline]
154      fn absorb_internal(&mut self, mut rate_start: usize, input: &[Field<E>]) {
155          if !input.is_empty() {
156              let first_chunk_size = std::cmp::min(RATE - rate_start, input.len());
157              let num_elements_remaining = input.len() - first_chunk_size;
158              let (first_chunk, rest_chunk) = input.split_at(first_chunk_size);
159              let rest_chunks = rest_chunk.chunks(RATE);
160              // The total number of chunks is `elements[num_elements_remaining..].len() / RATE`, plus 1
161              // for the remainder.
162              let total_num_chunks = 1 + // 1 for the first chunk
163                  // We add all the chunks that are perfectly divisible by `RATE`
164                  (num_elements_remaining / RATE) +
165                  // And also add 1 if the last chunk is non-empty
166                  // (i.e. if `num_elements_remaining` is not a multiple of `RATE`)
167                  usize::from((num_elements_remaining % RATE) != 0);
168  
169              // Absorb the input elements, `RATE` elements at a time, except for the first chunk, which
170              // is of size `RATE - rate_start`.
171              for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
172                  for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state_mut()[rate_start..]) {
173                      *state_elem += element;
174                  }
175                  // Are we in the last chunk?
176                  // If so, let's wrap up.
177                  if i == total_num_chunks - 1 {
178                      self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() };
179                      return;
180                  } else {
181                      self.permute();
182                  }
183                  rate_start = 0;
184              }
185          }
186      }
187  
188      /// Squeeze |output| many elements. This does not end in a squeeze
189      #[inline]
190      fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [Field<E>]) {
191          let output_size = output.len();
192          if output_size != 0 {
193              let first_chunk_size = std::cmp::min(RATE - rate_start, output.len());
194              let num_output_remaining = output.len() - first_chunk_size;
195              let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size);
196              assert_eq!(rest_chunk.len(), num_output_remaining);
197              let rest_chunks = rest_chunk.chunks_mut(RATE);
198              // The total number of chunks is `output[num_output_remaining..].len() / RATE`, plus 1
199              // for the remainder.
200              let total_num_chunks = 1 + // 1 for the first chunk
201                  // We add all the chunks that are perfectly divisible by `RATE`
202                  (num_output_remaining / RATE) +
203                  // And also add 1 if the last chunk is non-empty
204                  // (i.e. if `num_output_remaining` is not a multiple of `RATE`)
205                  usize::from((num_output_remaining % RATE) != 0);
206  
207              // Absorb the input output, `RATE` output at a time, except for the first chunk, which
208              // is of size `RATE - rate_start`.
209              for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() {
210                  let range = rate_start..(rate_start + chunk.len());
211                  debug_assert_eq!(
212                      chunk.len(),
213                      self.state.rate_state(range.clone()).len(),
214                      "Failed to squeeze {output_size} at rate {RATE} & rate_start {rate_start}"
215                  );
216                  chunk.copy_from_slice(self.state.rate_state(range));
217                  // Are we in the last chunk?
218                  // If so, let's wrap up.
219                  if i == total_num_chunks - 1 {
220                      self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) };
221                      return;
222                  } else {
223                      self.permute();
224                  }
225                  rate_start = 0;
226              }
227          }
228      }
229  }