poseidon_grain_lfsr.rs
1 // Copyright (c) 2019-2025 Alpha-Delta Network Inc. 2 // This file is part of the deltavm library. 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at: 7 8 // http://www.apache.org/licenses/LICENSE-2.0 9 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #![allow(dead_code)] 17 18 use crate::{FieldParameters, PrimeField}; 19 use deltavm_utilities::FromBits; 20 21 use anyhow::{Result, bail}; 22 23 pub struct PoseidonGrainLFSR { 24 pub field_size_in_bits: u64, 25 pub state: [bool; 80], 26 pub head: usize, 27 } 28 29 impl PoseidonGrainLFSR { 30 pub fn new( 31 is_sbox_an_inverse: bool, 32 field_size_in_bits: u64, 33 state_len: u64, 34 num_full_rounds: u64, 35 num_partial_rounds: u64, 36 ) -> Self { 37 let mut state = [false; 80]; 38 39 // b0, b1 describes the field 40 state[1] = true; 41 42 // b2, ..., b5 describes the S-BOX 43 state[5] = is_sbox_an_inverse; 44 45 // b6, ..., b17 are the binary representation of n (prime_num_bits) 46 { 47 let mut cur = field_size_in_bits; 48 for i in (6..=17).rev() { 49 state[i] = cur & 1 == 1; 50 cur >>= 1; 51 } 52 } 53 54 // b18, ..., b29 are the binary representation of t (state_len, rate + capacity) 55 { 56 let mut cur = state_len; 57 for i in (18..=29).rev() { 58 state[i] = cur & 1 == 1; 59 cur >>= 1; 60 } 61 } 62 63 // b30, ..., b39 are the binary representation of R_F (the number of full rounds) 64 { 65 let mut cur = num_full_rounds; 66 for i in (30..=39).rev() { 67 state[i] = cur & 1 == 1; 68 cur >>= 1; 69 } 70 } 71 72 // b40, ..., b49 are the binary representation of R_P (the number of partial rounds) 73 { 74 let mut cur = num_partial_rounds; 75 for i in (40..=49).rev() { 76 state[i] = cur & 1 == 1; 77 cur >>= 1; 78 } 79 } 80 81 // b50, ..., b79 are set to 1 82 state[50..=79].copy_from_slice(&[true; 30]); 83 84 // Initialize. 85 let mut res = Self { field_size_in_bits, state, head: 0 }; 86 for _ in 0..160 { 87 res.next_bit(); 88 } 89 res 90 } 91 92 pub fn get_field_elements_rejection_sampling<F: PrimeField>(&mut self, num_elements: usize) -> Result<Vec<F>> { 93 // Ensure the number of bits matches the modulus. 94 if self.field_size_in_bits != F::Parameters::MODULUS_BITS as u64 { 95 bail!("The number of bits in the field must match the modulus"); 96 } 97 98 let mut output = Vec::with_capacity(num_elements); 99 let mut bits = Vec::with_capacity(self.field_size_in_bits as usize); 100 for _ in 0..num_elements { 101 // Perform rejection sampling. 102 loop { 103 // Obtain `n` bits and make it most-significant-bit first. 104 bits.extend(self.get_bits(self.field_size_in_bits as usize)); 105 bits.reverse(); 106 // Construct the number. 107 let bigint = F::BigInteger::from_bits_le(&bits)?; 108 bits.clear(); 109 // Ensure the number is in the field. 110 if let Some(element) = F::from_bigint(bigint) { 111 output.push(element); 112 break; 113 } 114 } 115 } 116 Ok(output) 117 } 118 119 pub fn get_field_elements_mod_p<F: PrimeField>(&mut self, num_elems: usize) -> Result<Vec<F>> { 120 // Ensure the number of bits matches the modulus. 121 let num_bits = self.field_size_in_bits; 122 if num_bits != F::Parameters::MODULUS_BITS as u64 { 123 bail!("The number of bits in the field must match the modulus"); 124 } 125 126 // Prepare reusable vectors for the intermediate bits and bytes. 127 let mut bits = Vec::with_capacity(num_bits as usize); 128 let mut bytes = Vec::with_capacity((num_bits as usize).div_ceil(8)); 129 130 let mut output = Vec::with_capacity(num_elems); 131 for _ in 0..num_elems { 132 // Obtain `n` bits and make it most-significant-bit first. 133 let bits_iter = self.get_bits(num_bits as usize); 134 for bit in bits_iter { 135 bits.push(bit); 136 } 137 bits.reverse(); 138 139 for byte in bits 140 .chunks(8) 141 .map(|chunk| { 142 let mut sum = chunk[0] as u8; 143 let mut cur = 1; 144 for i in chunk.iter().skip(1) { 145 cur *= 2; 146 sum += cur * (*i as u8); 147 } 148 sum 149 }) 150 .rev() 151 { 152 bytes.push(byte); 153 } 154 155 output.push(F::from_bytes_be_mod_order(&bytes)); 156 157 // Clear the vectors of bits and bytes so they can be reused 158 // in the next iteration. 159 bits.clear(); 160 bytes.clear(); 161 } 162 Ok(output) 163 } 164 } 165 166 impl PoseidonGrainLFSR { 167 #[inline] 168 fn get_bits(&mut self, num_bits: usize) -> LFSRIter<'_> { 169 LFSRIter { lfsr: self, num_bits, current_bit: 0 } 170 } 171 172 #[inline] 173 fn next_bit(&mut self) -> bool { 174 let next_bit = self.state[(self.head + 62) % 80] 175 ^ self.state[(self.head + 51) % 80] 176 ^ self.state[(self.head + 38) % 80] 177 ^ self.state[(self.head + 23) % 80] 178 ^ self.state[(self.head + 13) % 80] 179 ^ self.state[self.head]; 180 self.state[self.head] = next_bit; 181 self.head += 1; 182 self.head %= 80; 183 184 next_bit 185 } 186 } 187 188 pub struct LFSRIter<'a> { 189 lfsr: &'a mut PoseidonGrainLFSR, 190 num_bits: usize, 191 current_bit: usize, 192 } 193 194 impl Iterator for LFSRIter<'_> { 195 type Item = bool; 196 197 fn next(&mut self) -> Option<Self::Item> { 198 if self.current_bit < self.num_bits { 199 // Obtain the first bit 200 let mut new_bit = self.lfsr.next_bit(); 201 202 // Loop until the first bit is true 203 while !new_bit { 204 // Discard the second bit 205 let _ = self.lfsr.next_bit(); 206 // Obtain another first bit 207 new_bit = self.lfsr.next_bit(); 208 } 209 self.current_bit += 1; 210 211 // Obtain the second bit 212 Some(self.lfsr.next_bit()) 213 } else { 214 None 215 } 216 } 217 } 218 219 impl ExactSizeIterator for LFSRIter<'_> { 220 fn len(&self) -> usize { 221 self.num_bits 222 } 223 }