sponge.rs
1 // Copyright (c) 2019-2025 Alpha-Delta Network Inc. 2 // This file is part of the deltavm library. 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at: 7 8 // http://www.apache.org/licenses/LICENSE-2.0 9 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 use crate::poseidon::{ 17 State, 18 helpers::{AlgebraicSponge, DuplexSpongeMode}, 19 }; 20 use deltavm_console_types::{Field, prelude::*}; 21 use deltavm_fields::PoseidonParameters; 22 23 use smallvec::SmallVec; 24 use std::{ops::DerefMut, sync::Arc}; 25 26 /// A duplex sponge based using the Poseidon permutation. 27 /// 28 /// This implementation of Poseidon is entirely from Fractal's implementation in [COS20][cos] 29 /// with small syntax changes. 30 /// 31 /// [cos]: https://eprint.iacr.org/2019/1076 32 #[derive(Clone, Debug)] 33 pub struct PoseidonSponge<E: Environment, const RATE: usize, const CAPACITY: usize> { 34 /// Sponge Parameters 35 parameters: Arc<PoseidonParameters<E::Field, RATE, CAPACITY>>, 36 /// Current sponge's state (current elements in the permutation block) 37 state: State<E, RATE, CAPACITY>, 38 /// Current mode (whether its absorbing or squeezing) 39 pub(in crate::poseidon) mode: DuplexSpongeMode, 40 } 41 42 impl<E: Environment, const RATE: usize, const CAPACITY: usize> AlgebraicSponge<E, RATE, CAPACITY> 43 for PoseidonSponge<E, RATE, CAPACITY> 44 { 45 type Parameters = Arc<PoseidonParameters<E::Field, RATE, CAPACITY>>; 46 47 fn new(parameters: &Self::Parameters) -> Self { 48 Self { 49 parameters: parameters.clone(), 50 state: State::default(), 51 mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 }, 52 } 53 } 54 55 fn absorb(&mut self, input: &[Field<E>]) { 56 if !input.is_empty() { 57 match self.mode { 58 DuplexSpongeMode::Absorbing { mut next_absorb_index } => { 59 if next_absorb_index == RATE { 60 self.permute(); 61 next_absorb_index = 0; 62 } 63 self.absorb_internal(next_absorb_index, input); 64 } 65 DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => { 66 self.permute(); 67 self.absorb_internal(0, input); 68 } 69 } 70 } 71 } 72 73 fn squeeze(&mut self, num_elements: u16) -> SmallVec<[Field<E>; 10]> { 74 if num_elements == 0 { 75 return SmallVec::new(); 76 } 77 let mut output = if num_elements <= 10 { 78 smallvec::smallvec_inline![Field::<E>::zero(); 10] 79 } else { 80 smallvec::smallvec![Field::<E>::zero(); num_elements as usize] 81 }; 82 83 match self.mode { 84 DuplexSpongeMode::Absorbing { next_absorb_index: _ } => { 85 self.permute(); 86 self.squeeze_internal(0, &mut output[..num_elements as usize]); 87 } 88 DuplexSpongeMode::Squeezing { mut next_squeeze_index } => { 89 if next_squeeze_index == RATE { 90 self.permute(); 91 next_squeeze_index = 0; 92 } 93 self.squeeze_internal(next_squeeze_index, &mut output[..num_elements as usize]); 94 } 95 } 96 97 output.truncate(num_elements as usize); 98 output 99 } 100 } 101 102 impl<E: Environment, const RATE: usize, const CAPACITY: usize> PoseidonSponge<E, RATE, CAPACITY> { 103 #[inline] 104 fn apply_ark(&mut self, round_number: usize) { 105 for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) { 106 *state_elem += Field::<E>::new(*ark_elem); 107 } 108 } 109 110 #[inline] 111 fn apply_s_box(&mut self, is_full_round: bool) { 112 // Full rounds apply the S Box (x^alpha) to every element of state 113 if is_full_round { 114 for elem in self.state.iter_mut() { 115 let e = elem.deref_mut(); 116 *e = e.pow([self.parameters.alpha]); 117 } 118 } 119 // Partial rounds apply the S Box (x^alpha) to just the first element of state 120 else { 121 let e = self.state[0].deref_mut(); 122 *e = e.pow([self.parameters.alpha]); 123 } 124 } 125 126 #[inline] 127 fn apply_mds(&mut self) { 128 let mut new_state = State::default(); 129 new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| { 130 *new_elem = Field::new(E::Field::sum_of_products(self.state.iter().map(|e| e.deref()), mds_row.iter())); 131 }); 132 self.state = new_state; 133 } 134 135 #[inline] 136 fn permute(&mut self) { 137 // Determine the partial rounds range bound. 138 let partial_rounds = self.parameters.partial_rounds; 139 let full_rounds = self.parameters.full_rounds; 140 let full_rounds_over_2 = full_rounds / 2; 141 let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds); 142 143 // Iterate through all rounds to permute. 144 for i in 0..(partial_rounds + full_rounds) { 145 let is_full_round = !partial_round_range.contains(&i); 146 self.apply_ark(i); 147 self.apply_s_box(is_full_round); 148 self.apply_mds(); 149 } 150 } 151 152 /// Absorbs everything in elements, this does not end in an absorption. 153 #[inline] 154 fn absorb_internal(&mut self, mut rate_start: usize, input: &[Field<E>]) { 155 if !input.is_empty() { 156 let first_chunk_size = std::cmp::min(RATE - rate_start, input.len()); 157 let num_elements_remaining = input.len() - first_chunk_size; 158 let (first_chunk, rest_chunk) = input.split_at(first_chunk_size); 159 let rest_chunks = rest_chunk.chunks(RATE); 160 // The total number of chunks is `elements[num_elements_remaining..].len() / RATE`, plus 1 161 // for the remainder. 162 let total_num_chunks = 1 + // 1 for the first chunk 163 // We add all the chunks that are perfectly divisible by `RATE` 164 (num_elements_remaining / RATE) + 165 // And also add 1 if the last chunk is non-empty 166 // (i.e. if `num_elements_remaining` is not a multiple of `RATE`) 167 usize::from((num_elements_remaining % RATE) != 0); 168 169 // Absorb the input elements, `RATE` elements at a time, except for the first chunk, which 170 // is of size `RATE - rate_start`. 171 for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() { 172 for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state_mut()[rate_start..]) { 173 *state_elem += element; 174 } 175 // Are we in the last chunk? 176 // If so, let's wrap up. 177 if i == total_num_chunks - 1 { 178 self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() }; 179 return; 180 } else { 181 self.permute(); 182 } 183 rate_start = 0; 184 } 185 } 186 } 187 188 /// Squeeze |output| many elements. This does not end in a squeeze 189 #[inline] 190 fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [Field<E>]) { 191 let output_size = output.len(); 192 if output_size != 0 { 193 let first_chunk_size = std::cmp::min(RATE - rate_start, output.len()); 194 let num_output_remaining = output.len() - first_chunk_size; 195 let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size); 196 assert_eq!(rest_chunk.len(), num_output_remaining); 197 let rest_chunks = rest_chunk.chunks_mut(RATE); 198 // The total number of chunks is `output[num_output_remaining..].len() / RATE`, plus 1 199 // for the remainder. 200 let total_num_chunks = 1 + // 1 for the first chunk 201 // We add all the chunks that are perfectly divisible by `RATE` 202 (num_output_remaining / RATE) + 203 // And also add 1 if the last chunk is non-empty 204 // (i.e. if `num_output_remaining` is not a multiple of `RATE`) 205 usize::from((num_output_remaining % RATE) != 0); 206 207 // Absorb the input output, `RATE` output at a time, except for the first chunk, which 208 // is of size `RATE - rate_start`. 209 for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() { 210 let range = rate_start..(rate_start + chunk.len()); 211 debug_assert_eq!( 212 chunk.len(), 213 self.state.rate_state(range.clone()).len(), 214 "Failed to squeeze {output_size} at rate {RATE} & rate_start {rate_start}" 215 ); 216 chunk.copy_from_slice(self.state.rate_state(range)); 217 // Are we in the last chunk? 218 // If so, let's wrap up. 219 if i == total_num_chunks - 1 { 220 self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) }; 221 return; 222 } else { 223 self.permute(); 224 } 225 rate_start = 0; 226 } 227 } 228 } 229 }