/ algorithms / src / msm / variable_base / batched.rs
batched.rs
  1  // Copyright (c) 2025-2026 ACDC Network
  2  // This file is part of the alphavm library.
  3  //
  4  // Alpha Chain | Delta Chain Protocol
  5  // International Monetary Graphite.
  6  //
  7  // Derived from Aleo (https://aleo.org) and ProvableHQ (https://provable.com).
  8  // They built world-class ZK infrastructure. We installed the EASY button.
  9  // Their cryptography: elegant. Our modifications: bureaucracy-compatible.
 10  // Original brilliance: theirs. Robert's Rules: ours. Bugs: definitely ours.
 11  //
 12  // Original Aleo/ProvableHQ code subject to Apache 2.0 https://www.apache.org/licenses/LICENSE-2.0
 13  // All modifications and new work: CC0 1.0 Universal Public Domain Dedication.
 14  // No rights reserved. No permission required. No warranty. No refunds.
 15  //
 16  // https://creativecommons.org/publicdomain/zero/1.0/
 17  // SPDX-License-Identifier: CC0-1.0
 18  
 19  use alphavm_curves::{AffineCurve, ProjectiveCurve};
 20  use alphavm_fields::{Field, One, PrimeField, Zero};
 21  use alphavm_utilities::{cfg_into_iter, BigInteger, BitIteratorBE};
 22  
 23  #[cfg(not(feature = "serial"))]
 24  use rayon::prelude::*;
 25  
 26  #[cfg(target_arch = "x86_64")]
 27  use crate::{prefetch_slice, prefetch_slice_write};
 28  
 29  #[derive(Copy, Clone, Debug)]
 30  pub struct BucketPosition {
 31      pub bucket_index: u32,
 32      pub scalar_index: u32,
 33  }
 34  
 35  impl Eq for BucketPosition {}
 36  
 37  impl PartialEq for BucketPosition {
 38      fn eq(&self, other: &Self) -> bool {
 39          self.bucket_index == other.bucket_index
 40      }
 41  }
 42  
 43  impl Ord for BucketPosition {
 44      fn cmp(&self, other: &Self) -> core::cmp::Ordering {
 45          self.bucket_index.cmp(&other.bucket_index)
 46      }
 47  }
 48  
 49  impl PartialOrd for BucketPosition {
 50      fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
 51          Some(self.cmp(other))
 52      }
 53  }
 54  
 55  /// Returns a batch size of sufficient size to amortize the cost of an
 56  /// inversion, while attempting to reduce strain to the CPU cache.
 57  #[inline]
 58  const fn batch_size(msm_size: usize) -> usize {
 59      // These values are determined empirically using performance benchmarks for
 60      // BLS12-377 on Intel, AMD, and M1 machines. These values are determined by
 61      // taking the L1 and L2 cache sizes and dividing them by the size of group
 62      // elements (i.e. 96 bytes).
 63      //
 64      // As the algorithm itself requires caching additional values beyond the group
 65      // elements, the ideal batch size is less than expected, to accommodate
 66      // those values. In general, it was found that undershooting is better than
 67      // overshooting this heuristic.
 68      if cfg!(target_arch = "x86_64") && msm_size < 500_000 {
 69          // Assumes an L1 cache size of 32KiB. Note that larger cache sizes
 70          // are not negatively impacted by this value, however smaller L1 cache sizes
 71          // are.
 72          300
 73      } else {
 74          // Assumes an L2 cache size of 1MiB. Note that larger cache sizes
 75          // are not negatively impacted by this value, however smaller L2 cache sizes
 76          // are.
 77          3000
 78      }
 79  }
 80  
 81  /// If `(j, k)` is the `i`-th entry in `index`, then this method sets
 82  /// `bases[j] = bases[j] + bases[k]`. The state of `bases[k]` becomes
 83  /// unspecified.
 84  #[inline]
 85  fn batch_add_in_place_same_slice<G: AffineCurve>(bases: &mut [G], index: &[(u32, u32)]) {
 86      let mut inversion_tmp = G::BaseField::one();
 87      let half = G::BaseField::half();
 88  
 89      #[cfg(target_arch = "x86_64")]
 90      let mut prefetch_iter = index.iter();
 91      #[cfg(target_arch = "x86_64")]
 92      prefetch_iter.next();
 93  
 94      // We run two loops over the data separated by an inversion
 95      for (idx, idy) in index.iter() {
 96          #[cfg(target_arch = "x86_64")]
 97          prefetch_slice!(G, bases, bases, prefetch_iter);
 98  
 99          let (a, b) = if idx < idy {
100              let (x, y) = bases.split_at_mut(*idy as usize);
101              (&mut x[*idx as usize], &mut y[0])
102          } else {
103              let (x, y) = bases.split_at_mut(*idx as usize);
104              (&mut y[0], &mut x[*idy as usize])
105          };
106          G::batch_add_loop_1(a, b, &half, &mut inversion_tmp);
107      }
108  
109      inversion_tmp = inversion_tmp.inverse().unwrap(); // this is always in Fp*
110  
111      #[cfg(target_arch = "x86_64")]
112      let mut prefetch_iter = index.iter().rev();
113      #[cfg(target_arch = "x86_64")]
114      prefetch_iter.next();
115  
116      for (idx, idy) in index.iter().rev() {
117          #[cfg(target_arch = "x86_64")]
118          prefetch_slice!(G, bases, bases, prefetch_iter);
119  
120          let (a, b) = if idx < idy {
121              let (x, y) = bases.split_at_mut(*idy as usize);
122              (&mut x[*idx as usize], y[0])
123          } else {
124              let (x, y) = bases.split_at_mut(*idx as usize);
125              (&mut y[0], x[*idy as usize])
126          };
127          G::batch_add_loop_2(a, b, &mut inversion_tmp);
128      }
129  }
130  
131  /// If `(j, k)` is the `i`-th entry in `index`, then this method performs one of
132  /// two actions:
133  /// * `addition_result[i] = bases[j] + bases[k]`
134  /// * `addition_result[i] = bases[j];
135  ///
136  /// It uses `scratch_space` to store intermediate values, and clears it after
137  /// use.
138  #[inline]
139  fn batch_add_write<G: AffineCurve>(
140      bases: &[G],
141      index: &[(u32, u32)],
142      addition_result: &mut Vec<G>,
143      scratch_space: &mut Vec<Option<G>>,
144  ) {
145      let mut inversion_tmp = G::BaseField::one();
146      let half = G::BaseField::half();
147  
148      #[cfg(target_arch = "x86_64")]
149      let mut prefetch_iter = index.iter();
150      #[cfg(target_arch = "x86_64")]
151      prefetch_iter.next();
152  
153      // We run two loops over the data separated by an inversion
154      for (idx, idy) in index.iter() {
155          #[cfg(target_arch = "x86_64")]
156          prefetch_slice_write!(G, bases, bases, prefetch_iter);
157  
158          if *idy == !0u32 {
159              addition_result.push(bases[*idx as usize]);
160              scratch_space.push(None);
161          } else {
162              let (mut a, mut b) = (bases[*idx as usize], bases[*idy as usize]);
163              G::batch_add_loop_1(&mut a, &mut b, &half, &mut inversion_tmp);
164              addition_result.push(a);
165              scratch_space.push(Some(b));
166          }
167      }
168  
169      inversion_tmp = inversion_tmp.inverse().unwrap(); // this is always in Fp*
170  
171      for (a, op_b) in addition_result.iter_mut().rev().zip(scratch_space.iter().rev()) {
172          if let Some(b) = op_b {
173              G::batch_add_loop_2(a, *b, &mut inversion_tmp);
174          }
175      }
176      scratch_space.clear();
177  }
178  
179  #[inline]
180  pub(super) fn batch_add<G: AffineCurve>(
181      num_buckets: usize,
182      bases: &[G],
183      bucket_positions: &mut [BucketPosition],
184  ) -> Vec<G> {
185      assert!(bases.len() >= bucket_positions.len());
186      assert!(!bases.is_empty());
187  
188      // Fetch the ideal batch size for the number of bases.
189      let batch_size = batch_size(bases.len());
190  
191      // Sort the buckets by their bucket index (not scalar index).
192      bucket_positions.sort_unstable();
193  
194      let mut num_scalars = bucket_positions.len();
195      let mut all_ones = true;
196      let mut new_scalar_length = 0;
197      let mut global_counter = 0;
198      let mut local_counter = 1;
199      let mut number_of_bases_in_batch = 0;
200  
201      let mut instr = Vec::<(u32, u32)>::with_capacity(batch_size);
202      let mut new_bases = Vec::with_capacity(bases.len());
203      let mut scratch_space = Vec::with_capacity(batch_size / 2);
204  
205      // In the first loop, copy the results of the first in-place addition tree to
206      // the vector `new_bases`.
207      while global_counter < num_scalars {
208          let current_bucket = bucket_positions[global_counter].bucket_index;
209          while global_counter + 1 < num_scalars && bucket_positions[global_counter + 1].bucket_index == current_bucket {
210              global_counter += 1;
211              local_counter += 1;
212          }
213          if current_bucket >= num_buckets as u32 {
214              local_counter = 1;
215          } else if local_counter > 1 {
216              // all ones is false if next len is not 1
217              if local_counter > 2 {
218                  all_ones = false;
219              }
220              let is_odd = local_counter % 2 == 1;
221              let half = local_counter / 2;
222              for i in 0..half {
223                  instr.push((
224                      bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index,
225                      bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index,
226                  ));
227                  bucket_positions[new_scalar_length + i] =
228                      BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + i) as u32 };
229              }
230              if is_odd {
231                  instr.push((bucket_positions[global_counter].scalar_index, !0u32));
232                  bucket_positions[new_scalar_length + half] =
233                      BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + half) as u32 };
234              }
235              // Reset the local_counter and update state
236              new_scalar_length += half + (local_counter % 2);
237              number_of_bases_in_batch += half;
238              local_counter = 1;
239  
240              // When the number of bases in a batch crosses the threshold, perform a batch
241              // addition.
242              if number_of_bases_in_batch >= batch_size / 2 {
243                  // We need instructions for copying data in the case of noops.
244                  // We encode noops/copies as !0u32
245                  batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space);
246  
247                  instr.clear();
248                  number_of_bases_in_batch = 0;
249              }
250          } else {
251              instr.push((bucket_positions[global_counter].scalar_index, !0u32));
252              bucket_positions[new_scalar_length] =
253                  BucketPosition { bucket_index: current_bucket, scalar_index: new_scalar_length as u32 };
254              new_scalar_length += 1;
255          }
256          global_counter += 1;
257      }
258      if !instr.is_empty() {
259          batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space);
260          instr.clear();
261      }
262      global_counter = 0;
263      number_of_bases_in_batch = 0;
264      local_counter = 1;
265      num_scalars = new_scalar_length;
266      new_scalar_length = 0;
267  
268      // Next, perform all the updates in place.
269      while !all_ones {
270          all_ones = true;
271          while global_counter < num_scalars {
272              let current_bucket = bucket_positions[global_counter].bucket_index;
273              while global_counter + 1 < num_scalars
274                  && bucket_positions[global_counter + 1].bucket_index == current_bucket
275              {
276                  global_counter += 1;
277                  local_counter += 1;
278              }
279              if current_bucket >= num_buckets as u32 {
280                  local_counter = 1;
281              } else if local_counter > 1 {
282                  // all ones is false if next len is not 1
283                  if local_counter != 2 {
284                      all_ones = false;
285                  }
286                  let is_odd = local_counter % 2 == 1;
287                  let half = local_counter / 2;
288                  for i in 0..half {
289                      instr.push((
290                          bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index,
291                          bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index,
292                      ));
293                      bucket_positions[new_scalar_length + i] =
294                          bucket_positions[global_counter - (local_counter - 1) + 2 * i];
295                  }
296                  if is_odd {
297                      bucket_positions[new_scalar_length + half] = bucket_positions[global_counter];
298                  }
299                  // Reset the local_counter and update state
300                  new_scalar_length += half + (local_counter % 2);
301                  number_of_bases_in_batch += half;
302                  local_counter = 1;
303  
304                  if number_of_bases_in_batch >= batch_size / 2 {
305                      batch_add_in_place_same_slice(&mut new_bases, &instr);
306                      instr.clear();
307                      number_of_bases_in_batch = 0;
308                  }
309              } else {
310                  bucket_positions[new_scalar_length] = bucket_positions[global_counter];
311                  new_scalar_length += 1;
312              }
313              global_counter += 1;
314          }
315          // If there are any remaining unprocessed instructions, proceed to perform batch
316          // addition.
317          if !instr.is_empty() {
318              batch_add_in_place_same_slice(&mut new_bases, &instr);
319              instr.clear();
320          }
321          global_counter = 0;
322          number_of_bases_in_batch = 0;
323          local_counter = 1;
324          num_scalars = new_scalar_length;
325          new_scalar_length = 0;
326      }
327  
328      let mut res = vec![Zero::zero(); num_buckets];
329      for bucket_position in bucket_positions.iter().take(num_scalars) {
330          res[bucket_position.bucket_index as usize] = new_bases[bucket_position.scalar_index as usize];
331      }
332      res
333  }
334  
335  #[inline]
336  fn batched_window<G: AffineCurve>(
337      bases: &[G],
338      scalars: &[<G::ScalarField as PrimeField>::BigInteger],
339      w_start: usize,
340      c: usize,
341  ) -> (G::Projective, usize) {
342      // We don't need the "zero" bucket, so we only have 2^c - 1 buckets
343      let window_size = if !w_start.is_multiple_of(c) { w_start % c } else { c };
344      let num_buckets = (1 << window_size) - 1;
345  
346      let mut bucket_positions: Vec<_> = scalars
347          .iter()
348          .enumerate()
349          .map(|(scalar_index, &scalar)| {
350              let mut scalar = scalar;
351  
352              // We right-shift by w_start, thus getting rid of the lower bits.
353              scalar.divn(w_start as u32);
354  
355              // We mod the remaining bits by the window size.
356              let scalar = (scalar.as_ref()[0] % (1 << c)) as i32;
357  
358              BucketPosition { bucket_index: (scalar - 1) as u32, scalar_index: scalar_index as u32 }
359          })
360          .collect();
361  
362      let buckets = batch_add(num_buckets, bases, &mut bucket_positions);
363  
364      let mut res = G::Projective::zero();
365      let mut running_sum = G::Projective::zero();
366      for b in buckets.into_iter().rev() {
367          running_sum.add_assign_mixed(&b);
368          res += &running_sum;
369      }
370  
371      (res, window_size)
372  }
373  
374  pub fn msm<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
375      if bases.len() < 15 {
376          let num_bits = G::ScalarField::size_in_bits();
377          let bigint_size = <G::ScalarField as PrimeField>::BigInteger::NUM_LIMBS * 64;
378          let mut bits =
379              scalars.iter().map(|s| BitIteratorBE::new(s.as_ref()).skip(bigint_size - num_bits)).collect::<Vec<_>>();
380          let mut sum = G::Projective::zero();
381  
382          let mut encountered_one = false;
383          for _ in 0..num_bits {
384              if encountered_one {
385                  sum.double_in_place();
386              }
387              for (bits, base) in bits.iter_mut().zip(bases) {
388                  if let Some(true) = bits.next() {
389                      sum.add_assign_mixed(base);
390                      encountered_one = true;
391                  }
392              }
393          }
394          debug_assert!(bits.iter_mut().all(|b| b.next().is_none()));
395          sum
396      } else {
397          // Determine the bucket size `c` (chosen empirically).
398          let c = match scalars.len() < 32 {
399              true => 1,
400              false => crate::msm::ln_without_floats(scalars.len()) + 2,
401          };
402  
403          let num_bits = <G::ScalarField as PrimeField>::size_in_bits();
404  
405          // Each window is of size `c`.
406          // We divide up the bits 0..num_bits into windows of size `c`, and
407          // in parallel process each such window.
408          let window_sums: Vec<_> =
409              cfg_into_iter!(0..num_bits).step_by(c).map(|w_start| batched_window(bases, scalars, w_start, c)).collect();
410  
411          // We store the sum for the lowest window.
412          let (lowest, window_sums) = window_sums.split_first().unwrap();
413  
414          // We're traversing windows from high to low.
415          window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| {
416              total += sum_i;
417              for _ in 0..*window_size {
418                  total.double_in_place();
419              }
420              total
421          }) + lowest.0
422      }
423  }