poseidon.rs
1 // Copyright (c) 2025-2026 ACDC Network 2 // This file is part of the alphavm library. 3 // 4 // Alpha Chain | Delta Chain Protocol 5 // International Monetary Graphite. 6 // 7 // Derived from Aleo (https://aleo.org) and ProvableHQ (https://provable.com). 8 // They built world-class ZK infrastructure. We installed the EASY button. 9 // Their cryptography: elegant. Our modifications: bureaucracy-compatible. 10 // Original brilliance: theirs. Robert's Rules: ours. Bugs: definitely ours. 11 // 12 // Original Aleo/ProvableHQ code subject to Apache 2.0 https://www.apache.org/licenses/LICENSE-2.0 13 // All modifications and new work: CC0 1.0 Universal Public Domain Dedication. 14 // No rights reserved. No permission required. No warranty. No refunds. 15 // 16 // https://creativecommons.org/publicdomain/zero/1.0/ 17 // SPDX-License-Identifier: CC0-1.0 18 19 use crate::{nonnative_params::*, AlgebraicSponge, DuplexSpongeMode}; 20 use alphavm_fields::{FieldParameters, PoseidonParameters, PrimeField, ToConstraintField}; 21 use alphavm_utilities::{BigInteger, FromBits, ToBits}; 22 23 use smallvec::SmallVec; 24 use std::{ 25 iter::Peekable, 26 ops::{Index, IndexMut}, 27 sync::Arc, 28 }; 29 30 #[derive(Copy, Clone, Debug)] 31 pub struct State<F: PrimeField, const RATE: usize, const CAPACITY: usize> { 32 capacity_state: [F; CAPACITY], 33 rate_state: [F; RATE], 34 } 35 36 impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Default for State<F, RATE, CAPACITY> { 37 fn default() -> Self { 38 Self { capacity_state: [F::zero(); CAPACITY], rate_state: [F::zero(); RATE] } 39 } 40 } 41 42 impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> State<F, RATE, CAPACITY> { 43 /// Returns an immutable iterator over the state. 44 pub fn iter(&self) -> impl Iterator<Item = &F> + Clone { 45 self.capacity_state.iter().chain(self.rate_state.iter()) 46 } 47 48 /// Returns a mutable iterator over the state. 49 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut F> { 50 self.capacity_state.iter_mut().chain(self.rate_state.iter_mut()) 51 } 52 } 53 54 impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> Index<usize> for State<F, RATE, CAPACITY> { 55 type Output = F; 56 57 fn index(&self, index: usize) -> &Self::Output { 58 assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY); 59 if index < CAPACITY { 60 &self.capacity_state[index] 61 } else { 62 &self.rate_state[index - CAPACITY] 63 } 64 } 65 } 66 67 impl<F: PrimeField, const RATE: usize, const CAPACITY: usize> IndexMut<usize> for State<F, RATE, CAPACITY> { 68 fn index_mut(&mut self, index: usize) -> &mut Self::Output { 69 assert!(index < RATE + CAPACITY, "Index out of bounds: index is {} but length is {}", index, RATE + CAPACITY); 70 if index < CAPACITY { 71 &mut self.capacity_state[index] 72 } else { 73 &mut self.rate_state[index - CAPACITY] 74 } 75 } 76 } 77 78 #[derive(Clone, Debug, PartialEq, Eq)] 79 pub struct Poseidon<F: PrimeField, const RATE: usize> { 80 parameters: Arc<PoseidonParameters<F, RATE, 1>>, 81 } 82 83 impl<F: PrimeField, const RATE: usize> Poseidon<F, RATE> { 84 /// Initializes a new instance of the cryptographic hash function. 85 pub fn setup() -> Self { 86 Self { parameters: Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) } 87 } 88 89 /// Evaluate the cryptographic hash function over a list of field elements 90 /// as input. 91 pub fn evaluate(&self, input: &[F]) -> F { 92 self.evaluate_many(input, 1)[0] 93 } 94 95 /// Evaluate the cryptographic hash function over a list of field elements 96 /// as input, and returns the specified number of field elements as 97 /// output. 98 pub fn evaluate_many(&self, input: &[F], num_outputs: usize) -> Vec<F> { 99 let mut sponge = PoseidonSponge::<F, RATE, 1>::new_with_parameters(&self.parameters); 100 sponge.absorb_native_field_elements(input); 101 sponge.squeeze_native_field_elements(num_outputs).to_vec() 102 } 103 104 /// Evaluate the cryptographic hash function over a non-fixed-length vector, 105 /// in which the length also needs to be hashed. 106 pub fn evaluate_with_len(&self, input: &[F]) -> F { 107 self.evaluate(&[vec![F::from(input.len() as u128)], input.to_vec()].concat()) 108 } 109 110 pub fn parameters(&self) -> &Arc<PoseidonParameters<F, RATE, 1>> { 111 &self.parameters 112 } 113 } 114 115 /// A duplex sponge based using the Poseidon permutation. 116 /// 117 /// This implementation of Poseidon is entirely from Fractal's implementation in 118 /// [COS20][cos] with small syntax changes. 119 /// 120 /// [cos]: https://eprint.iacr.org/2019/1076 121 #[derive(Clone, Debug)] 122 pub struct PoseidonSponge<F: PrimeField, const RATE: usize, const CAPACITY: usize> { 123 /// Sponge Parameters 124 parameters: Arc<PoseidonParameters<F, RATE, CAPACITY>>, 125 /// Current sponge's state (current elements in the permutation block) 126 state: State<F, RATE, CAPACITY>, 127 /// Current mode (whether its absorbing or squeezing) 128 pub mode: DuplexSpongeMode, 129 /// A persistent lookup table used when compressing elements. 130 adjustment_factor_lookup_table: Arc<[F]>, 131 } 132 133 impl<F: PrimeField, const RATE: usize> AlgebraicSponge<F, RATE> for PoseidonSponge<F, RATE, 1> { 134 type Parameters = Arc<PoseidonParameters<F, RATE, 1>>; 135 136 fn sample_parameters() -> Self::Parameters { 137 Arc::new(F::default_poseidon_parameters::<RATE>().unwrap()) 138 } 139 140 fn new_with_parameters(parameters: &Self::Parameters) -> Self { 141 Self { 142 parameters: parameters.clone(), 143 state: State::default(), 144 mode: DuplexSpongeMode::Absorbing { next_absorb_index: 0 }, 145 adjustment_factor_lookup_table: { 146 let capacity = F::size_in_bits() - 1; 147 let mut table = Vec::<F>::with_capacity(capacity); 148 149 let mut cur = F::one(); 150 for _ in 0..capacity { 151 table.push(cur); 152 cur.double_in_place(); 153 } 154 155 table.into() 156 }, 157 } 158 } 159 160 /// Takes in field elements. 161 fn absorb_native_field_elements<T: ToConstraintField<F>>(&mut self, elements: &[T]) { 162 let input = elements.iter().flat_map(|e| e.to_field_elements().unwrap()).collect::<Vec<_>>(); 163 if !input.is_empty() { 164 match self.mode { 165 DuplexSpongeMode::Absorbing { mut next_absorb_index } => { 166 if next_absorb_index == RATE { 167 self.permute(); 168 next_absorb_index = 0; 169 } 170 self.absorb_internal(next_absorb_index, &input); 171 } 172 DuplexSpongeMode::Squeezing { next_squeeze_index: _ } => { 173 self.permute(); 174 self.absorb_internal(0, &input); 175 } 176 } 177 } 178 } 179 180 /// Takes in field elements. 181 fn absorb_nonnative_field_elements<Target: PrimeField>(&mut self, elements: impl IntoIterator<Item = Target>) { 182 Self::push_elements_to_sponge(self, elements, OptimizationType::Weight); 183 } 184 185 fn squeeze_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> { 186 self.get_fe(num, false) 187 } 188 189 fn squeeze_native_field_elements(&mut self, num_elements: usize) -> SmallVec<[F; 10]> { 190 if num_elements == 0 { 191 return SmallVec::<[F; 10]>::new(); 192 } 193 let mut output = if num_elements <= 10 { 194 smallvec::smallvec_inline![F::zero(); 10] 195 } else { 196 smallvec::smallvec![F::zero(); num_elements] 197 }; 198 199 match self.mode { 200 DuplexSpongeMode::Absorbing { next_absorb_index: _ } => { 201 self.permute(); 202 self.squeeze_internal(0, &mut output[..num_elements]); 203 } 204 DuplexSpongeMode::Squeezing { mut next_squeeze_index } => { 205 if next_squeeze_index == RATE { 206 self.permute(); 207 next_squeeze_index = 0; 208 } 209 self.squeeze_internal(next_squeeze_index, &mut output[..num_elements]); 210 } 211 } 212 213 output.truncate(num_elements); 214 output 215 } 216 217 /// Takes out field elements of 168 bits. 218 fn squeeze_short_nonnative_field_elements<Target: PrimeField>(&mut self, num: usize) -> SmallVec<[Target; 10]> { 219 self.get_fe(num, true) 220 } 221 } 222 223 impl<F: PrimeField, const RATE: usize> PoseidonSponge<F, RATE, 1> { 224 #[inline] 225 fn apply_ark(&mut self, round_number: usize) { 226 for (state_elem, ark_elem) in self.state.iter_mut().zip(&self.parameters.ark[round_number]) { 227 *state_elem += ark_elem; 228 } 229 } 230 231 #[inline] 232 fn apply_s_box(&mut self, is_full_round: bool) { 233 if is_full_round { 234 // Full rounds apply the S Box (x^alpha) to every element of state 235 for elem in self.state.iter_mut() { 236 *elem = elem.pow([self.parameters.alpha]); 237 } 238 } else { 239 // Partial rounds apply the S Box (x^alpha) to just the first element of state 240 self.state[0] = self.state[0].pow([self.parameters.alpha]); 241 } 242 } 243 244 #[inline] 245 fn apply_mds(&mut self) { 246 let mut new_state = State::default(); 247 new_state.iter_mut().zip(&self.parameters.mds).for_each(|(new_elem, mds_row)| { 248 *new_elem = F::sum_of_products(self.state.iter(), mds_row.iter()); 249 }); 250 self.state = new_state; 251 } 252 253 #[inline] 254 fn permute(&mut self) { 255 // Determine the partial rounds range bound. 256 let partial_rounds = self.parameters.partial_rounds; 257 let full_rounds = self.parameters.full_rounds; 258 let full_rounds_over_2 = full_rounds / 2; 259 let partial_round_range = full_rounds_over_2..(full_rounds_over_2 + partial_rounds); 260 261 // Iterate through all rounds to permute. 262 for i in 0..(partial_rounds + full_rounds) { 263 let is_full_round = !partial_round_range.contains(&i); 264 self.apply_ark(i); 265 self.apply_s_box(is_full_round); 266 self.apply_mds(); 267 } 268 } 269 270 /// Absorbs everything in elements, this does not end in an absorption. 271 #[inline] 272 fn absorb_internal(&mut self, mut rate_start: usize, input: &[F]) { 273 if !input.is_empty() { 274 let first_chunk_size = std::cmp::min(RATE - rate_start, input.len()); 275 let num_elements_remaining = input.len() - first_chunk_size; 276 let (first_chunk, rest_chunk) = input.split_at(first_chunk_size); 277 let rest_chunks = rest_chunk.chunks(RATE); 278 // The total number of chunks is `elements[num_elements_remaining..].len() / 279 // RATE`, plus 1 for the remainder. 280 let total_num_chunks = 1 + // 1 for the first chunk 281 // We add all the chunks that are perfectly divisible by `RATE` 282 (num_elements_remaining / RATE) + 283 // And also add 1 if the last chunk is non-empty 284 // (i.e. if `num_elements_remaining` is not a multiple of `RATE`) 285 usize::from(!num_elements_remaining.is_multiple_of(RATE)); 286 287 // Absorb the input elements, `RATE` elements at a time, except for the first 288 // chunk, which is of size `RATE - rate_start`. 289 for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() { 290 for (element, state_elem) in chunk.iter().zip(&mut self.state.rate_state[rate_start..]) { 291 *state_elem += element; 292 } 293 // Are we in the last chunk? 294 // If so, let's wrap up. 295 if i == total_num_chunks - 1 { 296 self.mode = DuplexSpongeMode::Absorbing { next_absorb_index: rate_start + chunk.len() }; 297 return; 298 } else { 299 self.permute(); 300 } 301 rate_start = 0; 302 } 303 } 304 } 305 306 /// Squeeze |output| many elements. This does not end in a squeeze 307 #[inline] 308 fn squeeze_internal(&mut self, mut rate_start: usize, output: &mut [F]) { 309 let output_size = output.len(); 310 if output_size != 0 { 311 let first_chunk_size = std::cmp::min(RATE - rate_start, output.len()); 312 let num_output_remaining = output.len() - first_chunk_size; 313 let (first_chunk, rest_chunk) = output.split_at_mut(first_chunk_size); 314 assert_eq!(rest_chunk.len(), num_output_remaining); 315 let rest_chunks = rest_chunk.chunks_mut(RATE); 316 // The total number of chunks is `output[num_output_remaining..].len() / RATE`, 317 // plus 1 for the remainder. 318 let total_num_chunks = 1 + // 1 for the first chunk 319 // We add all the chunks that are perfectly divisible by `RATE` 320 (num_output_remaining / RATE) + 321 // And also add 1 if the last chunk is non-empty 322 // (i.e. if `num_output_remaining` is not a multiple of `RATE`) 323 usize::from(!num_output_remaining.is_multiple_of(RATE)); 324 325 // Absorb the input output, `RATE` output at a time, except for the first chunk, 326 // which is of size `RATE - rate_start`. 327 for (i, chunk) in std::iter::once(first_chunk).chain(rest_chunks).enumerate() { 328 let range = rate_start..(rate_start + chunk.len()); 329 debug_assert_eq!( 330 chunk.len(), 331 self.state.rate_state[range.clone()].len(), 332 "failed with squeeze {output_size} at rate {RATE} and rate_start {rate_start}" 333 ); 334 chunk.copy_from_slice(&self.state.rate_state[range]); 335 // Are we in the last chunk? 336 // If so, let's wrap up. 337 if i == total_num_chunks - 1 { 338 self.mode = DuplexSpongeMode::Squeezing { next_squeeze_index: (rate_start + chunk.len()) }; 339 return; 340 } else { 341 self.permute(); 342 } 343 rate_start = 0; 344 } 345 } 346 } 347 348 /// Compress every two elements if possible. 349 /// Provides a vector of (limb, num_of_additions), both of which are F. 350 pub fn compress_elements<TargetField: PrimeField, I: Iterator<Item = (F, F)>>( 351 &self, 352 mut src_limbs: Peekable<I>, 353 ty: OptimizationType, 354 ) -> Vec<F> { 355 let capacity = F::size_in_bits() - 1; 356 let mut dest_limbs = Vec::<F>::new(); 357 358 let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), ty); 359 360 // Prepare a reusable vector to be used in overhead calculation. 361 let mut num_bits = Vec::new(); 362 363 while let Some(first) = src_limbs.next() { 364 let second = src_limbs.peek(); 365 366 let first_max_bits_per_limb = params.bits_per_limb + crate::overhead!(first.1 + F::one(), &mut num_bits); 367 let second_max_bits_per_limb = if let Some(second) = second { 368 params.bits_per_limb + crate::overhead!(second.1 + F::one(), &mut num_bits) 369 } else { 370 0 371 }; 372 373 if let Some(second) = second { 374 if first_max_bits_per_limb + second_max_bits_per_limb <= capacity { 375 let adjustment_factor = &self.adjustment_factor_lookup_table[second_max_bits_per_limb]; 376 377 dest_limbs.push(first.0 * adjustment_factor + second.0); 378 src_limbs.next(); 379 } else { 380 dest_limbs.push(first.0); 381 } 382 } else { 383 dest_limbs.push(first.0); 384 } 385 } 386 387 dest_limbs 388 } 389 390 /// Convert a `TargetField` element into limbs (not constraints) 391 /// This is an internal function that would be reused by a number of other 392 /// functions 393 pub fn get_limbs_representations<TargetField: PrimeField>( 394 elem: &TargetField, 395 optimization_type: OptimizationType, 396 ) -> SmallVec<[F; 10]> { 397 Self::get_limbs_representations_from_big_integer::<TargetField>(&elem.to_bigint(), optimization_type) 398 } 399 400 /// Obtain the limbs directly from a big int 401 pub fn get_limbs_representations_from_big_integer<TargetField: PrimeField>( 402 elem: &<TargetField as PrimeField>::BigInteger, 403 optimization_type: OptimizationType, 404 ) -> SmallVec<[F; 10]> { 405 let params = get_params(TargetField::size_in_bits(), F::size_in_bits(), optimization_type); 406 407 // Prepare a reusable vector for the BE bits. 408 let mut cur_bits = Vec::new(); 409 // Push the lower limbs first 410 let mut limbs: SmallVec<[F; 10]> = SmallVec::new(); 411 let mut cur = *elem; 412 for _ in 0..params.num_limbs { 413 cur.write_bits_be(&mut cur_bits); // `write_bits_be` is big endian 414 let cur_mod_r = 415 <F as PrimeField>::BigInteger::from_bits_be(&cur_bits[cur_bits.len() - params.bits_per_limb..]) 416 .unwrap(); // therefore, the lowest `bits_per_non_top_limb` bits is what we want. 417 limbs.push(F::from_bigint(cur_mod_r).unwrap()); 418 cur.divn(params.bits_per_limb as u32); 419 // Clear the vector after every iteration so its allocation can be reused. 420 cur_bits.clear(); 421 } 422 423 // then we reverse, so that the limbs are ``big limb first'' 424 limbs.reverse(); 425 426 limbs 427 } 428 429 /// Push elements to sponge, treated in the non-native field 430 /// representations. 431 pub fn push_elements_to_sponge<TargetField: PrimeField>( 432 &mut self, 433 src: impl IntoIterator<Item = TargetField>, 434 ty: OptimizationType, 435 ) { 436 let src_limbs = src 437 .into_iter() 438 .flat_map(|elem| { 439 let limbs = Self::get_limbs_representations(&elem, ty); 440 limbs.into_iter().map(|limb| (limb, F::one())) 441 // specifically set to one, since most gadgets in the constraint 442 // world would not have zero noise (due to the relatively weak 443 // normal form testing in `alloc`) 444 }) 445 .peekable(); 446 447 let dest_limbs = self.compress_elements::<TargetField, _>(src_limbs, ty); 448 self.absorb_native_field_elements(&dest_limbs); 449 } 450 451 /// obtain random bits from hashchain. 452 /// not guaranteed to be uniformly distributed, should only be used in 453 /// certain situations. 454 pub fn get_bits(&mut self, num_bits: usize) -> Vec<bool> { 455 let bits_per_element = F::size_in_bits() - 1; 456 let num_elements = num_bits.div_ceil(bits_per_element); 457 458 let src_elements = self.squeeze_native_field_elements(num_elements); 459 let mut dest_bits = Vec::<bool>::with_capacity(num_elements * bits_per_element); 460 461 let skip = (F::Parameters::REPR_SHAVE_BITS + 1) as usize; 462 for elem in src_elements.iter() { 463 // discard the highest bit 464 let elem_bits = elem.to_bigint().to_bits_be(); 465 dest_bits.extend_from_slice(&elem_bits[skip..]); 466 } 467 dest_bits.truncate(num_bits); 468 469 dest_bits 470 } 471 472 /// obtain random field elements from hashchain. 473 /// not guaranteed to be uniformly distributed, should only be used in 474 /// certain situations. 475 pub fn get_fe<TargetField: PrimeField>( 476 &mut self, 477 num_elements: usize, 478 outputs_short_elements: bool, 479 ) -> SmallVec<[TargetField; 10]> { 480 let num_bits_per_nonnative = if outputs_short_elements { 481 168 482 } else { 483 TargetField::size_in_bits() - 1 // also omit the highest bit 484 }; 485 let bits = self.get_bits(num_bits_per_nonnative * num_elements); 486 487 let mut lookup_table = Vec::<TargetField>::with_capacity(num_bits_per_nonnative); 488 let mut cur = TargetField::one(); 489 for _ in 0..num_bits_per_nonnative { 490 lookup_table.push(cur); 491 cur.double_in_place(); 492 } 493 494 let dest_elements = bits 495 .chunks_exact(num_bits_per_nonnative) 496 .map(|per_nonnative_bits| { 497 // technically, this can be done via BigInteger::from_bits; here, we use this 498 // method for consistency with the gadget counterpart 499 let mut res = TargetField::zero(); 500 501 for (i, bit) in per_nonnative_bits.iter().rev().enumerate() { 502 if *bit { 503 res += &lookup_table[i]; 504 } 505 } 506 res 507 }) 508 .collect::<SmallVec<_>>(); 509 debug_assert_eq!(dest_elements.len(), num_elements); 510 511 dest_elements 512 } 513 }