/ fields / src / traits / poseidon_default.rs
poseidon_default.rs
  1  // Copyright (c) 2019-2025 Alpha-Delta Network Inc.
  2  // This file is part of the alphavm library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use crate::{PoseidonGrainLFSR, PrimeField, serial_batch_inversion_and_mul};
 17  use alphastd::{end_timer, start_timer};
 18  use itertools::Itertools;
 19  
 20  use anyhow::{Result, bail};
 21  
 22  /// Parameters and RNG used
 23  #[derive(Debug, Clone, PartialEq, Eq)]
 24  pub struct PoseidonParameters<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
 25      /// number of rounds in a full-round operation
 26      pub full_rounds: usize,
 27      /// number of rounds in a partial-round operation
 28      pub partial_rounds: usize,
 29      /// Exponent used in S-boxes
 30      pub alpha: u64,
 31      /// Additive Round keys. These are added before each MDS matrix application to make it an affine shift.
 32      /// They are indexed by `ark[round_num][state_element_index]`
 33      pub ark: Vec<Vec<F>>,
 34      /// Maximally Distance Separating Matrix.
 35      pub mds: Vec<Vec<F>>,
 36  }
 37  
 38  /// A field with Poseidon parameters associated
 39  pub trait PoseidonDefaultField {
 40      /// Obtain the default Poseidon parameters for this rate and for this prime field,
 41      /// with a specific optimization goal.
 42      fn default_poseidon_parameters<const RATE: usize>() -> Result<PoseidonParameters<Self, RATE, 1>>
 43      where
 44          Self: PrimeField,
 45      {
 46          /// Internal function that computes the ark and mds from the Poseidon Grain LFSR.
 47          #[allow(clippy::type_complexity)]
 48          fn find_poseidon_ark_and_mds<F: PrimeField, const RATE: usize>(
 49              full_rounds: u64,
 50              partial_rounds: u64,
 51              skip_matrices: u64,
 52          ) -> Result<(Vec<Vec<F>>, Vec<Vec<F>>)> {
 53              let lfsr_time = start_timer!(|| "LFSR Init");
 54              let mut lfsr =
 55                  PoseidonGrainLFSR::new(false, F::size_in_bits() as u64, (RATE + 1) as u64, full_rounds, partial_rounds);
 56              end_timer!(lfsr_time);
 57  
 58              let ark_time = start_timer!(|| "Constructing ARK");
 59              let mut ark = Vec::with_capacity((full_rounds + partial_rounds) as usize);
 60              for _ in 0..(full_rounds + partial_rounds) {
 61                  ark.push(lfsr.get_field_elements_rejection_sampling(RATE + 1)?);
 62              }
 63              end_timer!(ark_time);
 64  
 65              let skip_time = start_timer!(|| "Skipping matrices");
 66              for _ in 0..skip_matrices {
 67                  let _ = lfsr.get_field_elements_mod_p::<F>(2 * (RATE + 1))?;
 68              }
 69              end_timer!(skip_time);
 70  
 71              // A qualifying matrix must satisfy the following requirements:
 72              // - There is no duplication among the elements in x or y.
 73              // - There is no i and j such that x[i] + y[j] = p.
 74              // - There resultant MDS passes all three tests.
 75  
 76              let xs = lfsr.get_field_elements_mod_p::<F>(RATE + 1)?;
 77              let ys = lfsr.get_field_elements_mod_p::<F>(RATE + 1)?;
 78  
 79              let mds_time = start_timer!(|| "Construct MDS");
 80              let mut mds_flattened = vec![F::zero(); (RATE + 1) * (RATE + 1)];
 81              for (x, mds_row_i) in xs.iter().take(RATE + 1).zip_eq(mds_flattened.chunks_mut(RATE + 1)) {
 82                  for (y, e) in ys.iter().take(RATE + 1).zip_eq(mds_row_i) {
 83                      *e = *x + y;
 84                  }
 85              }
 86              serial_batch_inversion_and_mul(&mut mds_flattened, &F::one());
 87              let mds = mds_flattened.chunks(RATE + 1).map(|row| row.to_vec()).collect();
 88              end_timer!(mds_time);
 89  
 90              Ok((ark, mds))
 91          }
 92  
 93          match Self::Parameters::PARAMS_OPT_FOR_CONSTRAINTS.iter().find(|entry| entry.rate == RATE) {
 94              Some(entry) => {
 95                  let (ark, mds) = find_poseidon_ark_and_mds::<Self, RATE>(
 96                      entry.full_rounds as u64,
 97                      entry.partial_rounds as u64,
 98                      entry.skip_matrices as u64,
 99                  )?;
100                  Ok(PoseidonParameters {
101                      full_rounds: entry.full_rounds,
102                      partial_rounds: entry.partial_rounds,
103                      alpha: entry.alpha as u64,
104                      ark,
105                      mds,
106                  })
107              }
108              None => bail!("No Poseidon parameters were found for this rate"),
109          }
110      }
111  }
112  
113  /// A trait for default Poseidon parameters associated with a prime field
114  pub trait PoseidonDefaultParameters {
115      /// An array of the parameters optimized for constraints
116      /// (rate, alpha, full_rounds, partial_rounds, skip_matrices)
117      /// for rate = 2, 3, 4, 5, 6, 7, 8
118      ///
119      /// Here, `skip_matrices` denote how many matrices to skip before
120      /// finding one that satisfy all the requirements.
121      const PARAMS_OPT_FOR_CONSTRAINTS: [PoseidonDefaultParametersEntry; 7];
122  }
123  
124  /// An entry in the default Poseidon parameters
125  pub struct PoseidonDefaultParametersEntry {
126      /// The rate (in terms of number of field elements).
127      pub rate: usize,
128      /// Exponent used in S-boxes.
129      pub alpha: usize,
130      /// Number of rounds in a full-round operation.
131      pub full_rounds: usize,
132      /// Number of rounds in a partial-round operation.
133      pub partial_rounds: usize,
134      /// Number of matrices to skip when generating parameters using the Grain LFSR.
135      ///
136      /// The matrices being skipped are those that do not satisfy all the desired properties.
137      /// See the [reference implementation](https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/generate_parameters_grain.sage) for more detail.
138      pub skip_matrices: usize,
139  }
140  
141  impl PoseidonDefaultParametersEntry {
142      /// Create an entry in PoseidonDefaultParameters.
143      pub const fn new(
144          rate: usize,
145          alpha: usize,
146          full_rounds: usize,
147          partial_rounds: usize,
148          skip_matrices: usize,
149      ) -> Self {
150          Self { rate, alpha, full_rounds, partial_rounds, skip_matrices }
151      }
152  }