/ circuit / algorithms / src / keccak / hash.rs
hash.rs
  1  // Copyright (c) 2019-2025 Alpha-Delta Network Inc.
  2  // This file is part of the alphavm library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use super::*;
 17  
 18  impl<E: Environment, const TYPE: u8, const VARIANT: usize> Hash for Keccak<E, TYPE, VARIANT> {
 19      type Input = Boolean<E>;
 20      type Output = Vec<Boolean<E>>;
 21  
 22      /// Returns the Keccak hash of the given input as bits.
 23      #[inline]
 24      fn hash(&self, input: &[Self::Input]) -> Self::Output {
 25          // The bitrate `r`.
 26          // The capacity is twice the digest length (i.e. twice the variant, where the variant is in {224, 256, 384, 512}),
 27          // and the bit rate is the width (1600 in our case) minus the capacity.
 28          let bitrate = PERMUTATION_WIDTH - 2 * VARIANT;
 29          debug_assert!(bitrate < PERMUTATION_WIDTH, "The bitrate must be less than the permutation width");
 30          debug_assert!(bitrate % 8 == 0, "The bitrate must be a multiple of 8");
 31  
 32          // Ensure the input is not empty.
 33          if input.is_empty() {
 34              E::halt("The input to the hash function must not be empty")
 35          }
 36  
 37          // The root state `s` is defined as `0^b`.
 38          let mut s = vec![Boolean::constant(false); PERMUTATION_WIDTH];
 39  
 40          // The padded blocks `P`.
 41          let padded_blocks = match TYPE {
 42              0 => Self::pad_keccak(input, bitrate),
 43              1 => Self::pad_sha3(input, bitrate),
 44              2.. => unreachable!("Invalid Keccak type"),
 45          };
 46  
 47          /* The first part of the sponge construction (the absorbing phase):
 48           *
 49           * for i = 0 to |P| − 1 do
 50           *   s = s ⊕ (P_i || 0^c) # Note: |P_i| + c == b, since |P_i| == r
 51           *   s = f(s)
 52           * end for
 53           */
 54          for block in padded_blocks {
 55              // s = s ⊕ (P_i || 0^c)
 56              for (j, bit) in block.into_iter().enumerate() {
 57                  s[j] = &s[j] ^ &bit;
 58              }
 59              // s = f(s)
 60              s = Self::permutation_f::<PERMUTATION_WIDTH, NUM_ROUNDS>(s, &self.round_constants, &self.rotl);
 61          }
 62  
 63          /* The second part of the sponge construction (the squeezing phase):
 64           *
 65           * Z = s[0..r-1]
 66           * while |Z| < d do // d is the digest length
 67           *   s = f(s)
 68           *   Z = Z || s[0..r-1]
 69           * end while
 70           * return Z[0..d-1]
 71           */
 72          // Z = s[0..r-1]
 73          let mut z = s[..bitrate].to_vec();
 74          // while |Z| < l do
 75          while z.len() < VARIANT {
 76              // s = f(s)
 77              s = Self::permutation_f::<PERMUTATION_WIDTH, NUM_ROUNDS>(s, &self.round_constants, &self.rotl);
 78              // Z = Z || s[0..r-1]
 79              z.extend(s.iter().take(bitrate).cloned());
 80          }
 81          // return Z[0..d-1]
 82          z.truncate(VARIANT);
 83          z
 84      }
 85  }
 86  
 87  impl<E: Environment, const TYPE: u8, const VARIANT: usize> Keccak<E, TYPE, VARIANT> {
 88      /// In Keccak, `pad` is a multi-rate padding, defined as `pad(M) = M || 0x01 || 0x00…0x00 || 0x80`,
 89      /// where `M` is the input data, and `0x01 || 0x00…0x00 || 0x80` is the padding.
 90      /// The padding extends the input data to a multiple of the bitrate `r`, defined as `r = b - c`,
 91      /// where `b` is the width of the permutation, and `c` is the capacity.
 92      fn pad_keccak(input: &[Boolean<E>], bitrate: usize) -> Vec<Vec<Boolean<E>>> {
 93          debug_assert!(bitrate > 0, "The bitrate must be positive");
 94  
 95          // Resize the input to a multiple of 8.
 96          let mut padded_input = input.to_vec();
 97          padded_input.resize(input.len().div_ceil(8) * 8, Boolean::constant(false));
 98  
 99          // Step 1: Append the bit "1" to the message.
100          padded_input.push(Boolean::constant(true));
101  
102          // Step 2: Append "0" bits until the length of the message is congruent to r-1 mod r.
103          while (padded_input.len() % bitrate) != (bitrate - 1) {
104              padded_input.push(Boolean::constant(false));
105          }
106  
107          // Step 3: Append the bit "1" to the message.
108          padded_input.push(Boolean::constant(true));
109  
110          // Construct the padded blocks.
111          let mut result = Vec::new();
112          for block in padded_input.chunks(bitrate) {
113              result.push(block.to_vec());
114          }
115          result
116      }
117  
118      /// In SHA-3, `pad` is a SHAKE, defined as `pad(M) = M || 0x06 || 0x00…0x00 || 0x80`,
119      /// where `M` is the input data, and `0x06 || 0x00…0x00 || 0x80` is the padding.
120      /// The padding extends the input data to a multiple of the bitrate `r`, defined as `r = b - c`,
121      /// where `b` is the width of the permutation, and `c` is the capacity.
122      fn pad_sha3(input: &[Boolean<E>], bitrate: usize) -> Vec<Vec<Boolean<E>>> {
123          debug_assert!(bitrate > 1, "The bitrate must be greater than 1");
124  
125          // Resize the input to a multiple of 8.
126          let mut padded_input = input.to_vec();
127          padded_input.resize(input.len().div_ceil(8) * 8, Boolean::constant(false));
128  
129          // Step 1: Append the "0x06" byte to the message.
130          padded_input.push(Boolean::constant(false));
131          padded_input.push(Boolean::constant(true));
132          padded_input.push(Boolean::constant(true));
133          padded_input.push(Boolean::constant(false));
134  
135          // Step 2: Append "0" bits until the length of the message is congruent to r-1 mod r.
136          while (padded_input.len() % bitrate) != (bitrate - 1) {
137              padded_input.push(Boolean::constant(false));
138          }
139  
140          // Step 3: Append the bit "1" to the message.
141          padded_input.push(Boolean::constant(true));
142  
143          // Construct the padded blocks.
144          let mut result = Vec::new();
145          for block in padded_input.chunks(bitrate) {
146              result.push(block.to_vec());
147          }
148          result
149      }
150  
151      /// The permutation `f` is a function that takes a fixed-length input and produces a fixed-length output,
152      /// defined as `f = Keccak-f[b]`, where `b := 25 * 2^l` is the width of the permutation,
153      /// and `l` is the log width of the permutation.
154      ///
155      /// The round function `Rnd` is applied `12 + 2l` times, where `l` is the log width of the permutation.
156      fn permutation_f<const WIDTH: usize, const NUM_ROUNDS: usize>(
157          input: Vec<Boolean<E>>,
158          round_constants: &[U64<E>],
159          rotl: &[usize],
160      ) -> Vec<Boolean<E>> {
161          debug_assert_eq!(input.len(), WIDTH, "The input vector must have {WIDTH} bits");
162          debug_assert_eq!(
163              round_constants.len(),
164              NUM_ROUNDS,
165              "The round constants vector must have {NUM_ROUNDS} elements"
166          );
167  
168          // Partition the input into 64-bit chunks.
169          let mut a = input.chunks(64).map(U64::from_bits_le).collect::<Vec<_>>();
170          // Permute the input.
171          for round_constant in round_constants.iter().take(NUM_ROUNDS) {
172              a = Self::round(a, round_constant, rotl);
173          }
174          // Return the permuted input.
175          let mut bits = Vec::with_capacity(input.len());
176          a.iter().for_each(|e| e.write_bits_le(&mut bits));
177          bits
178      }
179  
180      /// The round function `Rnd` is defined as follows:
181      /// ```text
182      /// Rnd = ι ◦ χ ◦ π ◦ ρ ◦ θ
183      /// ```
184      /// where `◦` denotes function composition.
185      fn round(a: Vec<U64<E>>, round_constant: &U64<E>, rotl: &[usize]) -> Vec<U64<E>> {
186          debug_assert_eq!(a.len(), MODULO * MODULO, "The input vector 'a' must have {} elements", MODULO * MODULO);
187  
188          /* The first part of Algorithm 1, θ:
189           *
190           * for x = 0 to 4 do
191           *   C[x] = a[x, 0]
192           *   for y = 1 to 4 do
193           *     C[x] = C[x] ⊕ a[x, y]
194           *   end for
195           * end for
196           */
197          let mut c = Vec::with_capacity(MODULO);
198          for x in 0..MODULO {
199              c.push(&a[x] ^ &a[x + MODULO] ^ &a[x + (2 * MODULO)] ^ &a[x + (3 * MODULO)] ^ &a[x + (4 * MODULO)]);
200          }
201  
202          /* The second part of Algorithm 1, θ:
203           *
204           * for x = 0 to 4 do
205           *   D[x] = C[x−1] ⊕ ROT(C[x+1],1)
206           *   for y = 0 to 4 do
207           *     A[x, y] = a[x, y] ⊕ D[x]
208           *   end for
209           * end for
210           */
211          let mut d = Vec::with_capacity(MODULO);
212          for x in 0..MODULO {
213              d.push(&c[(x + 4) % MODULO] ^ Self::rotate_left(&c[(x + 1) % MODULO], 63));
214          }
215          let mut a_1 = Vec::with_capacity(MODULO * MODULO);
216          for y in 0..MODULO {
217              for x in 0..MODULO {
218                  a_1.push(&a[x + (y * MODULO)] ^ &d[x]);
219              }
220          }
221  
222          /* Algorithm 3, π:
223           *
224           * for x = 0 to 4 do
225           *   for y = 0 to 4 do
226           *     (X, Y) = (y, (2*x + 3*y) mod 5)
227           *     A[X, Y] = a[x, y]
228           *   end for
229           * end for
230           *
231           * Algorithm 2, ρ:
232           *
233           * A[0, 0] = a[0, 0]
234           * (x, y) = (1, 0)
235           * for t = 0 to 23 do
236           *   A[x, y] = ROT(a[x, y], (t + 1)(t + 2)/2)
237           *   (x, y) = (y, (2*x + 3*y) mod 5)
238           * end for
239           */
240          let mut a_2 = a_1.clone();
241          for y in 0..MODULO {
242              for x in 0..MODULO {
243                  // This step combines the π and ρ steps into one.
244                  a_2[y + ((((2 * x) + (3 * y)) % MODULO) * MODULO)] =
245                      Self::rotate_left(&a_1[x + (y * MODULO)], rotl[x + (y * MODULO)]);
246              }
247          }
248  
249          /* Algorithm 4, χ:
250           *
251           * for y = 0 to 4 do
252           *   for x = 0 to 4 do
253           *     A[x, y] = a[x, y] ⊕ ((¬a[x+1, y]) ∧ a[x+2, y])
254           *   end for
255           * end for
256           */
257          let mut a_3 = Vec::with_capacity(MODULO * MODULO);
258          for y in 0..MODULO {
259              for x in 0..MODULO {
260                  let a = &a_2[x + (y * MODULO)];
261                  let b = &a_2[((x + 1) % MODULO) + (y * MODULO)];
262                  let c = &a_2[((x + 2) % MODULO) + (y * MODULO)];
263                  a_3.push(a ^ ((!b) & c));
264              }
265          }
266  
267          /* ι:
268           *
269           * A[0, 0] = A[0, 0] ⊕ RC
270           */
271          a_3[0] = &a_3[0] ^ round_constant;
272          a_3
273      }
274  
275      /// Performs a rotate left operation on the given `u64` value.
276      fn rotate_left(value: &U64<E>, n: usize) -> U64<E> {
277          // Perform the rotation.
278          let mut bits_le = value.to_bits_le();
279          bits_le.rotate_left(n);
280          // Return the rotated value.
281          U64::from_bits_le(&bits_le)
282      }
283  }
284  
285  #[cfg(test)]
286  mod tests {
287      use super::*;
288      use alphavm_circuit_types::environment::Circuit;
289      use console::Rng;
290  
291      const ITERATIONS: usize = 3;
292  
293      macro_rules! check_equivalence {
294          ($console:expr, $circuit:expr) => {
295              use console::Hash as H;
296  
297              let rng = &mut TestRng::default();
298  
299              let mut input_sizes = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 32, 64, 128, 256, 512, 1024];
300              input_sizes.extend((0..5).map(|_| rng.gen_range(1..1024)));
301  
302              for num_inputs in input_sizes {
303                  println!("Checking equivalence for {num_inputs} inputs");
304  
305                  // Prepare the preimage.
306                  let native_input = (0..num_inputs).map(|_| Uniform::rand(rng)).collect::<Vec<bool>>();
307                  let input = native_input.iter().map(|v| Boolean::<Circuit>::new(Mode::Private, *v)).collect::<Vec<_>>();
308  
309                  // Compute the console hash.
310                  let expected = $console.hash(&native_input).expect("Failed to hash console input");
311  
312                  // Compute the circuit hash.
313                  let candidate = $circuit.hash(&input);
314                  assert_eq!(expected, candidate.eject_value());
315                  Circuit::reset();
316              }
317          };
318      }
319  
320      fn check_hash(
321          mode: Mode,
322          num_inputs: usize,
323          num_constants: u64,
324          num_public: u64,
325          num_private: u64,
326          num_constraints: u64,
327          rng: &mut TestRng,
328      ) {
329          use console::Hash as H;
330  
331          let native = console::Keccak256::default();
332          let keccak = Keccak256::<Circuit>::new();
333  
334          for i in 0..ITERATIONS {
335              // Prepare the preimage.
336              let native_input = (0..num_inputs).map(|_| Uniform::rand(rng)).collect::<Vec<bool>>();
337              let input = native_input.iter().map(|v| Boolean::<Circuit>::new(mode, *v)).collect::<Vec<_>>();
338  
339              // Compute the native hash.
340              let expected = native.hash(&native_input).expect("Failed to hash native input");
341  
342              // Compute the circuit hash.
343              Circuit::scope(format!("Keccak {mode} {i}"), || {
344                  let candidate = keccak.hash(&input);
345                  assert_eq!(expected, candidate.eject_value());
346                  let case = format!("(mode = {mode}, num_inputs = {num_inputs})");
347                  assert_scope!(case, num_constants, num_public, num_private, num_constraints);
348              });
349              Circuit::reset();
350          }
351      }
352  
353      #[test]
354      fn test_keccak_256_hash_constant() {
355          let mut rng = TestRng::default();
356  
357          check_hash(Mode::Constant, 1, 0, 0, 0, 0, &mut rng);
358          check_hash(Mode::Constant, 2, 0, 0, 0, 0, &mut rng);
359          check_hash(Mode::Constant, 3, 0, 0, 0, 0, &mut rng);
360          check_hash(Mode::Constant, 4, 0, 0, 0, 0, &mut rng);
361          check_hash(Mode::Constant, 5, 0, 0, 0, 0, &mut rng);
362          check_hash(Mode::Constant, 6, 0, 0, 0, 0, &mut rng);
363          check_hash(Mode::Constant, 7, 0, 0, 0, 0, &mut rng);
364          check_hash(Mode::Constant, 8, 0, 0, 0, 0, &mut rng);
365          check_hash(Mode::Constant, 16, 0, 0, 0, 0, &mut rng);
366          check_hash(Mode::Constant, 32, 0, 0, 0, 0, &mut rng);
367          check_hash(Mode::Constant, 64, 0, 0, 0, 0, &mut rng);
368          check_hash(Mode::Constant, 128, 0, 0, 0, 0, &mut rng);
369          check_hash(Mode::Constant, 256, 0, 0, 0, 0, &mut rng);
370          check_hash(Mode::Constant, 511, 0, 0, 0, 0, &mut rng);
371          check_hash(Mode::Constant, 512, 0, 0, 0, 0, &mut rng);
372          check_hash(Mode::Constant, 513, 0, 0, 0, 0, &mut rng);
373          check_hash(Mode::Constant, 1023, 0, 0, 0, 0, &mut rng);
374          check_hash(Mode::Constant, 1024, 0, 0, 0, 0, &mut rng);
375          check_hash(Mode::Constant, 1025, 0, 0, 0, 0, &mut rng);
376      }
377  
378      #[test]
379      fn test_keccak_256_hash_public() {
380          let mut rng = TestRng::default();
381  
382          check_hash(Mode::Public, 1, 0, 0, 138157, 138157, &mut rng);
383          check_hash(Mode::Public, 2, 0, 0, 139108, 139108, &mut rng);
384          check_hash(Mode::Public, 3, 0, 0, 139741, 139741, &mut rng);
385          check_hash(Mode::Public, 4, 0, 0, 140318, 140318, &mut rng);
386          check_hash(Mode::Public, 5, 0, 0, 140879, 140879, &mut rng);
387          check_hash(Mode::Public, 6, 0, 0, 141350, 141350, &mut rng);
388          check_hash(Mode::Public, 7, 0, 0, 141787, 141787, &mut rng);
389          check_hash(Mode::Public, 8, 0, 0, 142132, 142132, &mut rng);
390          check_hash(Mode::Public, 16, 0, 0, 144173, 144173, &mut rng);
391          check_hash(Mode::Public, 32, 0, 0, 145394, 145394, &mut rng);
392          check_hash(Mode::Public, 64, 0, 0, 146650, 146650, &mut rng);
393          check_hash(Mode::Public, 128, 0, 0, 149248, 149248, &mut rng);
394          check_hash(Mode::Public, 256, 0, 0, 150848, 150848, &mut rng);
395          check_hash(Mode::Public, 512, 0, 0, 151424, 151424, &mut rng);
396          check_hash(Mode::Public, 1024, 0, 0, 152448, 152448, &mut rng);
397      }
398  
399      #[test]
400      fn test_keccak_256_hash_private() {
401          let mut rng = TestRng::default();
402  
403          check_hash(Mode::Private, 1, 0, 0, 138157, 138157, &mut rng);
404          check_hash(Mode::Private, 2, 0, 0, 139108, 139108, &mut rng);
405          check_hash(Mode::Private, 3, 0, 0, 139741, 139741, &mut rng);
406          check_hash(Mode::Private, 4, 0, 0, 140318, 140318, &mut rng);
407          check_hash(Mode::Private, 5, 0, 0, 140879, 140879, &mut rng);
408          check_hash(Mode::Private, 6, 0, 0, 141350, 141350, &mut rng);
409          check_hash(Mode::Private, 7, 0, 0, 141787, 141787, &mut rng);
410          check_hash(Mode::Private, 8, 0, 0, 142132, 142132, &mut rng);
411          check_hash(Mode::Private, 16, 0, 0, 144173, 144173, &mut rng);
412          check_hash(Mode::Private, 32, 0, 0, 145394, 145394, &mut rng);
413          check_hash(Mode::Private, 64, 0, 0, 146650, 146650, &mut rng);
414          check_hash(Mode::Private, 128, 0, 0, 149248, 149248, &mut rng);
415          check_hash(Mode::Private, 256, 0, 0, 150848, 150848, &mut rng);
416          check_hash(Mode::Private, 512, 0, 0, 151424, 151424, &mut rng);
417          check_hash(Mode::Private, 1024, 0, 0, 152448, 152448, &mut rng);
418      }
419  
420      #[test]
421      fn test_keccak_224_equivalence() {
422          check_equivalence!(console::Keccak224::default(), Keccak224::<Circuit>::new());
423      }
424  
425      #[test]
426      fn test_keccak_256_equivalence() {
427          check_equivalence!(console::Keccak256::default(), Keccak256::<Circuit>::new());
428      }
429  
430      #[test]
431      fn test_keccak_384_equivalence() {
432          check_equivalence!(console::Keccak384::default(), Keccak384::<Circuit>::new());
433      }
434  
435      #[test]
436      fn test_keccak_512_equivalence() {
437          check_equivalence!(console::Keccak512::default(), Keccak512::<Circuit>::new());
438      }
439  
440      #[test]
441      fn test_sha3_224_equivalence() {
442          check_equivalence!(console::Sha3_224::default(), Sha3_224::<Circuit>::new());
443      }
444  
445      #[test]
446      fn test_sha3_256_equivalence() {
447          check_equivalence!(console::Sha3_256::default(), Sha3_256::<Circuit>::new());
448      }
449  
450      #[test]
451      fn test_sha3_384_equivalence() {
452          check_equivalence!(console::Sha3_384::default(), Sha3_384::<Circuit>::new());
453      }
454  
455      #[test]
456      fn test_sha3_512_equivalence() {
457          check_equivalence!(console::Sha3_512::default(), Sha3_512::<Circuit>::new());
458      }
459  }