batched.rs
1 // Copyright (c) 2025-2026 ACDC Network 2 // This file is part of the alphavm library. 3 // 4 // Alpha Chain | Delta Chain Protocol 5 // International Monetary Graphite. 6 // 7 // Derived from Aleo (https://aleo.org) and ProvableHQ (https://provable.com). 8 // They built world-class ZK infrastructure. We installed the EASY button. 9 // Their cryptography: elegant. Our modifications: bureaucracy-compatible. 10 // Original brilliance: theirs. Robert's Rules: ours. Bugs: definitely ours. 11 // 12 // Original Aleo/ProvableHQ code subject to Apache 2.0 https://www.apache.org/licenses/LICENSE-2.0 13 // All modifications and new work: CC0 1.0 Universal Public Domain Dedication. 14 // No rights reserved. No permission required. No warranty. No refunds. 15 // 16 // https://creativecommons.org/publicdomain/zero/1.0/ 17 // SPDX-License-Identifier: CC0-1.0 18 19 use alphavm_curves::{AffineCurve, ProjectiveCurve}; 20 use alphavm_fields::{Field, One, PrimeField, Zero}; 21 use alphavm_utilities::{cfg_into_iter, BigInteger, BitIteratorBE}; 22 23 #[cfg(not(feature = "serial"))] 24 use rayon::prelude::*; 25 26 #[cfg(target_arch = "x86_64")] 27 use crate::{prefetch_slice, prefetch_slice_write}; 28 29 #[derive(Copy, Clone, Debug)] 30 pub struct BucketPosition { 31 pub bucket_index: u32, 32 pub scalar_index: u32, 33 } 34 35 impl Eq for BucketPosition {} 36 37 impl PartialEq for BucketPosition { 38 fn eq(&self, other: &Self) -> bool { 39 self.bucket_index == other.bucket_index 40 } 41 } 42 43 impl Ord for BucketPosition { 44 fn cmp(&self, other: &Self) -> core::cmp::Ordering { 45 self.bucket_index.cmp(&other.bucket_index) 46 } 47 } 48 49 impl PartialOrd for BucketPosition { 50 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { 51 Some(self.cmp(other)) 52 } 53 } 54 55 /// Returns a batch size of sufficient size to amortize the cost of an 56 /// inversion, while attempting to reduce strain to the CPU cache. 57 #[inline] 58 const fn batch_size(msm_size: usize) -> usize { 59 // These values are determined empirically using performance benchmarks for 60 // BLS12-377 on Intel, AMD, and M1 machines. These values are determined by 61 // taking the L1 and L2 cache sizes and dividing them by the size of group 62 // elements (i.e. 96 bytes). 63 // 64 // As the algorithm itself requires caching additional values beyond the group 65 // elements, the ideal batch size is less than expected, to accommodate 66 // those values. In general, it was found that undershooting is better than 67 // overshooting this heuristic. 68 if cfg!(target_arch = "x86_64") && msm_size < 500_000 { 69 // Assumes an L1 cache size of 32KiB. Note that larger cache sizes 70 // are not negatively impacted by this value, however smaller L1 cache sizes 71 // are. 72 300 73 } else { 74 // Assumes an L2 cache size of 1MiB. Note that larger cache sizes 75 // are not negatively impacted by this value, however smaller L2 cache sizes 76 // are. 77 3000 78 } 79 } 80 81 /// If `(j, k)` is the `i`-th entry in `index`, then this method sets 82 /// `bases[j] = bases[j] + bases[k]`. The state of `bases[k]` becomes 83 /// unspecified. 84 #[inline] 85 fn batch_add_in_place_same_slice<G: AffineCurve>(bases: &mut [G], index: &[(u32, u32)]) { 86 let mut inversion_tmp = G::BaseField::one(); 87 let half = G::BaseField::half(); 88 89 #[cfg(target_arch = "x86_64")] 90 let mut prefetch_iter = index.iter(); 91 #[cfg(target_arch = "x86_64")] 92 prefetch_iter.next(); 93 94 // We run two loops over the data separated by an inversion 95 for (idx, idy) in index.iter() { 96 #[cfg(target_arch = "x86_64")] 97 prefetch_slice!(G, bases, bases, prefetch_iter); 98 99 let (a, b) = if idx < idy { 100 let (x, y) = bases.split_at_mut(*idy as usize); 101 (&mut x[*idx as usize], &mut y[0]) 102 } else { 103 let (x, y) = bases.split_at_mut(*idx as usize); 104 (&mut y[0], &mut x[*idy as usize]) 105 }; 106 G::batch_add_loop_1(a, b, &half, &mut inversion_tmp); 107 } 108 109 inversion_tmp = inversion_tmp.inverse().unwrap(); // this is always in Fp* 110 111 #[cfg(target_arch = "x86_64")] 112 let mut prefetch_iter = index.iter().rev(); 113 #[cfg(target_arch = "x86_64")] 114 prefetch_iter.next(); 115 116 for (idx, idy) in index.iter().rev() { 117 #[cfg(target_arch = "x86_64")] 118 prefetch_slice!(G, bases, bases, prefetch_iter); 119 120 let (a, b) = if idx < idy { 121 let (x, y) = bases.split_at_mut(*idy as usize); 122 (&mut x[*idx as usize], y[0]) 123 } else { 124 let (x, y) = bases.split_at_mut(*idx as usize); 125 (&mut y[0], x[*idy as usize]) 126 }; 127 G::batch_add_loop_2(a, b, &mut inversion_tmp); 128 } 129 } 130 131 /// If `(j, k)` is the `i`-th entry in `index`, then this method performs one of 132 /// two actions: 133 /// * `addition_result[i] = bases[j] + bases[k]` 134 /// * `addition_result[i] = bases[j]; 135 /// 136 /// It uses `scratch_space` to store intermediate values, and clears it after 137 /// use. 138 #[inline] 139 fn batch_add_write<G: AffineCurve>( 140 bases: &[G], 141 index: &[(u32, u32)], 142 addition_result: &mut Vec<G>, 143 scratch_space: &mut Vec<Option<G>>, 144 ) { 145 let mut inversion_tmp = G::BaseField::one(); 146 let half = G::BaseField::half(); 147 148 #[cfg(target_arch = "x86_64")] 149 let mut prefetch_iter = index.iter(); 150 #[cfg(target_arch = "x86_64")] 151 prefetch_iter.next(); 152 153 // We run two loops over the data separated by an inversion 154 for (idx, idy) in index.iter() { 155 #[cfg(target_arch = "x86_64")] 156 prefetch_slice_write!(G, bases, bases, prefetch_iter); 157 158 if *idy == !0u32 { 159 addition_result.push(bases[*idx as usize]); 160 scratch_space.push(None); 161 } else { 162 let (mut a, mut b) = (bases[*idx as usize], bases[*idy as usize]); 163 G::batch_add_loop_1(&mut a, &mut b, &half, &mut inversion_tmp); 164 addition_result.push(a); 165 scratch_space.push(Some(b)); 166 } 167 } 168 169 inversion_tmp = inversion_tmp.inverse().unwrap(); // this is always in Fp* 170 171 for (a, op_b) in addition_result.iter_mut().rev().zip(scratch_space.iter().rev()) { 172 if let Some(b) = op_b { 173 G::batch_add_loop_2(a, *b, &mut inversion_tmp); 174 } 175 } 176 scratch_space.clear(); 177 } 178 179 #[inline] 180 pub(super) fn batch_add<G: AffineCurve>( 181 num_buckets: usize, 182 bases: &[G], 183 bucket_positions: &mut [BucketPosition], 184 ) -> Vec<G> { 185 assert!(bases.len() >= bucket_positions.len()); 186 assert!(!bases.is_empty()); 187 188 // Fetch the ideal batch size for the number of bases. 189 let batch_size = batch_size(bases.len()); 190 191 // Sort the buckets by their bucket index (not scalar index). 192 bucket_positions.sort_unstable(); 193 194 let mut num_scalars = bucket_positions.len(); 195 let mut all_ones = true; 196 let mut new_scalar_length = 0; 197 let mut global_counter = 0; 198 let mut local_counter = 1; 199 let mut number_of_bases_in_batch = 0; 200 201 let mut instr = Vec::<(u32, u32)>::with_capacity(batch_size); 202 let mut new_bases = Vec::with_capacity(bases.len()); 203 let mut scratch_space = Vec::with_capacity(batch_size / 2); 204 205 // In the first loop, copy the results of the first in-place addition tree to 206 // the vector `new_bases`. 207 while global_counter < num_scalars { 208 let current_bucket = bucket_positions[global_counter].bucket_index; 209 while global_counter + 1 < num_scalars && bucket_positions[global_counter + 1].bucket_index == current_bucket { 210 global_counter += 1; 211 local_counter += 1; 212 } 213 if current_bucket >= num_buckets as u32 { 214 local_counter = 1; 215 } else if local_counter > 1 { 216 // all ones is false if next len is not 1 217 if local_counter > 2 { 218 all_ones = false; 219 } 220 let is_odd = local_counter % 2 == 1; 221 let half = local_counter / 2; 222 for i in 0..half { 223 instr.push(( 224 bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index, 225 bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index, 226 )); 227 bucket_positions[new_scalar_length + i] = 228 BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + i) as u32 }; 229 } 230 if is_odd { 231 instr.push((bucket_positions[global_counter].scalar_index, !0u32)); 232 bucket_positions[new_scalar_length + half] = 233 BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + half) as u32 }; 234 } 235 // Reset the local_counter and update state 236 new_scalar_length += half + (local_counter % 2); 237 number_of_bases_in_batch += half; 238 local_counter = 1; 239 240 // When the number of bases in a batch crosses the threshold, perform a batch 241 // addition. 242 if number_of_bases_in_batch >= batch_size / 2 { 243 // We need instructions for copying data in the case of noops. 244 // We encode noops/copies as !0u32 245 batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space); 246 247 instr.clear(); 248 number_of_bases_in_batch = 0; 249 } 250 } else { 251 instr.push((bucket_positions[global_counter].scalar_index, !0u32)); 252 bucket_positions[new_scalar_length] = 253 BucketPosition { bucket_index: current_bucket, scalar_index: new_scalar_length as u32 }; 254 new_scalar_length += 1; 255 } 256 global_counter += 1; 257 } 258 if !instr.is_empty() { 259 batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space); 260 instr.clear(); 261 } 262 global_counter = 0; 263 number_of_bases_in_batch = 0; 264 local_counter = 1; 265 num_scalars = new_scalar_length; 266 new_scalar_length = 0; 267 268 // Next, perform all the updates in place. 269 while !all_ones { 270 all_ones = true; 271 while global_counter < num_scalars { 272 let current_bucket = bucket_positions[global_counter].bucket_index; 273 while global_counter + 1 < num_scalars 274 && bucket_positions[global_counter + 1].bucket_index == current_bucket 275 { 276 global_counter += 1; 277 local_counter += 1; 278 } 279 if current_bucket >= num_buckets as u32 { 280 local_counter = 1; 281 } else if local_counter > 1 { 282 // all ones is false if next len is not 1 283 if local_counter != 2 { 284 all_ones = false; 285 } 286 let is_odd = local_counter % 2 == 1; 287 let half = local_counter / 2; 288 for i in 0..half { 289 instr.push(( 290 bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index, 291 bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index, 292 )); 293 bucket_positions[new_scalar_length + i] = 294 bucket_positions[global_counter - (local_counter - 1) + 2 * i]; 295 } 296 if is_odd { 297 bucket_positions[new_scalar_length + half] = bucket_positions[global_counter]; 298 } 299 // Reset the local_counter and update state 300 new_scalar_length += half + (local_counter % 2); 301 number_of_bases_in_batch += half; 302 local_counter = 1; 303 304 if number_of_bases_in_batch >= batch_size / 2 { 305 batch_add_in_place_same_slice(&mut new_bases, &instr); 306 instr.clear(); 307 number_of_bases_in_batch = 0; 308 } 309 } else { 310 bucket_positions[new_scalar_length] = bucket_positions[global_counter]; 311 new_scalar_length += 1; 312 } 313 global_counter += 1; 314 } 315 // If there are any remaining unprocessed instructions, proceed to perform batch 316 // addition. 317 if !instr.is_empty() { 318 batch_add_in_place_same_slice(&mut new_bases, &instr); 319 instr.clear(); 320 } 321 global_counter = 0; 322 number_of_bases_in_batch = 0; 323 local_counter = 1; 324 num_scalars = new_scalar_length; 325 new_scalar_length = 0; 326 } 327 328 let mut res = vec![Zero::zero(); num_buckets]; 329 for bucket_position in bucket_positions.iter().take(num_scalars) { 330 res[bucket_position.bucket_index as usize] = new_bases[bucket_position.scalar_index as usize]; 331 } 332 res 333 } 334 335 #[inline] 336 fn batched_window<G: AffineCurve>( 337 bases: &[G], 338 scalars: &[<G::ScalarField as PrimeField>::BigInteger], 339 w_start: usize, 340 c: usize, 341 ) -> (G::Projective, usize) { 342 // We don't need the "zero" bucket, so we only have 2^c - 1 buckets 343 let window_size = if !w_start.is_multiple_of(c) { w_start % c } else { c }; 344 let num_buckets = (1 << window_size) - 1; 345 346 let mut bucket_positions: Vec<_> = scalars 347 .iter() 348 .enumerate() 349 .map(|(scalar_index, &scalar)| { 350 let mut scalar = scalar; 351 352 // We right-shift by w_start, thus getting rid of the lower bits. 353 scalar.divn(w_start as u32); 354 355 // We mod the remaining bits by the window size. 356 let scalar = (scalar.as_ref()[0] % (1 << c)) as i32; 357 358 BucketPosition { bucket_index: (scalar - 1) as u32, scalar_index: scalar_index as u32 } 359 }) 360 .collect(); 361 362 let buckets = batch_add(num_buckets, bases, &mut bucket_positions); 363 364 let mut res = G::Projective::zero(); 365 let mut running_sum = G::Projective::zero(); 366 for b in buckets.into_iter().rev() { 367 running_sum.add_assign_mixed(&b); 368 res += &running_sum; 369 } 370 371 (res, window_size) 372 } 373 374 pub fn msm<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective { 375 if bases.len() < 15 { 376 let num_bits = G::ScalarField::size_in_bits(); 377 let bigint_size = <G::ScalarField as PrimeField>::BigInteger::NUM_LIMBS * 64; 378 let mut bits = 379 scalars.iter().map(|s| BitIteratorBE::new(s.as_ref()).skip(bigint_size - num_bits)).collect::<Vec<_>>(); 380 let mut sum = G::Projective::zero(); 381 382 let mut encountered_one = false; 383 for _ in 0..num_bits { 384 if encountered_one { 385 sum.double_in_place(); 386 } 387 for (bits, base) in bits.iter_mut().zip(bases) { 388 if let Some(true) = bits.next() { 389 sum.add_assign_mixed(base); 390 encountered_one = true; 391 } 392 } 393 } 394 debug_assert!(bits.iter_mut().all(|b| b.next().is_none())); 395 sum 396 } else { 397 // Determine the bucket size `c` (chosen empirically). 398 let c = match scalars.len() < 32 { 399 true => 1, 400 false => crate::msm::ln_without_floats(scalars.len()) + 2, 401 }; 402 403 let num_bits = <G::ScalarField as PrimeField>::size_in_bits(); 404 405 // Each window is of size `c`. 406 // We divide up the bits 0..num_bits into windows of size `c`, and 407 // in parallel process each such window. 408 let window_sums: Vec<_> = 409 cfg_into_iter!(0..num_bits).step_by(c).map(|w_start| batched_window(bases, scalars, w_start, c)).collect(); 410 411 // We store the sum for the lowest window. 412 let (lowest, window_sums) = window_sums.split_first().unwrap(); 413 414 // We're traversing windows from high to low. 415 window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| { 416 total += sum_i; 417 for _ in 0..*window_size { 418 total.double_in_place(); 419 } 420 total 421 }) + lowest.0 422 } 423 }