/ algorithms / src / snark / varuna / tests.rs
tests.rs
  1  // Copyright (c) 2025 ADnet Contributors
  2  // This file is part of the AlphaVM library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  #[cfg(any(test, feature = "test"))]
 17  mod varuna {
 18      use crate::{
 19          snark::varuna::{
 20              AHPForR1CS, CircuitVerifyingKey, VarunaHidingMode, VarunaNonHidingMode, VarunaSNARK, VarunaVersion,
 21              mode::SNARKMode, proof::proof_size, test_circuit::TestCircuit,
 22          },
 23          traits::{AlgebraicSponge, SNARK},
 24      };
 25  
 26      use std::collections::BTreeMap;
 27  
 28      use alphavm_curves::bls12_377::{Bls12_377, Fq, Fr};
 29      use alphavm_utilities::{
 30          CanonicalSerialize, ToBytes,
 31          rand::{TestRng, Uniform},
 32      };
 33  
 34      type FS = crate::crypto_hash::PoseidonSponge<Fq, 2, 1>;
 35  
 36      type VarunaSonicInst = VarunaSNARK<Bls12_377, FS, VarunaHidingMode>;
 37      type VarunaSonicPoSWInst = VarunaSNARK<Bls12_377, FS, VarunaNonHidingMode>;
 38  
 39      macro_rules! impl_varuna_test {
 40          ($test_struct: ident, $snark_inst: tt, $snark_mode: tt) => {
 41              struct $test_struct {}
 42              impl $test_struct {
 43                  pub(crate) fn test_circuit(num_constraints: usize, num_variables: usize, pk_size_expectation: usize, varuna_version: VarunaVersion, rng: &mut alphavm_utilities::rand::TestRng) {
 44                      let random = Fr::rand(rng);
 45  
 46                      let max_degree = AHPForR1CS::<Fr, $snark_mode>::max_degree(100, 25, 300).unwrap();
 47                      let universal_srs = $snark_inst::universal_setup(max_degree).unwrap();
 48                      let universal_prover = &universal_srs.to_universal_prover().unwrap();
 49                      let universal_verifier = &universal_srs.to_universal_verifier().unwrap();
 50                      let fs_parameters = FS::sample_parameters();
 51  
 52                      let wrong_varuna_version = match varuna_version {
 53                          VarunaVersion::V1 => VarunaVersion::V2,
 54                          VarunaVersion::V2 => VarunaVersion::V1,
 55                      };
 56  
 57                      for i in 0..5 {
 58                          let mul_depth = 1;
 59                          println!("running test with SM::ZK: {}, mul_depth: {}, num_constraints: {}, num_variables: {}, varuna_version: {:?}", $snark_mode::ZK, mul_depth + i, num_constraints + i, num_variables + i, varuna_version);
 60                          let (circ, public_inputs) = TestCircuit::gen_rand(mul_depth + i, num_constraints + i, num_variables + i, rng);
 61                          let mut fake_inputs = public_inputs.clone();
 62                          fake_inputs[public_inputs.len() - 1] = random;
 63  
 64                          let (index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &circ).unwrap();
 65                          println!("Called circuit setup");
 66  
 67                          let certificate = $snark_inst::prove_vk(universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap();
 68                          assert!($snark_inst::verify_vk(universal_verifier, &fs_parameters, &circ, &index_vk, &certificate).unwrap());
 69                          println!("verified vk");
 70  
 71                          if i == 0 {
 72                              assert_eq!(pk_size_expectation, index_pk.to_bytes_le().unwrap().len(), "Update me if serialization has changed");
 73                          }
 74                          assert_eq!(664, index_vk.to_bytes_le().unwrap().len(), "Update me if serialization has changed");
 75  
 76                          let proof = $snark_inst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circ, rng).unwrap();
 77                          println!("Called prover");
 78  
 79                          assert!($snark_inst::verify(universal_verifier, &fs_parameters, &index_vk, varuna_version, public_inputs.clone(), &proof).unwrap());
 80                          println!("Called verifier");
 81                          eprintln!("\nShould not verify with fake inputs (i.e. verifier messages should print below):");
 82                          assert!(!$snark_inst::verify(universal_verifier, &fs_parameters, &index_vk, varuna_version, fake_inputs, &proof).unwrap());
 83                          eprintln!("\nShould not verify with wrong varuna version (i.e. verifier messages should print below):");
 84                          assert!(!$snark_inst::verify(universal_verifier, &fs_parameters, &index_vk, wrong_varuna_version, public_inputs, &proof).unwrap());
 85                      }
 86  
 87                      for circuit_batch_size in (0..4).map(|i| 2usize.pow(i)) {
 88                          for instance_batch_size in (0..4).map(|i| 2usize.pow(i)) {
 89                              println!("running test with circuit_batch_size: {circuit_batch_size} and instance_batch_size: {instance_batch_size}");
 90                              let mut constraints = BTreeMap::new();
 91                              let mut inputs = BTreeMap::new();
 92  
 93                              for i in 0..circuit_batch_size {
 94                                  let (circuit_batch, input_batch): (Vec<_>, Vec<_>) = (0..instance_batch_size)
 95                                  .map(|_| {
 96                                      let mul_depth = 2 + i;
 97                                      let (circ, inputs) = TestCircuit::gen_rand(mul_depth, num_constraints + 100*i, num_variables, rng);
 98                                      (circ, inputs)
 99                                  })
100                                  .unzip();
101                                  let circuit_id = AHPForR1CS::<Fr, $snark_mode>::index(&circuit_batch[0]).unwrap().id;
102                                  constraints.insert(circuit_id, circuit_batch);
103                                  inputs.insert(circuit_id, input_batch);
104                              }
105                              let unique_instances = constraints.values().map(|instances| &instances[0]).collect::<Vec<_>>();
106  
107                              let index_keys =
108                                  $snark_inst::batch_circuit_setup(&universal_srs, unique_instances.as_slice()).unwrap();
109                              println!("Called circuit setup");
110  
111                              let mut pks_to_constraints = BTreeMap::new();
112                              let mut vks_to_inputs = BTreeMap::new();
113  
114                              for (index_pk, index_vk) in index_keys.iter() {
115                                  let certificate = $snark_inst::prove_vk(universal_prover, &fs_parameters, &index_vk, &index_pk).unwrap();
116                                  let circuits = constraints[&index_pk.circuit.id].as_slice();
117                                  assert!($snark_inst::verify_vk(universal_verifier, &fs_parameters, &circuits[0], &index_vk, &certificate).unwrap());
118                                  pks_to_constraints.insert(index_pk, circuits);
119                                  vks_to_inputs.insert(index_vk, inputs[&index_pk.circuit.id].as_slice());
120                              }
121                              println!("verified vks");
122  
123                              let proof =
124                                  $snark_inst::prove_batch(universal_prover, &fs_parameters, varuna_version, &pks_to_constraints, rng).unwrap();
125                              println!("Called prover");
126  
127                              if varuna_version == VarunaVersion::V2 {
128                                  let batch_sizes = proof.batch_sizes();
129                                  let mut proof_bytes = vec![];
130                                  proof.serialize_compressed(&mut proof_bytes).unwrap();
131                                  let actual_size = proof_size::<Bls12_377>(&batch_sizes, VarunaVersion::V2, $snark_mode::ZK).unwrap();
132                                  assert_eq!(proof_bytes.len(), actual_size);
133                                  println!("Compressed size is as expected ({actual_size} B)");
134                              }
135  
136                              assert!(
137                                  $snark_inst::verify_batch(universal_verifier, &fs_parameters, varuna_version, &vks_to_inputs, &proof).unwrap(),
138                                  "Batch verification failed with {instance_batch_size} instances and {circuit_batch_size} circuits for circuits: {constraints:?}"
139                              );
140                              println!("Called verifier");
141                              eprintln!("\nShould not verify with wrong inputs (i.e. verifier messages should print below):");
142                              let mut fake_instance_inputs = Vec::with_capacity(vks_to_inputs.len());
143                              for instance_input in vks_to_inputs.values() {
144                                  let mut fake_instance_input = Vec::with_capacity(instance_input.len());
145                                  for input in instance_input.iter() {
146                                      let mut fake_input = input.clone();
147                                      fake_input[input.len() - 1] = Fr::rand(rng);
148                                      fake_instance_input.push(fake_input);
149                                  }
150                                  fake_instance_inputs.push(fake_instance_input);
151                              }
152                              let mut vks_to_fake_inputs = BTreeMap::new();
153                              for (i, vk) in vks_to_inputs.keys().enumerate() {
154                                  vks_to_fake_inputs.insert(*vk, fake_instance_inputs[i].as_slice());
155                              }
156                              assert!(
157                                  !$snark_inst::verify_batch(
158                                      universal_verifier,
159                                      &fs_parameters,
160                                      varuna_version,
161                                      &vks_to_fake_inputs,
162                                      &proof,
163                                  )
164                                  .unwrap()
165                              );
166                              eprintln!("\nShould not verify with wrong varuna version (i.e. verifier messages should print below):");
167                              assert!(
168                                  !$snark_inst::verify_batch(
169                                      universal_verifier,
170                                      &fs_parameters,
171                                      wrong_varuna_version,
172                                      &vks_to_inputs,
173                                      &proof,
174                                  )
175                                  .unwrap()
176                              );
177                          }
178                      }
179                  }
180  
181                  pub(crate) fn test_serde_json(num_constraints: usize, num_variables: usize, rng: &mut TestRng) {
182                      use std::str::FromStr;
183  
184                      let max_degree = AHPForR1CS::<Fr, $snark_mode>::max_degree(100, 25, 300).unwrap();
185                      let universal_srs = $snark_inst::universal_setup(max_degree).unwrap();
186  
187                      let mul_depth = 1;
188                      let (circ, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng);
189  
190                      let (_index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &circ).unwrap();
191                      println!("Called circuit setup");
192  
193                      // Serialize
194                      let expected_string = index_vk.to_string();
195                      let candidate_string = serde_json::to_string(&index_vk).unwrap();
196                      assert_eq!(
197                          expected_string,
198                          serde_json::Value::from_str(&candidate_string).unwrap().as_str().unwrap()
199                      );
200  
201                      // Deserialize
202                      assert_eq!(index_vk, CircuitVerifyingKey::from_str(&expected_string).unwrap());
203                      assert_eq!(index_vk, serde_json::from_str(&candidate_string).unwrap());
204                  }
205  
206                  pub(crate) fn test_bincode(num_constraints: usize, num_variables: usize, rng: &mut TestRng) {
207                      use alphavm_utilities::{FromBytes, ToBytes};
208  
209                      let max_degree = AHPForR1CS::<Fr, $snark_mode>::max_degree(100, 25, 300).unwrap();
210                      let universal_srs = $snark_inst::universal_setup(max_degree).unwrap();
211  
212                      let mul_depth = 1;
213                      let (circ, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng);
214  
215                      let (_index_pk, index_vk) = $snark_inst::circuit_setup(&universal_srs, &circ).unwrap();
216                      println!("Called circuit setup");
217  
218                      // Serialize
219                      let expected_bytes = index_vk.to_bytes_le().unwrap();
220                      let candidate_bytes = bincode::serialize(&index_vk).unwrap();
221                      // TODO (howardwu): Serialization - Handle the inconsistency between ToBytes and Serialize (off by a length encoding).
222                      assert_eq!(&expected_bytes[..], &candidate_bytes[8..]);
223  
224                      // Deserialize
225                      assert_eq!(index_vk, CircuitVerifyingKey::read_le(&expected_bytes[..]).unwrap());
226                      assert_eq!(index_vk, bincode::deserialize(&candidate_bytes[..]).unwrap());
227                  }
228              }
229          };
230      }
231  
232      impl_varuna_test!(SonicPCTest, VarunaSonicInst, VarunaHidingMode);
233      impl_varuna_test!(SonicPCPoswTest, VarunaSonicPoSWInst, VarunaNonHidingMode);
234  
235      #[test]
236      fn prove_and_verify_with_tall_matrix_big() {
237          let num_constraints = 100;
238          let num_variables = 25;
239          let pk_size_zk = 91971;
240          let pk_size_posw = 91633;
241          let mut rng = TestRng::default();
242  
243          SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng);
244          SonicPCPoswTest::test_circuit(num_constraints, num_variables, pk_size_posw, VarunaVersion::V1, &mut rng);
245  
246          SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V2, &mut rng);
247          SonicPCPoswTest::test_circuit(num_constraints, num_variables, pk_size_posw, VarunaVersion::V2, &mut rng);
248  
249          SonicPCTest::test_serde_json(num_constraints, num_variables, &mut rng);
250          SonicPCPoswTest::test_serde_json(num_constraints, num_variables, &mut rng);
251  
252          SonicPCTest::test_bincode(num_constraints, num_variables, &mut rng);
253          SonicPCPoswTest::test_bincode(num_constraints, num_variables, &mut rng);
254      }
255  
256      #[test]
257      fn prove_and_verify_with_tall_matrix_small() {
258          let num_constraints = 26;
259          let num_variables = 25;
260          let pk_size_zk = 25428;
261          let pk_size_posw = 25090;
262          let mut rng = TestRng::default();
263  
264          SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng);
265          SonicPCPoswTest::test_circuit(num_constraints, num_variables, pk_size_posw, VarunaVersion::V1, &mut rng);
266  
267          SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V2, &mut rng);
268          SonicPCPoswTest::test_circuit(num_constraints, num_variables, pk_size_posw, VarunaVersion::V2, &mut rng);
269  
270          SonicPCTest::test_serde_json(num_constraints, num_variables, &mut rng);
271          SonicPCPoswTest::test_serde_json(num_constraints, num_variables, &mut rng);
272  
273          SonicPCTest::test_bincode(num_constraints, num_variables, &mut rng);
274          SonicPCPoswTest::test_bincode(num_constraints, num_variables, &mut rng);
275      }
276  
277      #[test]
278      fn prove_and_verify_with_squat_matrix_big() {
279          let num_constraints = 25;
280          let num_variables = 100;
281          let pk_size_zk = 53523;
282          let pk_size_posw = 53185;
283          let mut rng = TestRng::default();
284  
285          SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng);
286          SonicPCPoswTest::test_circuit(num_constraints, num_variables, pk_size_posw, VarunaVersion::V1, &mut rng);
287  
288          SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V2, &mut rng);
289          SonicPCPoswTest::test_circuit(num_constraints, num_variables, pk_size_posw, VarunaVersion::V2, &mut rng);
290  
291          SonicPCTest::test_serde_json(num_constraints, num_variables, &mut rng);
292          SonicPCPoswTest::test_serde_json(num_constraints, num_variables, &mut rng);
293  
294          SonicPCTest::test_bincode(num_constraints, num_variables, &mut rng);
295          SonicPCPoswTest::test_bincode(num_constraints, num_variables, &mut rng);
296      }
297  
298      #[test]
299      fn prove_and_verify_with_squat_matrix_small() {
300          let num_constraints = 25;
301          let num_variables = 26;
302          let pk_size_zk = 25284;
303          let pk_size_posw = 24946;
304          let mut rng = TestRng::default();
305  
306          SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng);
307          SonicPCPoswTest::test_circuit(num_constraints, num_variables, pk_size_posw, VarunaVersion::V1, &mut rng);
308  
309          SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V2, &mut rng);
310          SonicPCPoswTest::test_circuit(num_constraints, num_variables, pk_size_posw, VarunaVersion::V2, &mut rng);
311  
312          SonicPCTest::test_serde_json(num_constraints, num_variables, &mut rng);
313          SonicPCPoswTest::test_serde_json(num_constraints, num_variables, &mut rng);
314  
315          SonicPCTest::test_bincode(num_constraints, num_variables, &mut rng);
316          SonicPCPoswTest::test_bincode(num_constraints, num_variables, &mut rng);
317      }
318  
319      #[test]
320      fn prove_and_verify_with_square_matrix() {
321          let num_constraints = 25;
322          let num_variables = 25;
323          let pk_size_zk = 25284;
324          let pk_size_posw = 24946;
325          let mut rng = TestRng::default();
326  
327          SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V1, &mut rng);
328          SonicPCPoswTest::test_circuit(num_constraints, num_variables, pk_size_posw, VarunaVersion::V1, &mut rng);
329  
330          SonicPCTest::test_circuit(num_constraints, num_variables, pk_size_zk, VarunaVersion::V2, &mut rng);
331          SonicPCPoswTest::test_circuit(num_constraints, num_variables, pk_size_posw, VarunaVersion::V2, &mut rng);
332  
333          SonicPCTest::test_serde_json(num_constraints, num_variables, &mut rng);
334          SonicPCPoswTest::test_serde_json(num_constraints, num_variables, &mut rng);
335  
336          SonicPCTest::test_bincode(num_constraints, num_variables, &mut rng);
337          SonicPCPoswTest::test_bincode(num_constraints, num_variables, &mut rng);
338      }
339  }
340  
341  #[cfg(any(test, feature = "test"))]
342  mod varuna_hiding {
343      use crate::{
344          crypto_hash::PoseidonSponge,
345          snark::varuna::{
346              CircuitVerifyingKey, VarunaHidingMode, VarunaSNARK, VarunaVersion, ahp::AHPForR1CS,
347              test_circuit::TestCircuit,
348          },
349          traits::{AlgebraicSponge, SNARK},
350      };
351      use alphavm_curves::bls12_377::{Bls12_377, Fq, Fr};
352      use alphavm_utilities::{
353          FromBytes, ToBytes,
354          rand::{TestRng, Uniform},
355      };
356  
357      use std::str::FromStr;
358  
359      type VarunaInst = VarunaSNARK<Bls12_377, FS, VarunaHidingMode>;
360      type FS = PoseidonSponge<Fq, 2, 1>;
361  
362      fn test_circuit_n_times(
363          num_constraints: usize,
364          num_variables: usize,
365          num_times: usize,
366          varuna_version: VarunaVersion,
367          rng: &mut TestRng,
368      ) {
369          let max_degree = AHPForR1CS::<Fr, VarunaHidingMode>::max_degree(100, 25, 300).unwrap();
370          let universal_srs = VarunaInst::universal_setup(max_degree).unwrap();
371          let universal_prover = &universal_srs.to_universal_prover().unwrap();
372          let universal_verifier = &universal_srs.to_universal_verifier().unwrap();
373          let fs_parameters = FS::sample_parameters();
374  
375          let wrong_varuna_version = match varuna_version {
376              VarunaVersion::V1 => VarunaVersion::V2,
377              VarunaVersion::V2 => VarunaVersion::V1,
378          };
379  
380          for _ in 0..num_times {
381              let mul_depth = 2;
382              let (circuit, public_inputs) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng);
383              let mut fake_inputs = public_inputs.clone();
384              fake_inputs[public_inputs.len() - 1] = Fr::rand(rng);
385  
386              let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap();
387              println!("Called circuit setup");
388  
389              let proof =
390                  VarunaInst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap();
391              println!("Called prover");
392  
393              assert!(
394                  VarunaInst::verify(
395                      universal_verifier,
396                      &fs_parameters,
397                      &index_vk,
398                      varuna_version,
399                      public_inputs.clone(),
400                      &proof,
401                  )
402                  .unwrap()
403              );
404              println!("Called verifier");
405              eprintln!("\nShould not verify with fake inputs (i.e. verifier messages should print below):");
406              assert!(
407                  !VarunaInst::verify(
408                      universal_verifier,
409                      &fs_parameters,
410                      &index_vk,
411                      varuna_version,
412                      fake_inputs.clone(),
413                      &proof
414                  )
415                  .unwrap()
416              );
417              eprintln!("\nShould not verify with wrong varuna version (i.e. verifier messages should print below):");
418              assert!(
419                  !VarunaInst::verify(
420                      universal_verifier,
421                      &fs_parameters,
422                      &index_vk,
423                      wrong_varuna_version,
424                      public_inputs.clone(),
425                      &proof,
426                  )
427                  .unwrap()
428              );
429          }
430      }
431  
432      fn test_circuit(num_constraints: usize, num_variables: usize, varuna_version: VarunaVersion, rng: &mut TestRng) {
433          test_circuit_n_times(num_constraints, num_variables, 100, varuna_version, rng)
434      }
435  
436      fn test_serde_json(num_constraints: usize, num_variables: usize, rng: &mut TestRng) {
437          let max_degree = AHPForR1CS::<Fr, VarunaHidingMode>::max_degree(100, 25, 300).unwrap();
438          let universal_srs = VarunaInst::universal_setup(max_degree).unwrap();
439  
440          let mul_depth = 1;
441          let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng);
442  
443          let (_index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap();
444          println!("Called circuit setup");
445  
446          // Serialize
447          let expected_string = index_vk.to_string();
448          let candidate_string = serde_json::to_string(&index_vk).unwrap();
449          assert_eq!(expected_string, serde_json::Value::from_str(&candidate_string).unwrap().as_str().unwrap());
450  
451          // Deserialize
452          assert_eq!(index_vk, CircuitVerifyingKey::from_str(&expected_string).unwrap());
453          assert_eq!(index_vk, serde_json::from_str(&candidate_string).unwrap());
454      }
455  
456      fn test_bincode(num_constraints: usize, num_variables: usize, rng: &mut TestRng) {
457          let max_degree = AHPForR1CS::<Fr, VarunaHidingMode>::max_degree(100, 25, 300).unwrap();
458          let universal_srs = VarunaInst::universal_setup(max_degree).unwrap();
459  
460          let mul_depth = 1;
461          let (circuit, _) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng);
462  
463          let (_index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap();
464          println!("Called circuit setup");
465  
466          // Serialize
467          let expected_bytes = index_vk.to_bytes_le().unwrap();
468          let candidate_bytes = bincode::serialize(&index_vk).unwrap();
469          // TODO (howardwu): Serialization - Handle the inconsistency between ToBytes and
470          // Serialize (off by a length encoding).
471          assert_eq!(&expected_bytes[..], &candidate_bytes[8..]);
472  
473          // Deserialize
474          assert_eq!(index_vk, CircuitVerifyingKey::read_le(&expected_bytes[..]).unwrap());
475          assert_eq!(index_vk, bincode::deserialize(&candidate_bytes[..]).unwrap());
476      }
477  
478      #[test]
479      fn prove_and_verify_with_tall_matrix_big() {
480          let num_constraints = 100;
481          let num_variables = 25;
482          let mut rng = TestRng::default();
483  
484          test_circuit(num_constraints, num_variables, VarunaVersion::V1, &mut rng);
485          test_circuit(num_constraints, num_variables, VarunaVersion::V2, &mut rng);
486          test_serde_json(num_constraints, num_variables, &mut rng);
487          test_bincode(num_constraints, num_variables, &mut rng);
488      }
489  
490      #[test]
491      fn prove_and_verify_with_tall_matrix_small() {
492          let num_constraints = 26;
493          let num_variables = 25;
494          let mut rng = TestRng::default();
495  
496          test_circuit(num_constraints, num_variables, VarunaVersion::V1, &mut rng);
497          test_circuit(num_constraints, num_variables, VarunaVersion::V2, &mut rng);
498          test_serde_json(num_constraints, num_variables, &mut rng);
499          test_bincode(num_constraints, num_variables, &mut rng);
500      }
501  
502      #[test]
503      fn prove_and_verify_with_squat_matrix_big() {
504          let num_constraints = 25;
505          let num_variables = 100;
506          let mut rng = TestRng::default();
507  
508          test_circuit(num_constraints, num_variables, VarunaVersion::V1, &mut rng);
509          test_circuit(num_constraints, num_variables, VarunaVersion::V2, &mut rng);
510          test_serde_json(num_constraints, num_variables, &mut rng);
511          test_bincode(num_constraints, num_variables, &mut rng);
512      }
513  
514      #[test]
515      fn prove_and_verify_with_squat_matrix_small() {
516          let num_constraints = 25;
517          let num_variables = 26;
518          let mut rng = TestRng::default();
519  
520          test_circuit(num_constraints, num_variables, VarunaVersion::V1, &mut rng);
521          test_circuit(num_constraints, num_variables, VarunaVersion::V2, &mut rng);
522          test_serde_json(num_constraints, num_variables, &mut rng);
523          test_bincode(num_constraints, num_variables, &mut rng);
524      }
525  
526      #[test]
527      fn prove_and_verify_with_square_matrix() {
528          let num_constraints = 25;
529          let num_variables = 25;
530          let mut rng = TestRng::default();
531  
532          test_circuit(num_constraints, num_variables, VarunaVersion::V1, &mut rng);
533          test_circuit(num_constraints, num_variables, VarunaVersion::V2, &mut rng);
534          test_serde_json(num_constraints, num_variables, &mut rng);
535          test_bincode(num_constraints, num_variables, &mut rng);
536      }
537  
538      #[test]
539      fn prove_and_verify_with_large_matrix() {
540          let num_constraints = 1 << 16;
541          let num_variables = 1 << 16;
542          let mut rng = TestRng::default();
543  
544          test_circuit_n_times(num_constraints, num_variables, 1, VarunaVersion::V1, &mut rng);
545          test_circuit_n_times(num_constraints, num_variables, 1, VarunaVersion::V2, &mut rng);
546      }
547  
548      #[test]
549      fn check_indexing() {
550          let rng = &mut TestRng::default();
551          let mul_depth = 2;
552          let num_constraints = 1 << 13;
553          let num_variables = 1 << 13;
554          let (circuit, public_inputs) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng);
555  
556          let max_degree = AHPForR1CS::<Fr, VarunaHidingMode>::max_degree(100, 25, 300).unwrap();
557          let universal_srs = VarunaInst::universal_setup(max_degree).unwrap();
558          let universal_prover = &universal_srs.to_universal_prover().unwrap();
559          let universal_verifier = &universal_srs.to_universal_verifier().unwrap();
560          let fs_parameters = FS::sample_parameters();
561          for varuna_version in [VarunaVersion::V1, VarunaVersion::V2] {
562              let wrong_varuna_version = match varuna_version {
563                  VarunaVersion::V1 => VarunaVersion::V2,
564                  VarunaVersion::V2 => VarunaVersion::V1,
565              };
566              let (index_pk, index_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap();
567              println!("Called circuit setup");
568  
569              let proof =
570                  VarunaInst::prove(universal_prover, &fs_parameters, &index_pk, varuna_version, &circuit, rng).unwrap();
571              println!("Called prover");
572  
573              universal_srs.download_powers_for(0..2usize.pow(18)).unwrap();
574              let (new_pk, new_vk) = VarunaInst::circuit_setup(&universal_srs, &circuit).unwrap();
575              assert_eq!(index_pk, new_pk);
576              assert_eq!(index_vk, new_vk);
577              assert!(
578                  VarunaInst::verify(
579                      universal_verifier,
580                      &fs_parameters,
581                      &index_vk,
582                      varuna_version,
583                      public_inputs.clone(),
584                      &proof,
585                  )
586                  .unwrap()
587              );
588              assert!(
589                  VarunaInst::verify(
590                      universal_verifier,
591                      &fs_parameters,
592                      &new_vk,
593                      varuna_version,
594                      public_inputs.clone(),
595                      &proof
596                  )
597                  .unwrap()
598              );
599              assert!(
600                  !VarunaInst::verify(
601                      universal_verifier,
602                      &fs_parameters,
603                      &index_vk,
604                      wrong_varuna_version,
605                      public_inputs.clone(),
606                      &proof,
607                  )
608                  .unwrap()
609              );
610          }
611      }
612  
613      #[test]
614      fn test_srs_downloads() {
615          let rng = &mut TestRng::default();
616  
617          let max_degree = AHPForR1CS::<Fr, VarunaHidingMode>::max_degree(100, 25, 300).unwrap();
618          let universal_srs = VarunaInst::universal_setup(max_degree).unwrap();
619          let universal_prover = &universal_srs.to_universal_prover().unwrap();
620          let universal_verifier = &universal_srs.to_universal_verifier().unwrap();
621          let fs_parameters = FS::sample_parameters();
622          let varuna_version = VarunaVersion::V2;
623  
624          // Indexing, proving, and verifying for a circuit with 1 << 15 constraints and 1
625          // << 15 variables.
626          let mul_depth = 2;
627          let num_constraints = 2usize.pow(15) - 10;
628          let num_variables = 2usize.pow(15) - 10;
629          let (circuit1, public_inputs1) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng);
630          let (pk1, vk1) = VarunaInst::circuit_setup(&universal_srs, &circuit1).unwrap();
631          println!("Called circuit setup");
632  
633          let proof1 = VarunaInst::prove(universal_prover, &fs_parameters, &pk1, varuna_version, &circuit1, rng).unwrap();
634          println!("Called prover");
635          assert!(
636              VarunaInst::verify(
637                  universal_verifier,
638                  &fs_parameters,
639                  &vk1,
640                  varuna_version,
641                  public_inputs1.clone(),
642                  &proof1
643              )
644              .unwrap()
645          );
646  
647          /***************************************************************************
648           * * */
649  
650          // Indexing, proving, and verifying for a circuit with 1 << 19 constraints and 1
651          // << 19 variables.
652          let mul_depth = 2;
653          let num_constraints = 2usize.pow(19) - 10;
654          let num_variables = 2usize.pow(19) - 10;
655          let (circuit2, public_inputs2) = TestCircuit::gen_rand(mul_depth, num_constraints, num_variables, rng);
656          let (pk2, vk2) = VarunaInst::circuit_setup(&universal_srs, &circuit2).unwrap();
657          println!("Called circuit setup");
658  
659          let proof2 = VarunaInst::prove(universal_prover, &fs_parameters, &pk2, varuna_version, &circuit2, rng).unwrap();
660          println!("Called prover");
661          assert!(
662              VarunaInst::verify(universal_verifier, &fs_parameters, &vk2, varuna_version, public_inputs2, &proof2)
663                  .unwrap()
664          );
665          /***************************************************************************
666           * * */
667          assert!(
668              VarunaInst::verify(universal_verifier, &fs_parameters, &vk1, varuna_version, public_inputs1, &proof1)
669                  .unwrap()
670          );
671      }
672  }
673  
674  mod varuna_test_vectors {
675      use crate::{
676          fft::EvaluationDomain,
677          snark::varuna::{AHPForR1CS, TestCircuit, VarunaNonHidingMode, VarunaSNARK, VarunaVersion, ahp::verifier},
678          traits::snark::SNARK,
679      };
680      use alphavm_curves::bls12_377::{Bls12_377, Fq, Fr};
681      use alphavm_fields::One;
682      use std::{collections::BTreeMap, fs, ops::Deref, path::PathBuf, str::FromStr, sync::Arc};
683  
684      type FS = crate::crypto_hash::PoseidonSponge<Fq, 2, 1>;
685      type MM = VarunaNonHidingMode;
686      type VarunaSonicInst = VarunaSNARK<Bls12_377, FS, MM>;
687  
688      // Create the path for the `resources` folder.
689      fn resources_path(create_dir: bool) -> PathBuf {
690          let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
691          path.push("src");
692          path.push("snark");
693          path.push("varuna");
694          path.push("resources");
695  
696          // Create the `resources` folder, if it does not exist.
697          if !path.exists() {
698              if create_dir {
699                  fs::create_dir(&path).unwrap_or_else(|_| panic!("Failed to create resources folder: {path:?}"));
700              } else {
701                  panic!("Resources folder does not exist: {path:?}");
702              }
703          }
704  
705          path
706      }
707  
708      // Create the file path.
709      fn test_vector_path(folder: &str, file: &str, circuit: &str, create_dir: bool) -> PathBuf {
710          let mut path = resources_path(create_dir);
711  
712          // Construct the path where the test data lives.
713          path.push(circuit);
714          path.push(folder);
715  
716          // Create the test folder if it does not exist if specified, otherwise panic.
717          if !path.exists() {
718              if create_dir {
719                  fs::create_dir(&path).unwrap_or_else(|_| panic!("Failed to create resources folder: {path:?}"));
720              } else {
721                  panic!("Resources folder does not exist: {path:?}");
722              }
723          }
724  
725          // Construct the path for the test file.
726          path.push(file);
727          path.set_extension("txt");
728  
729          path
730      }
731  
732      // Loads the given `test_folder/test_file` and asserts the given `candidate`
733      // matches the expected values.
734      #[track_caller]
735      fn assert_test_vector_equality(test_folder: &str, test_file: &str, candidate: &str, circuit: &str) {
736          // Get the path to the test file.
737          let path = test_vector_path(test_folder, test_file, circuit, false);
738  
739          // Assert the test file is equal to the expected value.
740          expect_test::expect_file![path].assert_eq(candidate);
741      }
742  
743      // Create a test vector from a trusted revision of Varuna.
744      fn create_test_vector(folder: &str, file: &str, data: &str, circuit: &str) {
745          // Get the path to the test file.
746          let path = test_vector_path(folder, file, circuit, true);
747  
748          // Write the test vector to file.
749          fs::write(&path, data).unwrap_or_else(|_| panic!("Failed to write to file: {path:?}"));
750      }
751  
752      // Tests varuna against the test vectors in all circuits in the resources
753      // folder.
754      fn test_varuna_with_all_circuits(create_test_vectors: bool) {
755          let entries = fs::read_dir(resources_path(create_test_vectors)).expect("Failed to read resources folder");
756          entries.into_iter().for_each(|entry| {
757              let path = entry.unwrap().path();
758              if path.is_dir() {
759                  let circuit = path.file_name().unwrap().to_str().unwrap();
760                  test_circuit_with_test_vectors(create_test_vectors, circuit);
761              }
762          });
763      }
764  
765      // Test Varuna against test vectors for a specific circuit.
766      fn test_circuit_with_test_vectors(create_test_vectors: bool, circuit: &str) {
767          // Initialize the parts of the witness used in the multiplicative constraints.
768          let witness_path = format!("src/snark/varuna/resources/{circuit}/witness.input");
769          let instance_file = fs::read_to_string(witness_path).expect("Could not read the file");
770          let witness: Vec<u128> = serde_json::from_str(instance_file.lines().next().unwrap()).unwrap();
771          let (a, b) = (witness[0], witness[1]);
772  
773          // Initialize challenges from file.
774          let challenges_path = format!("src/snark/varuna/resources/{circuit}/challenges.input");
775          let challenges_file = fs::read_to_string(challenges_path).expect("Could not read the file");
776          let mut challenges = Vec::new();
777          for line in challenges_file.lines() {
778              challenges.push(line)
779          }
780          let (alpha, _eta_a, eta_b, eta_c, beta, delta_a, delta_b, delta_c, _gamma) = (
781              Fr::from_str(challenges[0]).unwrap(),
782              Fr::from_str(challenges[1]).unwrap(),
783              Fr::from_str(challenges[2]).unwrap(),
784              Fr::from_str(challenges[3]).unwrap(),
785              Fr::from_str(challenges[4]).unwrap(),
786              vec![Fr::from_str(challenges[5]).unwrap()],
787              vec![Fr::from_str(challenges[6]).unwrap()],
788              vec![Fr::from_str(challenges[7]).unwrap()],
789              Fr::from_str(challenges[8]).unwrap(),
790          );
791  
792          let circuit_combiner = Fr::one();
793          let instance_combiners = vec![Fr::one()];
794  
795          // Create sample circuit which corresponds to instance.input file.
796          let mul_depth = 3;
797          let num_constraints = 7;
798          let num_variables = 7;
799  
800          // Create a fixed seed rng that matches those the test vectors were generated
801          // with.
802          let rng = &mut alphavm_utilities::rand::TestRng::fixed(4730);
803          let max_degree =
804              AHPForR1CS::<Fr, MM>::max_degree(num_constraints, num_variables, num_variables * num_constraints).unwrap();
805          let universal_srs = VarunaSonicInst::universal_setup(max_degree).unwrap();
806          let (circ, _) =
807              TestCircuit::generate_circuit_with_fixed_witness(a, b, mul_depth, num_constraints, num_variables);
808          println!("Circuit: {circ:?}");
809          let (index_pk, _index_vk) = VarunaSonicInst::circuit_setup(&universal_srs, &circ).unwrap();
810          let mut keys_to_constraints = BTreeMap::new();
811          keys_to_constraints.insert(index_pk.circuit.deref(), std::slice::from_ref(&circ));
812  
813          // Begin the Varuna protocol execution.
814          let prover_state = AHPForR1CS::<_, MM>::init_prover(&keys_to_constraints, rng).unwrap();
815          let mut prover_state = AHPForR1CS::<_, MM>::prover_first_round(prover_state, rng).unwrap();
816          let first_round_oracles = Arc::new(prover_state.first_round_oracles.as_ref().unwrap());
817  
818          // Get private witness polynomial coefficients.
819          let (_, w_poly) = first_round_oracles.batches.iter().next().unwrap();
820          let w_lde = format!("{:?}", w_poly[0].0.coeffs().map(|(_, coeff)| coeff).collect::<Vec<_>>());
821          if create_test_vectors {
822              create_test_vector("polynomials", "w_lde", &w_lde, circuit);
823          }
824  
825          // Generate test vectors from assignments.
826          let assignments = AHPForR1CS::<_, MM>::calculate_assignments(&mut prover_state).unwrap();
827  
828          // Get full witness polynomial coefficients.
829          let (_, z_poly) = assignments.iter().next().unwrap();
830          let z_lde = format!("{:?}", z_poly[0].coeffs().iter().collect::<Vec<_>>());
831          if create_test_vectors {
832              create_test_vector("polynomials", "z_lde", &z_lde, circuit);
833          }
834  
835          let combiners = verifier::BatchCombiners::<Fr> { circuit_combiner, instance_combiners };
836          let first_round_batch_combiners = BTreeMap::from_iter([(index_pk.circuit.id, combiners)]);
837          let verifier_first_msg = verifier::FirstMessage::<Fr> { first_round_batch_combiners };
838  
839          let (second_oracles, prover_state) =
840              AHPForR1CS::<_, MM>::prover_second_round::<_>(&verifier_first_msg, prover_state, rng).unwrap();
841  
842          // Get round 2 rowcheck polynomial oracle coefficients.
843          let h_0 = format!("{:?}", second_oracles.h_0.coeffs().map(|(_, coeff)| coeff).collect::<Vec<_>>());
844          if create_test_vectors {
845              create_test_vector("polynomials", "h_0", &h_0, circuit);
846          }
847  
848          let verifier_second_msg = verifier::SecondMessage::<Fr> { alpha, eta_b: Some(eta_b), eta_c: Some(eta_c) };
849          let (_prover_third_message, third_oracles, prover_state) = AHPForR1CS::<_, MM>::prover_third_round(
850              &verifier_first_msg,
851              &verifier_second_msg,
852              &None,
853              prover_state,
854              rng,
855              VarunaVersion::V1,
856          )
857          .unwrap();
858  
859          // Get coefficients round 3 univariate rowcheck polynomial oracles.
860          let g_1 = format!("{:?}", third_oracles.g_1.coeffs().map(|(_, coeff)| coeff).collect::<Vec<_>>());
861          if create_test_vectors {
862              create_test_vector("polynomials", "g_1", &g_1, circuit);
863          }
864          let h_1 = format!("{:?}", third_oracles.h_1.coeffs().map(|(_, coeff)| coeff).collect::<Vec<_>>());
865          if create_test_vectors {
866              create_test_vector("polynomials", "h_1", &h_1, circuit);
867          }
868  
869          let verifier_third_msg = verifier::ThirdMessage::<Fr> { beta };
870          let (_prover_fourth_message, fourth_oracles, prover_state) =
871              AHPForR1CS::<_, MM>::prover_fourth_round(&verifier_second_msg, &verifier_third_msg, prover_state, rng)
872                  .unwrap();
873  
874          // Create round 4 rational sumcheck oracle polynomials.
875          let (_, gm_polys) = fourth_oracles.gs.iter().next().unwrap();
876          let g_a = format!("{:?}", gm_polys.g_a.coeffs().map(|(_, coeff)| coeff).collect::<Vec<_>>());
877          let g_b = format!("{:?}", gm_polys.g_b.coeffs().map(|(_, coeff)| coeff).collect::<Vec<_>>());
878          let g_c = format!("{:?}", gm_polys.g_b.coeffs().map(|(_, coeff)| coeff).collect::<Vec<_>>());
879          if create_test_vectors {
880              create_test_vector("polynomials", "g_a", &g_a, circuit);
881              create_test_vector("polynomials", "g_b", &g_b, circuit);
882              create_test_vector("polynomials", "g_c", &g_c, circuit);
883          }
884  
885          // Create the verifier's fourth message.
886          let verifier_fourth_msg = verifier::FourthMessage::<Fr> { delta_a, delta_b, delta_c };
887  
888          let mut public_inputs = BTreeMap::new();
889          let public_input = prover_state.public_inputs(&index_pk.circuit).unwrap();
890          public_inputs.insert(index_pk.circuit.id, public_input);
891          let non_zero_a_domain = EvaluationDomain::<Fr>::new(index_pk.circuit.index_info.num_non_zero_a).unwrap();
892          let non_zero_b_domain = EvaluationDomain::<Fr>::new(index_pk.circuit.index_info.num_non_zero_b).unwrap();
893          let non_zero_c_domain = EvaluationDomain::<Fr>::new(index_pk.circuit.index_info.num_non_zero_c).unwrap();
894          let variable_domain =
895              EvaluationDomain::<Fr>::new(index_pk.circuit.index_info.num_public_and_private_variables).unwrap();
896          let constraint_domain = EvaluationDomain::<Fr>::new(index_pk.circuit.index_info.num_constraints).unwrap();
897          let input_domain = EvaluationDomain::<Fr>::new(index_pk.circuit.index_info.num_public_inputs).unwrap();
898  
899          // Get constraint domain elements.
900          let mut constraint_domain_elements = Vec::with_capacity(constraint_domain.size());
901          for el in constraint_domain.elements() {
902              constraint_domain_elements.push(el);
903          }
904          if create_test_vectors {
905              create_test_vector("domain", "R", &format!("{constraint_domain_elements:?}"), circuit);
906          }
907  
908          // Get non_zero_domain elements.
909          let non_zero_domain = *[&non_zero_a_domain, &non_zero_b_domain, &non_zero_c_domain]
910              .iter()
911              .max_by_key(|domain| domain.size)
912              .unwrap();
913          let mut non_zero_domain_elements = Vec::with_capacity(non_zero_domain.size());
914          for el in non_zero_domain.elements() {
915              non_zero_domain_elements.push(el);
916          }
917          if create_test_vectors {
918              create_test_vector("domain", "K", &format!("{non_zero_domain_elements:?}"), circuit);
919          }
920  
921          // Get variable domain elements.
922          let mut variable_domain_elements = Vec::with_capacity(input_domain.size());
923          for el in variable_domain.elements() {
924              variable_domain_elements.push(el);
925          }
926          if create_test_vectors {
927              create_test_vector("domain", "C", &format!("{variable_domain_elements:?}"), circuit);
928          }
929  
930          let fifth_oracles = AHPForR1CS::<_, MM>::prover_fifth_round(verifier_fourth_msg, prover_state, rng).unwrap();
931  
932          // Get coefficients of final oracle polynomial from round 5.
933          let h_2 = format!("{:?}", fifth_oracles.h_2.coeffs().map(|(_, coeff)| coeff).collect::<Vec<_>>());
934          if create_test_vectors {
935              create_test_vector("polynomials", "h_2", &h_2, circuit);
936          }
937  
938          // Check the intermediate oracle polynomials against the test vectors.
939          assert_test_vector_equality("polynomials", "w_lde", &w_lde, circuit);
940          assert_test_vector_equality("polynomials", "z_lde", &z_lde, circuit);
941          assert_test_vector_equality("polynomials", "h_0", &h_0, circuit);
942          assert_test_vector_equality("polynomials", "h_1", &h_1, circuit);
943          assert_test_vector_equality("polynomials", "g_1", &g_1, circuit);
944          assert_test_vector_equality("polynomials", "h_2", &h_2, circuit);
945          assert_test_vector_equality("polynomials", "g_a", &g_a, circuit);
946          assert_test_vector_equality("polynomials", "g_b", &g_b, circuit);
947          assert_test_vector_equality("polynomials", "g_c", &g_c, circuit);
948  
949          // Check that the domains match the test vectors.
950          assert_test_vector_equality("domain", "R", &format!("{constraint_domain_elements:?}"), circuit);
951          assert_test_vector_equality("domain", "K", &format!("{non_zero_domain_elements:?}"), circuit);
952          assert_test_vector_equality("domain", "C", &format!("{variable_domain_elements:?}"), circuit);
953      }
954  
955      #[test]
956      fn test_varuna_with_prover_test_vectors() {
957          test_varuna_with_all_circuits(false);
958      }
959  }