batched.rs
1 // Copyright (c) 2025 ADnet Contributors 2 // This file is part of the AlphaVM library. 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at: 7 8 // http://www.apache.org/licenses/LICENSE-2.0 9 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 use alphavm_curves::{AffineCurve, ProjectiveCurve}; 17 use alphavm_fields::{Field, One, PrimeField, Zero}; 18 use alphavm_utilities::{BigInteger, BitIteratorBE, cfg_into_iter}; 19 20 #[cfg(not(feature = "serial"))] 21 use rayon::prelude::*; 22 23 #[cfg(target_arch = "x86_64")] 24 use crate::{prefetch_slice, prefetch_slice_write}; 25 26 #[derive(Copy, Clone, Debug)] 27 pub struct BucketPosition { 28 pub bucket_index: u32, 29 pub scalar_index: u32, 30 } 31 32 impl Eq for BucketPosition {} 33 34 impl PartialEq for BucketPosition { 35 fn eq(&self, other: &Self) -> bool { 36 self.bucket_index == other.bucket_index 37 } 38 } 39 40 impl Ord for BucketPosition { 41 fn cmp(&self, other: &Self) -> core::cmp::Ordering { 42 self.bucket_index.cmp(&other.bucket_index) 43 } 44 } 45 46 impl PartialOrd for BucketPosition { 47 #[allow(clippy::non_canonical_partial_ord_impl)] 48 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { 49 self.bucket_index.partial_cmp(&other.bucket_index) 50 } 51 } 52 53 /// Returns a batch size of sufficient size to amortize the cost of an 54 /// inversion, while attempting to reduce strain to the CPU cache. 55 #[inline] 56 const fn batch_size(msm_size: usize) -> usize { 57 // These values are determined empirically using performance benchmarks for 58 // BLS12-377 on Intel, AMD, and M1 machines. These values are determined by 59 // taking the L1 and L2 cache sizes and dividing them by the size of group 60 // elements (i.e. 96 bytes). 61 // 62 // As the algorithm itself requires caching additional values beyond the group 63 // elements, the ideal batch size is less than expected, to accommodate 64 // those values. In general, it was found that undershooting is better than 65 // overshooting this heuristic. 66 if cfg!(target_arch = "x86_64") && msm_size < 500_000 { 67 // Assumes an L1 cache size of 32KiB. Note that larger cache sizes 68 // are not negatively impacted by this value, however smaller L1 cache sizes 69 // are. 70 300 71 } else { 72 // Assumes an L2 cache size of 1MiB. Note that larger cache sizes 73 // are not negatively impacted by this value, however smaller L2 cache sizes 74 // are. 75 3000 76 } 77 } 78 79 /// If `(j, k)` is the `i`-th entry in `index`, then this method sets 80 /// `bases[j] = bases[j] + bases[k]`. The state of `bases[k]` becomes 81 /// unspecified. 82 #[inline] 83 fn batch_add_in_place_same_slice<G: AffineCurve>(bases: &mut [G], index: &[(u32, u32)]) { 84 let mut inversion_tmp = G::BaseField::one(); 85 let half = G::BaseField::half(); 86 87 #[cfg(target_arch = "x86_64")] 88 let mut prefetch_iter = index.iter(); 89 #[cfg(target_arch = "x86_64")] 90 prefetch_iter.next(); 91 92 // We run two loops over the data separated by an inversion 93 for (idx, idy) in index.iter() { 94 #[cfg(target_arch = "x86_64")] 95 prefetch_slice!(G, bases, bases, prefetch_iter); 96 97 let (a, b) = if idx < idy { 98 let (x, y) = bases.split_at_mut(*idy as usize); 99 (&mut x[*idx as usize], &mut y[0]) 100 } else { 101 let (x, y) = bases.split_at_mut(*idx as usize); 102 (&mut y[0], &mut x[*idy as usize]) 103 }; 104 G::batch_add_loop_1(a, b, &half, &mut inversion_tmp); 105 } 106 107 inversion_tmp = inversion_tmp.inverse().unwrap(); // this is always in Fp* 108 109 #[cfg(target_arch = "x86_64")] 110 let mut prefetch_iter = index.iter().rev(); 111 #[cfg(target_arch = "x86_64")] 112 prefetch_iter.next(); 113 114 for (idx, idy) in index.iter().rev() { 115 #[cfg(target_arch = "x86_64")] 116 prefetch_slice!(G, bases, bases, prefetch_iter); 117 118 let (a, b) = if idx < idy { 119 let (x, y) = bases.split_at_mut(*idy as usize); 120 (&mut x[*idx as usize], y[0]) 121 } else { 122 let (x, y) = bases.split_at_mut(*idx as usize); 123 (&mut y[0], x[*idy as usize]) 124 }; 125 G::batch_add_loop_2(a, b, &mut inversion_tmp); 126 } 127 } 128 129 /// If `(j, k)` is the `i`-th entry in `index`, then this method performs one of 130 /// two actions: 131 /// * `addition_result[i] = bases[j] + bases[k]` 132 /// * `addition_result[i] = bases[j]; 133 /// 134 /// It uses `scratch_space` to store intermediate values, and clears it after 135 /// use. 136 #[inline] 137 fn batch_add_write<G: AffineCurve>( 138 bases: &[G], 139 index: &[(u32, u32)], 140 addition_result: &mut Vec<G>, 141 scratch_space: &mut Vec<Option<G>>, 142 ) { 143 let mut inversion_tmp = G::BaseField::one(); 144 let half = G::BaseField::half(); 145 146 #[cfg(target_arch = "x86_64")] 147 let mut prefetch_iter = index.iter(); 148 #[cfg(target_arch = "x86_64")] 149 prefetch_iter.next(); 150 151 // We run two loops over the data separated by an inversion 152 for (idx, idy) in index.iter() { 153 #[cfg(target_arch = "x86_64")] 154 prefetch_slice_write!(G, bases, bases, prefetch_iter); 155 156 if *idy == !0u32 { 157 addition_result.push(bases[*idx as usize]); 158 scratch_space.push(None); 159 } else { 160 let (mut a, mut b) = (bases[*idx as usize], bases[*idy as usize]); 161 G::batch_add_loop_1(&mut a, &mut b, &half, &mut inversion_tmp); 162 addition_result.push(a); 163 scratch_space.push(Some(b)); 164 } 165 } 166 167 inversion_tmp = inversion_tmp.inverse().unwrap(); // this is always in Fp* 168 169 for (a, op_b) in addition_result.iter_mut().rev().zip(scratch_space.iter().rev()) { 170 if let Some(b) = op_b { 171 G::batch_add_loop_2(a, *b, &mut inversion_tmp); 172 } 173 } 174 scratch_space.clear(); 175 } 176 177 #[inline] 178 pub(super) fn batch_add<G: AffineCurve>( 179 num_buckets: usize, 180 bases: &[G], 181 bucket_positions: &mut [BucketPosition], 182 ) -> Vec<G> { 183 assert!(bases.len() >= bucket_positions.len()); 184 assert!(!bases.is_empty()); 185 186 // Fetch the ideal batch size for the number of bases. 187 let batch_size = batch_size(bases.len()); 188 189 // Sort the buckets by their bucket index (not scalar index). 190 bucket_positions.sort_unstable(); 191 192 let mut num_scalars = bucket_positions.len(); 193 let mut all_ones = true; 194 let mut new_scalar_length = 0; 195 let mut global_counter = 0; 196 let mut local_counter = 1; 197 let mut number_of_bases_in_batch = 0; 198 199 let mut instr = Vec::<(u32, u32)>::with_capacity(batch_size); 200 let mut new_bases = Vec::with_capacity(bases.len()); 201 let mut scratch_space = Vec::with_capacity(batch_size / 2); 202 203 // In the first loop, copy the results of the first in-place addition tree to 204 // the vector `new_bases`. 205 while global_counter < num_scalars { 206 let current_bucket = bucket_positions[global_counter].bucket_index; 207 while global_counter + 1 < num_scalars && bucket_positions[global_counter + 1].bucket_index == current_bucket { 208 global_counter += 1; 209 local_counter += 1; 210 } 211 if current_bucket >= num_buckets as u32 { 212 local_counter = 1; 213 } else if local_counter > 1 { 214 // all ones is false if next len is not 1 215 if local_counter > 2 { 216 all_ones = false; 217 } 218 let is_odd = local_counter % 2 == 1; 219 let half = local_counter / 2; 220 for i in 0..half { 221 instr.push(( 222 bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index, 223 bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index, 224 )); 225 bucket_positions[new_scalar_length + i] = 226 BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + i) as u32 }; 227 } 228 if is_odd { 229 instr.push((bucket_positions[global_counter].scalar_index, !0u32)); 230 bucket_positions[new_scalar_length + half] = 231 BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + half) as u32 }; 232 } 233 // Reset the local_counter and update state 234 new_scalar_length += half + (local_counter % 2); 235 number_of_bases_in_batch += half; 236 local_counter = 1; 237 238 // When the number of bases in a batch crosses the threshold, perform a batch 239 // addition. 240 if number_of_bases_in_batch >= batch_size / 2 { 241 // We need instructions for copying data in the case of noops. 242 // We encode noops/copies as !0u32 243 batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space); 244 245 instr.clear(); 246 number_of_bases_in_batch = 0; 247 } 248 } else { 249 instr.push((bucket_positions[global_counter].scalar_index, !0u32)); 250 bucket_positions[new_scalar_length] = 251 BucketPosition { bucket_index: current_bucket, scalar_index: new_scalar_length as u32 }; 252 new_scalar_length += 1; 253 } 254 global_counter += 1; 255 } 256 if !instr.is_empty() { 257 batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space); 258 instr.clear(); 259 } 260 global_counter = 0; 261 number_of_bases_in_batch = 0; 262 local_counter = 1; 263 num_scalars = new_scalar_length; 264 new_scalar_length = 0; 265 266 // Next, perform all the updates in place. 267 while !all_ones { 268 all_ones = true; 269 while global_counter < num_scalars { 270 let current_bucket = bucket_positions[global_counter].bucket_index; 271 while global_counter + 1 < num_scalars 272 && bucket_positions[global_counter + 1].bucket_index == current_bucket 273 { 274 global_counter += 1; 275 local_counter += 1; 276 } 277 if current_bucket >= num_buckets as u32 { 278 local_counter = 1; 279 } else if local_counter > 1 { 280 // all ones is false if next len is not 1 281 if local_counter != 2 { 282 all_ones = false; 283 } 284 let is_odd = local_counter % 2 == 1; 285 let half = local_counter / 2; 286 for i in 0..half { 287 instr.push(( 288 bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index, 289 bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index, 290 )); 291 bucket_positions[new_scalar_length + i] = 292 bucket_positions[global_counter - (local_counter - 1) + 2 * i]; 293 } 294 if is_odd { 295 bucket_positions[new_scalar_length + half] = bucket_positions[global_counter]; 296 } 297 // Reset the local_counter and update state 298 new_scalar_length += half + (local_counter % 2); 299 number_of_bases_in_batch += half; 300 local_counter = 1; 301 302 if number_of_bases_in_batch >= batch_size / 2 { 303 batch_add_in_place_same_slice(&mut new_bases, &instr); 304 instr.clear(); 305 number_of_bases_in_batch = 0; 306 } 307 } else { 308 bucket_positions[new_scalar_length] = bucket_positions[global_counter]; 309 new_scalar_length += 1; 310 } 311 global_counter += 1; 312 } 313 // If there are any remaining unprocessed instructions, proceed to perform batch 314 // addition. 315 if !instr.is_empty() { 316 batch_add_in_place_same_slice(&mut new_bases, &instr); 317 instr.clear(); 318 } 319 global_counter = 0; 320 number_of_bases_in_batch = 0; 321 local_counter = 1; 322 num_scalars = new_scalar_length; 323 new_scalar_length = 0; 324 } 325 326 let mut res = vec![Zero::zero(); num_buckets]; 327 for bucket_position in bucket_positions.iter().take(num_scalars) { 328 res[bucket_position.bucket_index as usize] = new_bases[bucket_position.scalar_index as usize]; 329 } 330 res 331 } 332 333 #[inline] 334 fn batched_window<G: AffineCurve>( 335 bases: &[G], 336 scalars: &[<G::ScalarField as PrimeField>::BigInteger], 337 w_start: usize, 338 c: usize, 339 ) -> (G::Projective, usize) { 340 // We don't need the "zero" bucket, so we only have 2^c - 1 buckets 341 let window_size = if (w_start % c) != 0 { w_start % c } else { c }; 342 let num_buckets = (1 << window_size) - 1; 343 344 let mut bucket_positions: Vec<_> = scalars 345 .iter() 346 .enumerate() 347 .map(|(scalar_index, &scalar)| { 348 let mut scalar = scalar; 349 350 // We right-shift by w_start, thus getting rid of the lower bits. 351 scalar.divn(w_start as u32); 352 353 // We mod the remaining bits by the window size. 354 let scalar = (scalar.as_ref()[0] % (1 << c)) as i32; 355 356 BucketPosition { bucket_index: (scalar - 1) as u32, scalar_index: scalar_index as u32 } 357 }) 358 .collect(); 359 360 let buckets = batch_add(num_buckets, bases, &mut bucket_positions); 361 362 let mut res = G::Projective::zero(); 363 let mut running_sum = G::Projective::zero(); 364 for b in buckets.into_iter().rev() { 365 running_sum.add_assign_mixed(&b); 366 res += &running_sum; 367 } 368 369 (res, window_size) 370 } 371 372 pub fn msm<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective { 373 if bases.len() < 15 { 374 let num_bits = G::ScalarField::size_in_bits(); 375 let bigint_size = <G::ScalarField as PrimeField>::BigInteger::NUM_LIMBS * 64; 376 let mut bits = 377 scalars.iter().map(|s| BitIteratorBE::new(s.as_ref()).skip(bigint_size - num_bits)).collect::<Vec<_>>(); 378 let mut sum = G::Projective::zero(); 379 380 let mut encountered_one = false; 381 for _ in 0..num_bits { 382 if encountered_one { 383 sum.double_in_place(); 384 } 385 for (bits, base) in bits.iter_mut().zip(bases) { 386 if let Some(true) = bits.next() { 387 sum.add_assign_mixed(base); 388 encountered_one = true; 389 } 390 } 391 } 392 debug_assert!(bits.iter_mut().all(|b| b.next().is_none())); 393 sum 394 } else { 395 // Determine the bucket size `c` (chosen empirically). 396 let c = match scalars.len() < 32 { 397 true => 1, 398 false => crate::msm::ln_without_floats(scalars.len()) + 2, 399 }; 400 401 let num_bits = <G::ScalarField as PrimeField>::size_in_bits(); 402 403 // Each window is of size `c`. 404 // We divide up the bits 0..num_bits into windows of size `c`, and 405 // in parallel process each such window. 406 let window_sums: Vec<_> = 407 cfg_into_iter!(0..num_bits).step_by(c).map(|w_start| batched_window(bases, scalars, w_start, c)).collect(); 408 409 // We store the sum for the lowest window. 410 let (lowest, window_sums) = window_sums.split_first().unwrap(); 411 412 // We're traversing windows from high to low. 413 window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| { 414 total += sum_i; 415 for _ in 0..*window_size { 416 total.double_in_place(); 417 } 418 total 419 }) + lowest.0 420 } 421 }