poseidon.rs
1 // Copyright (c) 2019-2025 Alpha-Delta Network Inc. 2 // This file is part of the deltavm library. 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at: 7 8 // http://www.apache.org/licenses/LICENSE-2.0 9 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 use crate::{AlgebraicSponge, DuplexSpongeMode, nonnative_params::*}; 17 use deltavm_fields::{FieldParameters, PoseidonParameters, PrimeField, ToConstraintField}; 18 use deltavm_utilities::{BigInteger, FromBits, ToBits}; 19 20 use smallvec::SmallVec; 21 use std::{ 22 iter::Peekable, 23 ops::{Index, IndexMut}, 24 sync::Arc, 25 }; 26 27 #[derive(Copy, Clone, Debug)] 28 pub struct State<F: PrimeField, const RATE: usize, const CAPACITY: usize> { 29 capacity_state: [F; CAPACITY], 30 rate_state: [F; RATE], 31 } 32 33 impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Default for State<F, RATE, CAPACITY> { 34 fn default() -> Self { 35 Self { capacity_state: [F::zero(); CAPACITY], rate_state: [F::zero(); RATE] } 36 } 37 } 38 39 impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> State<F, RATE, CAPACITY> { 40 /// Returns an immutable iterator over the state. 41 pub fn iter(&self) -> impl Iterator<Item = &F> + Clone { 42 self.capacity_state.iter().chain(self.rate_state.iter()) 43 } 44 45 /// Returns a mutable iterator over the state. 46 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut F> { 47 self.capacity_state.iter_mut().chain(self.rate_state.iter_mut()) 48 } 49 } 50 51 impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Index<usize> for State<F, RATE, CAPACITY> { 52 type Output = F; 53 54 fn index(&self, index: usize) -> &Self::Output { 55 assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY); 56 if index < CAPACITY { &self.capacity_state[index] } else { &self.rate_state[index - CAPACITY] } 57 } 58 } 59 60 impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> IndexMut<usize> for State<F, RATE, CAPACITY> { 61 fn index_mut(&mut self, index: usize) -> &mut Self::Output { 62 assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY); 63 if index < CAPACITY { &mut self.capacity_state[index] } else { &mut self.rate_state[index - CAPACITY] } 64 } 65 } 66 67 #[derive(Clone, Debug, PartialEq, Eq)] 68 pub struct Poseidon<F: PrimeField, const RATE: usize> { 69 parameters: Arc<PoseidonParameters<F, RATE, 1>>, 70 } 71 72 impl<F: PrimeField, const RATE: usize> Poseidon<F, RATE> { 73 /// Initializes a new instance of the cryptographic hash function. 74 pub fn setup() -> Self { 75 Self { parameters: Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) } 76 } 77 78 /// Evaluate the cryptographic hash function over a list of field elements 79 /// as input. 80 pub fn evaluate(&self, input: &[F]) -> F { 81 self.evaluate_many(input, 1)[0] 82 } 83 84 /// Evaluate the cryptographic hash function over a list of field elements 85 /// as input, and returns the specified number of field elements as 86 /// output. 87 pub fn evaluate_many(&self, input: &[F], num_outputs: usize) -> Vec<F> { 88 let mut sponge = PoseidonSponge::<F, RATE, 1>::new_with_parameters(&self.parameters); 89 sponge.absorb_native_field_elements(input); 90 sponge.squeeze_native_field_elements(num_outputs).to_vec() 91 } 92 93 /// Evaluate the cryptographic hash function over a non-fixed-length vector, 94 /// in which the length also needs to be hashed. 95 pub fn evaluate_with_len(&self, input: &[F]) -> F { 96 self.evaluate(&[vec![F::from(input.len() as u128)], input.to_vec()].concat()) 97 } 98 99 pub fn parameters(&self) -> &Arc<PoseidonParameters<F, RATE, 1>> { 100 &self.parameters 101 } 102 } 103 104 /// A duplex sponge based using the Poseidon permutation. 105 /// 106 /// This implementation of Poseidon is entirely from Fractal's implementation in 107 /// [COS20][cos] with small syntax changes. 108 /// 109 /// [cos]: https://eprint.iacr.org/2019/1076 110 #[derive(Clone, Debug)] 111 pub struct PoseidonSponge<F: PrimeField, const RATE: usize, const CAPACITY: usize> { 112 /// Sponge Parameters 113 parameters: Arc<PoseidonParameters<F, RATE, CAPACITY>>, 114 /// Current sponge's state (current elements in the permutation block) 115 state: State<F, RATE, CAPACITY>, 116 /// Current mode (whether its absorbing or squeezing) 117 pub mode: DuplexSpongeMode, 118 /// A persistent lookup table used when compressing elements. 119 adjustment_factor_lookup_table: Arc<[F]>, 120 } 121 122 impl<F: PrimeField, const RATE: usize> AlgebraicSponge<F, RATE> for PoseidonSponge<F, RATE, 1> { 123 type Parameters = Arc<PoseidonParameters<F, RATE, 1>>; 124 125 fn sample_parameters() -> Self::Parameters { 126 Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) 127 } 128 129 fn new_with_parameters(parameters: &Self::Parameters) -> Self { 130 Self { 131 parameters: parameters.clone(), 132 state: State::default(), 133 mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 }, 134 adjustment_factor_lookup_table: { 135 let capacity = F::size_in_bits() - 1; 136 let mut table = Vec::<F>::with_capacity(capacity); 137 138 let mut cur = F::one(); 139 for _ in 0..capacity { 140 table.push(cur); 141 cur.double_in_place(); 142 } 143 144 table.into() 145 }, 146 } 147 } 148 149 /// Takes in field elements. 150 fn absorb_native_field_elements<T: ToConstraintField<F>>(&mut self, elements: &[T]) { 151 let input = elements.iter().flat_map(|e| e.to_field_elements().unwrap()).collect::<Vec<_>>(); 152 if !input.is_empty() { 153 match self.mode { 154 DuplexSpongeMode::Absorbing { mut next_absorb_index } => { 155 if next_absorb_index == RATE { 156 self.permute(); 157 next_absorb_index = 0; 158 } 159 self.absorb_internal(next_absorb_index, &input); 160 } 161 DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => { 162 self.permute(); 163 self.absorb_internal(0, &input); 164 } 165 } 166 } 167 } 168 169 /// Takes in field elements. 170 fn absorb_nonnative_field_elements<Target: PrimeField>(&mut self, elements: impl IntoIterator<Item = Target>) { 171 Self::push_elements_to_sponge(self, elements, OptimizationType::Weight); 172 } 173 174 fn squeeze_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> { 175 self.get_fe(num, false) 176 } 177 178 fn squeeze_native_field_elements(&mut self, num_elements: usize) -> SmallVec<[F; 10]> { 179 if num_elements == 0 { 180 return SmallVec::<[F; 10]>::new(); 181 } 182 let mut output = if num_elements <= 10 { 183 smallvec::smallvec_inline![F::zero(); 10] 184 } else { 185 smallvec::smallvec![F::zero(); num_elements] 186 }; 187 188 match self.mode { 189 DuplexSpongeMode::Absorbing { next_absorb_index: _ } => { 190 self.permute(); 191 self.squeeze_internal(0, &mut output[..num_elements]); 192 } 193 DuplexSpongeMode::Squeezing { mut next_squeeze_index } => { 194 if next_squeeze_index == RATE { 195 self.permute(); 196 next_squeeze_index = 0; 197 } 198 self.squeeze_internal(next_squeeze_index, &mut output[..num_elements]); 199 } 200 } 201 202 output.truncate(num_elements); 203 output 204 } 205 206 /// Takes out field elements of 168 bits. 207 fn squeeze_short_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> { 208 self.get_fe(num, true) 209 } 210 } 211 212 impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> { 213 #[inline] 214 fn apply_ark(&mut self, round_number: usize) { 215 for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) { 216 *state_elem += ark_elem; 217 } 218 } 219 220 #[inline] 221 fn apply_s_box(&mut self, is_full_round: bool) { 222 if is_full_round { 223 // Full rounds apply the S Box (x^alpha) to every element of state 224 for elem in self.state.iter_mut() { 225 *elem = elem.pow([self.parameters.alpha]); 226 } 227 } else { 228 // Partial rounds apply the S Box (x^alpha) to just the first element of state 229 self.state[0] = self.state[0].pow([self.parameters.alpha]); 230 } 231 } 232 233 #[inline] 234 fn apply_mds(&mut self) { 235 let mut new_state = State::default(); 236 new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| { 237 *new_elem = F::sum_of_products(self.state.iter(), mds_row.iter()); 238 }); 239 self.state = new_state; 240 } 241 242 #[inline] 243 fn permute(&mut self) { 244 // Determine the partial rounds range bound. 245 let partial_rounds = self.parameters.partial_rounds; 246 let full_rounds = self.parameters.full_rounds; 247 let full_rounds_over_2 = full_rounds / 2; 248 let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds); 249 250 // Iterate through all rounds to permute. 251 for i in 0..(partial_rounds + full_rounds) { 252 let is_full_round = !partial_round_range.contains(&i); 253 self.apply_ark(i); 254 self.apply_s_box(is_full_round); 255 self.apply_mds(); 256 } 257 } 258 259 /// Absorbs everything in elements, this does not end in an absorption. 260 #[inline] 261 fn absorb_internal(&mut self, mut rate_start: usize, input: &[F]) { 262 if !input.is_empty() { 263 let first_chunk_size = std::cmp::min(RATE - rate_start, input.len()); 264 let num_elements_remaining = input.len() - first_chunk_size; 265 let (first_chunk, rest_chunk) = input.split_at(first_chunk_size); 266 let rest_chunks = rest_chunk.chunks(RATE); 267 // The total number of chunks is `elements[num_elements_remaining..].len() / 268 // RATE`, plus 1 for the remainder. 269 let total_num_chunks = 1 + // 1 for the first chunk 270 // We add all the chunks that are perfectly divisible by `RATE` 271 (num_elements_remaining / RATE) + 272 // And also add 1 if the last chunk is non-empty 273 // (i.e. if `num_elements_remaining` is not a multiple of `RATE`) 274 usize::from((num_elements_remaining % RATE) != 0); 275 276 // Absorb the input elements, `RATE` elements at a time, except for the first 277 // chunk, which is of size `RATE - rate_start`. 278 for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() { 279 for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state[rate_start..]) { 280 *state_elem += element; 281 } 282 // Are we in the last chunk? 283 // If so, let's wrap up. 284 if i == total_num_chunks - 1 { 285 self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() }; 286 return; 287 } else { 288 self.permute(); 289 } 290 rate_start = 0; 291 } 292 } 293 } 294 295 /// Squeeze |output| many elements. This does not end in a squeeze 296 #[inline] 297 fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [F]) { 298 let output_size = output.len(); 299 if output_size != 0 { 300 let first_chunk_size = std::cmp::min(RATE - rate_start, output.len()); 301 let num_output_remaining = output.len() - first_chunk_size; 302 let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size); 303 assert_eq!(rest_chunk.len(), num_output_remaining); 304 let rest_chunks = rest_chunk.chunks_mut(RATE); 305 // The total number of chunks is `output[num_output_remaining..].len() / RATE`, 306 // plus 1 for the remainder. 307 let total_num_chunks = 1 + // 1 for the first chunk 308 // We add all the chunks that are perfectly divisible by `RATE` 309 (num_output_remaining / RATE) + 310 // And also add 1 if the last chunk is non-empty 311 // (i.e. if `num_output_remaining` is not a multiple of `RATE`) 312 usize::from((num_output_remaining % RATE) != 0); 313 314 // Absorb the input output, `RATE` output at a time, except for the first chunk, 315 // which is of size `RATE - rate_start`. 316 for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() { 317 let range = rate_start..(rate_start + chunk.len()); 318 debug_assert_eq!( 319 chunk.len(), 320 self.state.rate_state[range.clone()].len(), 321 "failed with squeeze {output_size} at rate {RATE} and rate_start {rate_start}" 322 ); 323 chunk.copy_from_slice(&self.state.rate_state[range]); 324 // Are we in the last chunk? 325 // If so, let's wrap up. 326 if i == total_num_chunks - 1 { 327 self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) }; 328 return; 329 } else { 330 self.permute(); 331 } 332 rate_start = 0; 333 } 334 } 335 } 336 337 /// Compress every two elements if possible. 338 /// Provides a vector of (limb, num_of_additions), both of which are F. 339 pub fn compress_elements<TargetField: PrimeField, I: Iterator<Item = (F, F)>>( 340 &self, 341 mut src_limbs: Peekable<I>, 342 ty: OptimizationType, 343 ) -> Vec<F> { 344 let capacity = F::size_in_bits() - 1; 345 let mut dest_limbs = Vec::<F>::new(); 346 347 let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty); 348 349 // Prepare a reusable vector to be used in overhead calculation. 350 let mut num_bits = Vec::new(); 351 352 while let Some(first) = src_limbs.next() { 353 let second = src_limbs.peek(); 354 355 let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one(), &mut num_bits); 356 let second_max_bits_per_limb = if let Some(second) = second { 357 params.bits_per_limb + crate::overhead!(second.1 + F::one(), &mut num_bits) 358 } else { 359 0 360 }; 361 362 if let Some(second) = second { 363 if first_max_bits_per_limb + second_max_bits_per_limb <= capacity { 364 let adjustment_factor = &self.adjustment_factor_lookup_table[second_max_bits_per_limb]; 365 366 dest_limbs.push(first.0 * adjustment_factor + second.0); 367 src_limbs.next(); 368 } else { 369 dest_limbs.push(first.0); 370 } 371 } else { 372 dest_limbs.push(first.0); 373 } 374 } 375 376 dest_limbs 377 } 378 379 /// Convert a `TargetField` element into limbs (not constraints) 380 /// This is an internal function that would be reused by a number of other 381 /// functions 382 pub fn get_limbs_representations<TargetField: PrimeField>( 383 elem: &TargetField, 384 optimization_type: OptimizationType, 385 ) -> SmallVec<[F; 10]> { 386 Self::get_limbs_representations_from_big_integer::<TargetField>(&elem.to_bigint(), optimization_type) 387 } 388 389 /// Obtain the limbs directly from a big int 390 pub fn get_limbs_representations_from_big_integer<TargetField: PrimeField>( 391 elem: &<TargetField as PrimeField>::BigInteger, 392 optimization_type: OptimizationType, 393 ) -> SmallVec<[F; 10]> { 394 let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type); 395 396 // Prepare a reusable vector for the BE bits. 397 let mut cur_bits = Vec::new(); 398 // Push the lower limbs first 399 let mut limbs: SmallVec<[F; 10]> = SmallVec::new(); 400 let mut cur = *elem; 401 for _ in 0..params.num_limbs { 402 cur.write_bits_be(&mut cur_bits); // `write_bits_be` is big endian 403 let cur_mod_r = 404 <F as PrimeField>::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..]) 405 .unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want. 406 limbs.push(F::from_bigint(cur_mod_r).unwrap()); 407 cur.divn(params.bits_per_limb as u32); 408 // Clear the vector after every iteration so its allocation can be reused. 409 cur_bits.clear(); 410 } 411 412 // then we reverse, so that the limbs are ``big limb first'' 413 limbs.reverse(); 414 415 limbs 416 } 417 418 /// Push elements to sponge, treated in the non-native field 419 /// representations. 420 pub fn push_elements_to_sponge<TargetField: PrimeField>( 421 &mut self, 422 src: impl IntoIterator<Item = TargetField>, 423 ty: OptimizationType, 424 ) { 425 let src_limbs = src 426 .into_iter() 427 .flat_map(|elem| { 428 let limbs = Self::get_limbs_representations(&elem, ty); 429 limbs.into_iter().map(|limb| (limb, F::one())) 430 // specifically set to one, since most gadgets in the constraint 431 // world would not have zero noise (due to the relatively weak 432 // normal form testing in `alloc`) 433 }) 434 .peekable(); 435 436 let dest_limbs = self.compress_elements::<TargetField, _>(src_limbs, ty); 437 self.absorb_native_field_elements(&dest_limbs); 438 } 439 440 /// obtain random bits from hashchain. 441 /// not guaranteed to be uniformly distributed, should only be used in 442 /// certain situations. 443 pub fn get_bits(&mut self, num_bits: usize) -> Vec<bool> { 444 let bits_per_element = F::size_in_bits() - 1; 445 let num_elements = num_bits.div_ceil(bits_per_element); 446 447 let src_elements = self.squeeze_native_field_elements(num_elements); 448 let mut dest_bits = Vec::<bool>::with_capacity(num_elements * bits_per_element); 449 450 let skip = (F::Parameters::REPR_SHAVE_BITS + 1) as usize; 451 for elem in src_elements.iter() { 452 // discard the highest bit 453 let elem_bits = elem.to_bigint().to_bits_be(); 454 dest_bits.extend_from_slice(&elem_bits[skip..]); 455 } 456 dest_bits.truncate(num_bits); 457 458 dest_bits 459 } 460 461 /// obtain random field elements from hashchain. 462 /// not guaranteed to be uniformly distributed, should only be used in 463 /// certain situations. 464 pub fn get_fe<TargetField: PrimeField>( 465 &mut self, 466 num_elements: usize, 467 outputs_short_elements: bool, 468 ) -> SmallVec<[TargetField; 10]> { 469 let num_bits_per_nonnative = if outputs_short_elements { 470 168 471 } else { 472 TargetField::size_in_bits() - 1 // also omit the highest bit 473 }; 474 let bits = self.get_bits(num_bits_per_nonnative * num_elements); 475 476 let mut lookup_table = Vec::<TargetField>::with_capacity(num_bits_per_nonnative); 477 let mut cur = TargetField::one(); 478 for _ in 0..num_bits_per_nonnative { 479 lookup_table.push(cur); 480 cur.double_in_place(); 481 } 482 483 let dest_elements = bits 484 .chunks_exact(num_bits_per_nonnative) 485 .map(|per_nonnative_bits| { 486 // technically, this can be done via BigInteger::from_bits; here, we use this 487 // method for consistency with the gadget counterpart 488 let mut res = TargetField::zero(); 489 490 for (i, bit) in per_nonnative_bits.iter().rev().enumerate() { 491 if *bit { 492 res += &lookup_table[i]; 493 } 494 } 495 res 496 }) 497 .collect::<SmallVec<_>>(); 498 debug_assert_eq!(dest_elements.len(), num_elements); 499 500 dest_elements 501 } 502 }