/ algorithms / src / msm / variable_base / batched.rs
batched.rs
  1  // Copyright (c) 2025 ADnet Contributors
  2  // This file is part of the AlphaVM library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  use alphavm_curves::{AffineCurve, ProjectiveCurve};
 17  use alphavm_fields::{Field, One, PrimeField, Zero};
 18  use alphavm_utilities::{BigInteger, BitIteratorBE, cfg_into_iter};
 19  
 20  #[cfg(not(feature = "serial"))]
 21  use rayon::prelude::*;
 22  
 23  #[cfg(target_arch = "x86_64")]
 24  use crate::{prefetch_slice, prefetch_slice_write};
 25  
 26  #[derive(Copy, Clone, Debug)]
 27  pub struct BucketPosition {
 28      pub bucket_index: u32,
 29      pub scalar_index: u32,
 30  }
 31  
 32  impl Eq for BucketPosition {}
 33  
 34  impl PartialEq for BucketPosition {
 35      fn eq(&self, other: &Self) -> bool {
 36          self.bucket_index == other.bucket_index
 37      }
 38  }
 39  
 40  impl Ord for BucketPosition {
 41      fn cmp(&self, other: &Self) -> core::cmp::Ordering {
 42          self.bucket_index.cmp(&other.bucket_index)
 43      }
 44  }
 45  
 46  impl PartialOrd for BucketPosition {
 47      #[allow(clippy::non_canonical_partial_ord_impl)]
 48      fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
 49          self.bucket_index.partial_cmp(&other.bucket_index)
 50      }
 51  }
 52  
 53  /// Returns a batch size of sufficient size to amortize the cost of an
 54  /// inversion, while attempting to reduce strain to the CPU cache.
 55  #[inline]
 56  const fn batch_size(msm_size: usize) -> usize {
 57      // These values are determined empirically using performance benchmarks for
 58      // BLS12-377 on Intel, AMD, and M1 machines. These values are determined by
 59      // taking the L1 and L2 cache sizes and dividing them by the size of group
 60      // elements (i.e. 96 bytes).
 61      //
 62      // As the algorithm itself requires caching additional values beyond the group
 63      // elements, the ideal batch size is less than expected, to accommodate
 64      // those values. In general, it was found that undershooting is better than
 65      // overshooting this heuristic.
 66      if cfg!(target_arch = "x86_64") && msm_size < 500_000 {
 67          // Assumes an L1 cache size of 32KiB. Note that larger cache sizes
 68          // are not negatively impacted by this value, however smaller L1 cache sizes
 69          // are.
 70          300
 71      } else {
 72          // Assumes an L2 cache size of 1MiB. Note that larger cache sizes
 73          // are not negatively impacted by this value, however smaller L2 cache sizes
 74          // are.
 75          3000
 76      }
 77  }
 78  
 79  /// If `(j, k)` is the `i`-th entry in `index`, then this method sets
 80  /// `bases[j] = bases[j] + bases[k]`. The state of `bases[k]` becomes
 81  /// unspecified.
 82  #[inline]
 83  fn batch_add_in_place_same_slice<G: AffineCurve>(bases: &mut [G], index: &[(u32, u32)]) {
 84      let mut inversion_tmp = G::BaseField::one();
 85      let half = G::BaseField::half();
 86  
 87      #[cfg(target_arch = "x86_64")]
 88      let mut prefetch_iter = index.iter();
 89      #[cfg(target_arch = "x86_64")]
 90      prefetch_iter.next();
 91  
 92      // We run two loops over the data separated by an inversion
 93      for (idx, idy) in index.iter() {
 94          #[cfg(target_arch = "x86_64")]
 95          prefetch_slice!(G, bases, bases, prefetch_iter);
 96  
 97          let (a, b) = if idx < idy {
 98              let (x, y) = bases.split_at_mut(*idy as usize);
 99              (&mut x[*idx as usize], &mut y[0])
100          } else {
101              let (x, y) = bases.split_at_mut(*idx as usize);
102              (&mut y[0], &mut x[*idy as usize])
103          };
104          G::batch_add_loop_1(a, b, &half, &mut inversion_tmp);
105      }
106  
107      inversion_tmp = inversion_tmp.inverse().unwrap(); // this is always in Fp*
108  
109      #[cfg(target_arch = "x86_64")]
110      let mut prefetch_iter = index.iter().rev();
111      #[cfg(target_arch = "x86_64")]
112      prefetch_iter.next();
113  
114      for (idx, idy) in index.iter().rev() {
115          #[cfg(target_arch = "x86_64")]
116          prefetch_slice!(G, bases, bases, prefetch_iter);
117  
118          let (a, b) = if idx < idy {
119              let (x, y) = bases.split_at_mut(*idy as usize);
120              (&mut x[*idx as usize], y[0])
121          } else {
122              let (x, y) = bases.split_at_mut(*idx as usize);
123              (&mut y[0], x[*idy as usize])
124          };
125          G::batch_add_loop_2(a, b, &mut inversion_tmp);
126      }
127  }
128  
129  /// If `(j, k)` is the `i`-th entry in `index`, then this method performs one of
130  /// two actions:
131  /// * `addition_result[i] = bases[j] + bases[k]`
132  /// * `addition_result[i] = bases[j];
133  ///
134  /// It uses `scratch_space` to store intermediate values, and clears it after
135  /// use.
136  #[inline]
137  fn batch_add_write<G: AffineCurve>(
138      bases: &[G],
139      index: &[(u32, u32)],
140      addition_result: &mut Vec<G>,
141      scratch_space: &mut Vec<Option<G>>,
142  ) {
143      let mut inversion_tmp = G::BaseField::one();
144      let half = G::BaseField::half();
145  
146      #[cfg(target_arch = "x86_64")]
147      let mut prefetch_iter = index.iter();
148      #[cfg(target_arch = "x86_64")]
149      prefetch_iter.next();
150  
151      // We run two loops over the data separated by an inversion
152      for (idx, idy) in index.iter() {
153          #[cfg(target_arch = "x86_64")]
154          prefetch_slice_write!(G, bases, bases, prefetch_iter);
155  
156          if *idy == !0u32 {
157              addition_result.push(bases[*idx as usize]);
158              scratch_space.push(None);
159          } else {
160              let (mut a, mut b) = (bases[*idx as usize], bases[*idy as usize]);
161              G::batch_add_loop_1(&mut a, &mut b, &half, &mut inversion_tmp);
162              addition_result.push(a);
163              scratch_space.push(Some(b));
164          }
165      }
166  
167      inversion_tmp = inversion_tmp.inverse().unwrap(); // this is always in Fp*
168  
169      for (a, op_b) in addition_result.iter_mut().rev().zip(scratch_space.iter().rev()) {
170          if let Some(b) = op_b {
171              G::batch_add_loop_2(a, *b, &mut inversion_tmp);
172          }
173      }
174      scratch_space.clear();
175  }
176  
177  #[inline]
178  pub(super) fn batch_add<G: AffineCurve>(
179      num_buckets: usize,
180      bases: &[G],
181      bucket_positions: &mut [BucketPosition],
182  ) -> Vec<G> {
183      assert!(bases.len() >= bucket_positions.len());
184      assert!(!bases.is_empty());
185  
186      // Fetch the ideal batch size for the number of bases.
187      let batch_size = batch_size(bases.len());
188  
189      // Sort the buckets by their bucket index (not scalar index).
190      bucket_positions.sort_unstable();
191  
192      let mut num_scalars = bucket_positions.len();
193      let mut all_ones = true;
194      let mut new_scalar_length = 0;
195      let mut global_counter = 0;
196      let mut local_counter = 1;
197      let mut number_of_bases_in_batch = 0;
198  
199      let mut instr = Vec::<(u32, u32)>::with_capacity(batch_size);
200      let mut new_bases = Vec::with_capacity(bases.len());
201      let mut scratch_space = Vec::with_capacity(batch_size / 2);
202  
203      // In the first loop, copy the results of the first in-place addition tree to
204      // the vector `new_bases`.
205      while global_counter < num_scalars {
206          let current_bucket = bucket_positions[global_counter].bucket_index;
207          while global_counter + 1 < num_scalars && bucket_positions[global_counter + 1].bucket_index == current_bucket {
208              global_counter += 1;
209              local_counter += 1;
210          }
211          if current_bucket >= num_buckets as u32 {
212              local_counter = 1;
213          } else if local_counter > 1 {
214              // all ones is false if next len is not 1
215              if local_counter > 2 {
216                  all_ones = false;
217              }
218              let is_odd = local_counter % 2 == 1;
219              let half = local_counter / 2;
220              for i in 0..half {
221                  instr.push((
222                      bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index,
223                      bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index,
224                  ));
225                  bucket_positions[new_scalar_length + i] =
226                      BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + i) as u32 };
227              }
228              if is_odd {
229                  instr.push((bucket_positions[global_counter].scalar_index, !0u32));
230                  bucket_positions[new_scalar_length + half] =
231                      BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + half) as u32 };
232              }
233              // Reset the local_counter and update state
234              new_scalar_length += half + (local_counter % 2);
235              number_of_bases_in_batch += half;
236              local_counter = 1;
237  
238              // When the number of bases in a batch crosses the threshold, perform a batch
239              // addition.
240              if number_of_bases_in_batch >= batch_size / 2 {
241                  // We need instructions for copying data in the case of noops.
242                  // We encode noops/copies as !0u32
243                  batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space);
244  
245                  instr.clear();
246                  number_of_bases_in_batch = 0;
247              }
248          } else {
249              instr.push((bucket_positions[global_counter].scalar_index, !0u32));
250              bucket_positions[new_scalar_length] =
251                  BucketPosition { bucket_index: current_bucket, scalar_index: new_scalar_length as u32 };
252              new_scalar_length += 1;
253          }
254          global_counter += 1;
255      }
256      if !instr.is_empty() {
257          batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space);
258          instr.clear();
259      }
260      global_counter = 0;
261      number_of_bases_in_batch = 0;
262      local_counter = 1;
263      num_scalars = new_scalar_length;
264      new_scalar_length = 0;
265  
266      // Next, perform all the updates in place.
267      while !all_ones {
268          all_ones = true;
269          while global_counter < num_scalars {
270              let current_bucket = bucket_positions[global_counter].bucket_index;
271              while global_counter + 1 < num_scalars
272                  && bucket_positions[global_counter + 1].bucket_index == current_bucket
273              {
274                  global_counter += 1;
275                  local_counter += 1;
276              }
277              if current_bucket >= num_buckets as u32 {
278                  local_counter = 1;
279              } else if local_counter > 1 {
280                  // all ones is false if next len is not 1
281                  if local_counter != 2 {
282                      all_ones = false;
283                  }
284                  let is_odd = local_counter % 2 == 1;
285                  let half = local_counter / 2;
286                  for i in 0..half {
287                      instr.push((
288                          bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index,
289                          bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index,
290                      ));
291                      bucket_positions[new_scalar_length + i] =
292                          bucket_positions[global_counter - (local_counter - 1) + 2 * i];
293                  }
294                  if is_odd {
295                      bucket_positions[new_scalar_length + half] = bucket_positions[global_counter];
296                  }
297                  // Reset the local_counter and update state
298                  new_scalar_length += half + (local_counter % 2);
299                  number_of_bases_in_batch += half;
300                  local_counter = 1;
301  
302                  if number_of_bases_in_batch >= batch_size / 2 {
303                      batch_add_in_place_same_slice(&mut new_bases, &instr);
304                      instr.clear();
305                      number_of_bases_in_batch = 0;
306                  }
307              } else {
308                  bucket_positions[new_scalar_length] = bucket_positions[global_counter];
309                  new_scalar_length += 1;
310              }
311              global_counter += 1;
312          }
313          // If there are any remaining unprocessed instructions, proceed to perform batch
314          // addition.
315          if !instr.is_empty() {
316              batch_add_in_place_same_slice(&mut new_bases, &instr);
317              instr.clear();
318          }
319          global_counter = 0;
320          number_of_bases_in_batch = 0;
321          local_counter = 1;
322          num_scalars = new_scalar_length;
323          new_scalar_length = 0;
324      }
325  
326      let mut res = vec![Zero::zero(); num_buckets];
327      for bucket_position in bucket_positions.iter().take(num_scalars) {
328          res[bucket_position.bucket_index as usize] = new_bases[bucket_position.scalar_index as usize];
329      }
330      res
331  }
332  
333  #[inline]
334  fn batched_window<G: AffineCurve>(
335      bases: &[G],
336      scalars: &[<G::ScalarField as PrimeField>::BigInteger],
337      w_start: usize,
338      c: usize,
339  ) -> (G::Projective, usize) {
340      // We don't need the "zero" bucket, so we only have 2^c - 1 buckets
341      let window_size = if (w_start % c) != 0 { w_start % c } else { c };
342      let num_buckets = (1 << window_size) - 1;
343  
344      let mut bucket_positions: Vec<_> = scalars
345          .iter()
346          .enumerate()
347          .map(|(scalar_index, &scalar)| {
348              let mut scalar = scalar;
349  
350              // We right-shift by w_start, thus getting rid of the lower bits.
351              scalar.divn(w_start as u32);
352  
353              // We mod the remaining bits by the window size.
354              let scalar = (scalar.as_ref()[0] % (1 << c)) as i32;
355  
356              BucketPosition { bucket_index: (scalar - 1) as u32, scalar_index: scalar_index as u32 }
357          })
358          .collect();
359  
360      let buckets = batch_add(num_buckets, bases, &mut bucket_positions);
361  
362      let mut res = G::Projective::zero();
363      let mut running_sum = G::Projective::zero();
364      for b in buckets.into_iter().rev() {
365          running_sum.add_assign_mixed(&b);
366          res += &running_sum;
367      }
368  
369      (res, window_size)
370  }
371  
372  pub fn msm<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
373      if bases.len() < 15 {
374          let num_bits = G::ScalarField::size_in_bits();
375          let bigint_size = <G::ScalarField as PrimeField>::BigInteger::NUM_LIMBS * 64;
376          let mut bits =
377              scalars.iter().map(|s| BitIteratorBE::new(s.as_ref()).skip(bigint_size - num_bits)).collect::<Vec<_>>();
378          let mut sum = G::Projective::zero();
379  
380          let mut encountered_one = false;
381          for _ in 0..num_bits {
382              if encountered_one {
383                  sum.double_in_place();
384              }
385              for (bits, base) in bits.iter_mut().zip(bases) {
386                  if let Some(true) = bits.next() {
387                      sum.add_assign_mixed(base);
388                      encountered_one = true;
389                  }
390              }
391          }
392          debug_assert!(bits.iter_mut().all(|b| b.next().is_none()));
393          sum
394      } else {
395          // Determine the bucket size `c` (chosen empirically).
396          let c = match scalars.len() < 32 {
397              true => 1,
398              false => crate::msm::ln_without_floats(scalars.len()) + 2,
399          };
400  
401          let num_bits = <G::ScalarField as PrimeField>::size_in_bits();
402  
403          // Each window is of size `c`.
404          // We divide up the bits 0..num_bits into windows of size `c`, and
405          // in parallel process each such window.
406          let window_sums: Vec<_> =
407              cfg_into_iter!(0..num_bits).step_by(c).map(|w_start| batched_window(bases, scalars, w_start, c)).collect();
408  
409          // We store the sum for the lowest window.
410          let (lowest, window_sums) = window_sums.split_first().unwrap();
411  
412          // We're traversing windows from high to low.
413          window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| {
414              total += sum_i;
415              for _ in 0..*window_size {
416                  total.double_in_place();
417              }
418              total
419          }) + lowest.0
420      }
421  }