/ fields / src / traits / poseidon_grain_lfsr.rs
poseidon_grain_lfsr.rs
  1  // Copyright (c) 2019-2025 Alpha-Delta Network Inc.
  2  // This file is part of the deltavm library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  #![allow(dead_code)]
 17  
 18  use crate::{FieldParameters, PrimeField};
 19  use deltavm_utilities::FromBits;
 20  
 21  use anyhow::{Result, bail};
 22  
 23  pub struct PoseidonGrainLFSR {
 24      pub field_size_in_bits: u64,
 25      pub state: [bool; 80],
 26      pub head: usize,
 27  }
 28  
 29  impl PoseidonGrainLFSR {
 30      pub fn new(
 31          is_sbox_an_inverse: bool,
 32          field_size_in_bits: u64,
 33          state_len: u64,
 34          num_full_rounds: u64,
 35          num_partial_rounds: u64,
 36      ) -> Self {
 37          let mut state = [false; 80];
 38  
 39          // b0, b1 describes the field
 40          state[1] = true;
 41  
 42          // b2, ..., b5 describes the S-BOX
 43          state[5] = is_sbox_an_inverse;
 44  
 45          // b6, ..., b17 are the binary representation of n (prime_num_bits)
 46          {
 47              let mut cur = field_size_in_bits;
 48              for i in (6..=17).rev() {
 49                  state[i] = cur & 1 == 1;
 50                  cur >>= 1;
 51              }
 52          }
 53  
 54          // b18, ..., b29 are the binary representation of t (state_len, rate + capacity)
 55          {
 56              let mut cur = state_len;
 57              for i in (18..=29).rev() {
 58                  state[i] = cur & 1 == 1;
 59                  cur >>= 1;
 60              }
 61          }
 62  
 63          // b30, ..., b39 are the binary representation of R_F (the number of full rounds)
 64          {
 65              let mut cur = num_full_rounds;
 66              for i in (30..=39).rev() {
 67                  state[i] = cur & 1 == 1;
 68                  cur >>= 1;
 69              }
 70          }
 71  
 72          // b40, ..., b49 are the binary representation of R_P (the number of partial rounds)
 73          {
 74              let mut cur = num_partial_rounds;
 75              for i in (40..=49).rev() {
 76                  state[i] = cur & 1 == 1;
 77                  cur >>= 1;
 78              }
 79          }
 80  
 81          // b50, ..., b79 are set to 1
 82          state[50..=79].copy_from_slice(&[true; 30]);
 83  
 84          // Initialize.
 85          let mut res = Self { field_size_in_bits, state, head: 0 };
 86          for _ in 0..160 {
 87              res.next_bit();
 88          }
 89          res
 90      }
 91  
 92      pub fn get_field_elements_rejection_sampling<F: PrimeField>(&mut self, num_elements: usize) -> Result<Vec<F>> {
 93          // Ensure the number of bits matches the modulus.
 94          if self.field_size_in_bits != F::Parameters::MODULUS_BITS as u64 {
 95              bail!("The number of bits in the field must match the modulus");
 96          }
 97  
 98          let mut output = Vec::with_capacity(num_elements);
 99          let mut bits = Vec::with_capacity(self.field_size_in_bits as usize);
100          for _ in 0..num_elements {
101              // Perform rejection sampling.
102              loop {
103                  // Obtain `n` bits and make it most-significant-bit first.
104                  bits.extend(self.get_bits(self.field_size_in_bits as usize));
105                  bits.reverse();
106                  // Construct the number.
107                  let bigint = F::BigInteger::from_bits_le(&bits)?;
108                  bits.clear();
109                  // Ensure the number is in the field.
110                  if let Some(element) = F::from_bigint(bigint) {
111                      output.push(element);
112                      break;
113                  }
114              }
115          }
116          Ok(output)
117      }
118  
119      pub fn get_field_elements_mod_p<F: PrimeField>(&mut self, num_elems: usize) -> Result<Vec<F>> {
120          // Ensure the number of bits matches the modulus.
121          let num_bits = self.field_size_in_bits;
122          if num_bits != F::Parameters::MODULUS_BITS as u64 {
123              bail!("The number of bits in the field must match the modulus");
124          }
125  
126          // Prepare reusable vectors for the intermediate bits and bytes.
127          let mut bits = Vec::with_capacity(num_bits as usize);
128          let mut bytes = Vec::with_capacity((num_bits as usize).div_ceil(8));
129  
130          let mut output = Vec::with_capacity(num_elems);
131          for _ in 0..num_elems {
132              // Obtain `n` bits and make it most-significant-bit first.
133              let bits_iter = self.get_bits(num_bits as usize);
134              for bit in bits_iter {
135                  bits.push(bit);
136              }
137              bits.reverse();
138  
139              for byte in bits
140                  .chunks(8)
141                  .map(|chunk| {
142                      let mut sum = chunk[0] as u8;
143                      let mut cur = 1;
144                      for i in chunk.iter().skip(1) {
145                          cur *= 2;
146                          sum += cur * (*i as u8);
147                      }
148                      sum
149                  })
150                  .rev()
151              {
152                  bytes.push(byte);
153              }
154  
155              output.push(F::from_bytes_be_mod_order(&bytes));
156  
157              // Clear the vectors of bits and bytes so they can be reused
158              // in the next iteration.
159              bits.clear();
160              bytes.clear();
161          }
162          Ok(output)
163      }
164  }
165  
166  impl PoseidonGrainLFSR {
167      #[inline]
168      fn get_bits(&mut self, num_bits: usize) -> LFSRIter<'_> {
169          LFSRIter { lfsr: self, num_bits, current_bit: 0 }
170      }
171  
172      #[inline]
173      fn next_bit(&mut self) -> bool {
174          let next_bit = self.state[(self.head + 62) % 80]
175              ^ self.state[(self.head + 51) % 80]
176              ^ self.state[(self.head + 38) % 80]
177              ^ self.state[(self.head + 23) % 80]
178              ^ self.state[(self.head + 13) % 80]
179              ^ self.state[self.head];
180          self.state[self.head] = next_bit;
181          self.head += 1;
182          self.head %= 80;
183  
184          next_bit
185      }
186  }
187  
188  pub struct LFSRIter<'a> {
189      lfsr: &'a mut PoseidonGrainLFSR,
190      num_bits: usize,
191      current_bit: usize,
192  }
193  
194  impl Iterator for LFSRIter<'_> {
195      type Item = bool;
196  
197      fn next(&mut self) -> Option<Self::Item> {
198          if self.current_bit < self.num_bits {
199              // Obtain the first bit
200              let mut new_bit = self.lfsr.next_bit();
201  
202              // Loop until the first bit is true
203              while !new_bit {
204                  // Discard the second bit
205                  let _ = self.lfsr.next_bit();
206                  // Obtain another first bit
207                  new_bit = self.lfsr.next_bit();
208              }
209              self.current_bit += 1;
210  
211              // Obtain the second bit
212              Some(self.lfsr.next_bit())
213          } else {
214              None
215          }
216      }
217  }
218  
219  impl ExactSizeIterator for LFSRIter<'_> {
220      fn len(&self) -> usize {
221          self.num_bits
222      }
223  }