/ algorithms / examples / msm.rs
msm.rs
  1  // Copyright (c) 2025 ADnet Contributors
  2  // This file is part of the AlphaVM library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use alphavm_algorithms::msm::*;
 17  use alphavm_curves::{
 18      bls12_377::{Fr, G1Projective},
 19      traits::ProjectiveCurve,
 20  };
 21  use alphavm_fields::PrimeField;
 22  use alphavm_utilities::{
 23      cfg_into_iter,
 24      rand::{TestRng, Uniform},
 25  };
 26  
 27  use anyhow::Result;
 28  #[cfg(not(feature = "serial"))]
 29  use rayon::prelude::*;
 30  
 31  const DEFAULT_POWER_OF_TWO: usize = 20;
 32  
 33  /// Run the following command to perform the MSM(s).
 34  /// `cargo run --release --example msm [variant] [power of 2] [number of MSM
 35  /// iterations]`
 36  pub fn main() -> Result<()> {
 37      let args: Vec<String> = std::env::args().collect();
 38      if args.len() < 4 {
 39          eprintln!("Invalid number of arguments. Given: {} - Required: 3", args.len() - 1);
 40          return Ok(());
 41      }
 42  
 43      // Parse the power of two to sample.
 44      let power_of_two = match args[2].as_str().parse::<usize>() {
 45          Ok(power_of_two) => power_of_two,
 46          Err(_) => {
 47              eprintln!("Failed to parse the power of 2, using the default: 1 << {DEFAULT_POWER_OF_TWO}");
 48              DEFAULT_POWER_OF_TWO
 49          }
 50      };
 51  
 52      println!("\nSampling 1 << {power_of_two} pairs for the vMSM...");
 53  
 54      // Sample the bases and scalars.
 55      let samples = 1 << power_of_two;
 56  
 57      let scalars = cfg_into_iter!(0..samples)
 58          .step_by(1 << 16)
 59          .flat_map(|_| {
 60              let rng = &mut TestRng::fixed(123456789);
 61              (0..(1 << 16)).map(|_| Fr::rand(rng).to_bigint()).collect::<Vec<_>>()
 62          })
 63          .collect::<Vec<_>>();
 64  
 65      println!("Sampled 1 << {power_of_two} scalars.");
 66  
 67      let bases = G1Projective::batch_normalization_into_affine(
 68          cfg_into_iter!(0..samples)
 69              .step_by(1 << 16)
 70              .flat_map(|_| {
 71                  let rng = &mut TestRng::fixed(123456789);
 72                  (0..(1 << 16)).map(|_| G1Projective::rand(rng)).collect::<Vec<_>>()
 73              })
 74              .collect::<Vec<_>>(),
 75      );
 76  
 77      println!("Sampled 1 << {power_of_two} bases.");
 78  
 79      // Parse the number of MSM iterations.
 80      let num_iterations = match args[3].as_str().parse::<usize>() {
 81          Ok(num_iterations) => num_iterations,
 82          Err(_) => {
 83              eprintln!("\nFailed to parse the number of iterations, using the default: 1");
 84              1
 85          }
 86      };
 87  
 88      println!("\nPerforming the vMSM...");
 89  
 90      for i in 0..num_iterations {
 91          let timer = std::time::Instant::now();
 92  
 93          // Parse the variant.
 94          match args[1].as_str() {
 95              "batched" => batched::msm(bases.as_slice(), scalars.as_slice()),
 96              "standard" => standard::msm(bases.as_slice(), scalars.as_slice()),
 97              _ => panic!("Invalid variant: use 'batched' or 'standard'"),
 98          };
 99  
100          println!("{i} - Performed the vMSM in {} milliseconds.", timer.elapsed().as_millis());
101      }
102  
103      Ok(())
104  }