/ algorithms / src / fft / domain.rs
domain.rs
   1  // Copyright (c) 2025 ADnet Contributors
   2  // This file is part of the AlphaVM library.
   3  
   4  // Licensed under the Apache License, Version 2.0 (the "License");
   5  // you may not use this file except in compliance with the License.
   6  // You may obtain a copy of the License at:
   7  
   8  // http://www.apache.org/licenses/LICENSE-2.0
   9  
  10  // Unless required by applicable law or agreed to in writing, software
  11  // distributed under the License is distributed on an "AS IS" BASIS,
  12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13  // See the License for the specific language governing permissions and
  14  // limitations under the License.
  15  
  16  //! This module contains an `EvaluationDomain` abstraction for
  17  //! performing various kinds of polynomial arithmetic on top of
  18  //! the scalar field.
  19  //!
  20  //! In pairing-based SNARKs like GM17, we need to calculate
  21  //! a quotient polynomial over a target polynomial with roots
  22  //! at distinct points associated with each constraint of the
  23  //! constraint system. In order to be efficient, we choose these
  24  //! roots to be the powers of a 2^n root of unity in the field.
  25  //! This allows us to perform polynomial operations in O(n)
  26  //! by performing an O(n log n) FFT over such a domain.
  27  
  28  use crate::{
  29      cfg_chunks_mut, cfg_into_iter, cfg_iter, cfg_iter_mut,
  30      fft::{DomainCoeff, SparsePolynomial},
  31  };
  32  use alphavm_fields::{FftField, FftParameters, Field, batch_inversion};
  33  #[cfg(not(feature = "serial"))]
  34  use alphavm_utilities::max_available_threads;
  35  use alphavm_utilities::{execute_with_max_available_threads, serialize::*};
  36  
  37  use rand::Rng;
  38  use std::{borrow::Cow, fmt};
  39  
  40  use anyhow::{Result, ensure};
  41  
  42  #[cfg(not(feature = "serial"))]
  43  use rayon::prelude::*;
  44  
  45  #[cfg(feature = "serial")]
  46  use itertools::Itertools;
  47  
  48  /// Returns the ceiling of the base-2 logarithm of `x`.
  49  ///
  50  /// ```
  51  /// use alphavm_algorithms::fft::domain::log2;
  52  ///
  53  /// assert_eq!(log2(16), 4);
  54  /// assert_eq!(log2(17), 5);
  55  /// assert_eq!(log2(1), 0);
  56  /// assert_eq!(log2(0), 0);
  57  /// assert_eq!(log2(usize::MAX), (core::mem::size_of::<usize>() * 8) as u32);
  58  /// assert_eq!(log2(1 << 15), 15);
  59  /// assert_eq!(log2(2usize.pow(18)), 18);
  60  /// ```
  61  pub fn log2(x: usize) -> u32 {
  62      if x == 0 {
  63          0
  64      } else if x.is_power_of_two() {
  65          1usize.leading_zeros() - x.leading_zeros()
  66      } else {
  67          0usize.leading_zeros() - x.leading_zeros()
  68      }
  69  }
  70  
  71  // minimum size of a parallelized chunk
  72  #[allow(unused)]
  73  #[cfg(not(feature = "serial"))]
  74  const MIN_PARALLEL_CHUNK_SIZE: usize = 1 << 7;
  75  
  76  /// Defines a domain over which finite field (I)FFTs can be performed. Works
  77  /// only for fields that have a large multiplicative subgroup of size that is
  78  /// a power-of-2.
  79  #[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
  80  pub struct EvaluationDomain<F: FftField> {
  81      /// The size of the domain.
  82      pub size: u64,
  83      /// `log_2(self.size)`.
  84      pub log_size_of_group: u32,
  85      /// Size of the domain as a field element.
  86      pub size_as_field_element: F,
  87      /// Inverse of the size in the field.
  88      pub size_inv: F,
  89      /// A generator of the subgroup.
  90      pub group_gen: F,
  91      /// Inverse of the generator of the subgroup.
  92      pub group_gen_inv: F,
  93      /// Inverse of the multiplicative generator of the finite field.
  94      pub generator_inv: F,
  95  }
  96  
  97  impl<F: FftField> fmt::Debug for EvaluationDomain<F> {
  98      fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  99          write!(f, "Multiplicative subgroup of size {}", self.size)
 100      }
 101  }
 102  
 103  impl<F: FftField> EvaluationDomain<F> {
 104      /// Sample an element that is *not* in the domain.
 105      pub fn sample_element_outside_domain<R: Rng>(&self, rng: &mut R) -> F {
 106          let mut t = F::rand(rng);
 107          while self.evaluate_vanishing_polynomial(t).is_zero() {
 108              t = F::rand(rng);
 109          }
 110          t
 111      }
 112  
 113      /// Construct a domain that is large enough for evaluations of a polynomial
 114      /// having `num_coeffs` coefficients.
 115      pub fn new(num_coeffs: usize) -> Option<Self> {
 116          // Compute the size of our evaluation domain
 117          let size = num_coeffs.checked_next_power_of_two()? as u64;
 118          let log_size_of_group = size.trailing_zeros();
 119  
 120          // libfqfft uses > https://github.com/scipr-lab/libfqfft/blob/e0183b2cef7d4c5deb21a6eaf3fe3b586d738fe0/libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc#L33
 121          if log_size_of_group > F::FftParameters::TWO_ADICITY {
 122              return None;
 123          }
 124  
 125          // Compute the generator for the multiplicative subgroup.
 126          // It should be the 2^(log_size_of_group) root of unity.
 127          let group_gen = F::get_root_of_unity(size as usize)?;
 128  
 129          // Check that it is indeed the 2^(log_size_of_group) root of unity.
 130          debug_assert_eq!(group_gen.pow([size]), F::one());
 131  
 132          let size_as_field_element = F::from(size);
 133          let size_inv = size_as_field_element.inverse()?;
 134  
 135          Some(EvaluationDomain {
 136              size,
 137              log_size_of_group,
 138              size_as_field_element,
 139              size_inv,
 140              group_gen,
 141              group_gen_inv: group_gen.inverse()?,
 142              generator_inv: F::multiplicative_generator().inverse()?,
 143          })
 144      }
 145  
 146      /// Return the size of a domain that is large enough for evaluations of a
 147      /// polynomial having `num_coeffs` coefficients.
 148      pub fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
 149          let size = num_coeffs.checked_next_power_of_two()?;
 150          if size.trailing_zeros() <= F::FftParameters::TWO_ADICITY { Some(size) } else { None }
 151      }
 152  
 153      /// Return the size of `self`.
 154      pub fn size(&self) -> usize {
 155          self.size as usize
 156      }
 157  
 158      /// Compute an FFT.
 159      pub fn fft<T: DomainCoeff<F>>(&self, coeffs: &[T]) -> Vec<T> {
 160          let mut coeffs = coeffs.to_vec();
 161          self.fft_in_place(&mut coeffs);
 162          coeffs
 163      }
 164  
 165      /// Compute an FFT, modifying the vector in place.
 166      pub fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
 167          execute_with_max_available_threads(|| {
 168              coeffs.resize(self.size(), T::zero());
 169              self.in_order_fft_in_place(&mut *coeffs);
 170          });
 171      }
 172  
 173      /// Compute an IFFT.
 174      pub fn ifft<T: DomainCoeff<F>>(&self, evals: &[T]) -> Vec<T> {
 175          let mut evals = evals.to_vec();
 176          self.ifft_in_place(&mut evals);
 177          evals
 178      }
 179  
 180      /// Compute an IFFT, modifying the vector in place.
 181      #[inline]
 182      pub fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
 183          execute_with_max_available_threads(|| {
 184              evals.resize(self.size(), T::zero());
 185              self.in_order_ifft_in_place(&mut *evals);
 186          });
 187      }
 188  
 189      /// Compute an FFT over a coset of the domain.
 190      pub fn coset_fft<T: DomainCoeff<F>>(&self, coeffs: &[T]) -> Vec<T> {
 191          let mut coeffs = coeffs.to_vec();
 192          self.coset_fft_in_place(&mut coeffs);
 193          coeffs
 194      }
 195  
 196      /// Compute an FFT over a coset of the domain, modifying the input vector
 197      /// in place.
 198      pub fn coset_fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
 199          execute_with_max_available_threads(|| {
 200              Self::distribute_powers(coeffs, F::multiplicative_generator());
 201              self.fft_in_place(coeffs);
 202          });
 203      }
 204  
 205      /// Compute an IFFT over a coset of the domain.
 206      pub fn coset_ifft<T: DomainCoeff<F>>(&self, evals: &[T]) -> Vec<T> {
 207          let mut evals = evals.to_vec();
 208          self.coset_ifft_in_place(&mut evals);
 209          evals
 210      }
 211  
 212      /// Compute an IFFT over a coset of the domain, modifying the input vector
 213      /// in place.
 214      pub fn coset_ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
 215          execute_with_max_available_threads(|| {
 216              evals.resize(self.size(), T::zero());
 217              self.in_order_coset_ifft_in_place(&mut *evals);
 218          });
 219      }
 220  
 221      /// Multiply the `i`-th element of `coeffs` with `g^i`.
 222      fn distribute_powers<T: DomainCoeff<F>>(coeffs: &mut [T], g: F) {
 223          Self::distribute_powers_and_mul_by_const(coeffs, g, F::one());
 224      }
 225  
 226      /// Multiply the `i`-th element of `coeffs` with `c*g^i`.
 227      #[cfg(feature = "serial")]
 228      fn distribute_powers_and_mul_by_const<T: DomainCoeff<F>>(coeffs: &mut [T], g: F, c: F) {
 229          // invariant: pow = c*g^i at the ith iteration of the loop
 230          let mut pow = c;
 231          coeffs.iter_mut().for_each(|coeff| {
 232              *coeff *= pow;
 233              pow *= &g
 234          })
 235      }
 236  
 237      /// Multiply the `i`-th element of `coeffs` with `c*g^i`.
 238      #[cfg(not(feature = "serial"))]
 239      fn distribute_powers_and_mul_by_const<T: DomainCoeff<F>>(coeffs: &mut [T], g: F, c: F) {
 240          let min_parallel_chunk_size = 1024;
 241          let num_cpus_available = max_available_threads();
 242          let num_elem_per_thread = core::cmp::max(coeffs.len() / num_cpus_available, min_parallel_chunk_size);
 243  
 244          cfg_chunks_mut!(coeffs, num_elem_per_thread).enumerate().for_each(|(i, chunk)| {
 245              let offset = c * g.pow([(i * num_elem_per_thread) as u64]);
 246              let mut pow = offset;
 247              chunk.iter_mut().for_each(|coeff| {
 248                  *coeff *= pow;
 249                  pow *= &g
 250              })
 251          });
 252      }
 253  
 254      /// Evaluate all the lagrange polynomials defined by this domain at the
 255      /// point `tau`.
 256      pub fn evaluate_all_lagrange_coefficients(&self, tau: F) -> Vec<F> {
 257          // Evaluate all Lagrange polynomials
 258          let size = self.size as usize;
 259          let t_size = tau.pow([self.size]);
 260          let one = F::one();
 261          if t_size.is_one() {
 262              let mut u = vec![F::zero(); size];
 263              let mut omega_i = one;
 264              for x in u.iter_mut().take(size) {
 265                  if omega_i == tau {
 266                      *x = one;
 267                      break;
 268                  }
 269                  omega_i *= &self.group_gen;
 270              }
 271              u
 272          } else {
 273              let mut l = (t_size - one) * self.size_inv;
 274              let mut r = one;
 275              let mut u = vec![F::zero(); size];
 276              let mut ls = vec![F::zero(); size];
 277              for i in 0..size {
 278                  u[i] = tau - r;
 279                  ls[i] = l;
 280                  l *= &self.group_gen;
 281                  r *= &self.group_gen;
 282              }
 283  
 284              batch_inversion(u.as_mut_slice());
 285              cfg_iter_mut!(u).zip_eq(ls).for_each(|(tau_minus_r, l)| {
 286                  *tau_minus_r = l * *tau_minus_r;
 287              });
 288              u
 289          }
 290      }
 291  
 292      /// Return the sparse vanishing polynomial.
 293      pub fn vanishing_polynomial(&self) -> SparsePolynomial<F> {
 294          let coeffs = [(0, -F::one()), (self.size(), F::one())];
 295          SparsePolynomial::from_coefficients(coeffs)
 296      }
 297  
 298      /// This evaluates the vanishing polynomial for this domain at tau.
 299      /// For multiplicative subgroups, this polynomial is `z(X) = X^self.size -
 300      /// 1`.
 301      pub fn evaluate_vanishing_polynomial(&self, tau: F) -> F {
 302          tau.pow([self.size]) - F::one()
 303      }
 304  
 305      /// Return an iterator over the elements of the domain.
 306      pub fn elements(&self) -> Elements<F> {
 307          Elements { cur_elem: F::one(), cur_pow: 0, domain: *self }
 308      }
 309  
 310      /// The target polynomial is the zero polynomial in our
 311      /// evaluation domain, so we must perform division over
 312      /// a coset.
 313      pub fn divide_by_vanishing_poly_on_coset_in_place(&self, evals: &mut [F]) {
 314          let i = self.evaluate_vanishing_polynomial(F::multiplicative_generator()).inverse().unwrap();
 315  
 316          cfg_iter_mut!(evals).for_each(|eval| *eval *= &i);
 317      }
 318  
 319      /// Given an index in the `other` subdomain, return an index into this
 320      /// domain `self` This assumes the `other`'s elements are also `self`'s
 321      /// first elements
 322      pub fn reindex_by_subdomain(&self, other: &Self, index: usize) -> Result<usize> {
 323          ensure!(self.size() > other.size(), "other.size() must be smaller than self.size()");
 324  
 325          // Let this subgroup be G, and the subgroup we're re-indexing by be S.
 326          // Since its a subgroup, the 0th element of S is at index 0 in G, the first
 327          // element of S is at index |G|/|S|, the second at 2*|G|/|S|, etc.
 328          // Thus for an index i that corresponds to S, the index in G is i*|G|/|S|
 329          let period = self.size() / other.size();
 330          if index < other.size() {
 331              Ok(index * period)
 332          } else {
 333              // Let i now be the index of this element in G \ S
 334              // Let x be the number of elements in G \ S, for every element in S. Then x =
 335              // (|G|/|S| - 1). At index i in G \ S, the number of elements in S
 336              // that appear before the index in G to which i corresponds to, is
 337              // floor(i / x) + 1. The +1 is because index 0 of G is S_0, so the
 338              // position is offset by at least one. The floor(i / x) term is
 339              // because after x elements in G \ S, there is one more element from S
 340              // that will have appeared in G.
 341              let i = index - other.size();
 342              let x = period - 1;
 343              Ok(i + (i / x) + 1)
 344          }
 345      }
 346  
 347      /// Perform O(n) multiplication of two polynomials that are presented by
 348      /// their evaluations in the domain.
 349      /// Returns the evaluations of the product over the domain.
 350      pub fn mul_polynomials_in_evaluation_domain(&self, self_evals: Vec<F>, other_evals: &[F]) -> Result<Vec<F>> {
 351          let mut result = self_evals;
 352  
 353          ensure!(result.len() == other_evals.len());
 354          cfg_iter_mut!(result).zip_eq(other_evals).for_each(|(a, b)| *a *= b);
 355  
 356          Ok(result)
 357      }
 358  }
 359  
 360  impl<F: FftField> EvaluationDomain<F> {
 361      pub fn precompute_fft(&self) -> FFTPrecomputation<F> {
 362          execute_with_max_available_threads(|| FFTPrecomputation {
 363              roots: self.roots_of_unity(self.group_gen),
 364              domain: *self,
 365          })
 366      }
 367  
 368      pub fn precompute_ifft(&self) -> IFFTPrecomputation<F> {
 369          execute_with_max_available_threads(|| IFFTPrecomputation {
 370              inverse_roots: self.roots_of_unity(self.group_gen_inv),
 371              domain: *self,
 372          })
 373      }
 374  
 375      pub(crate) fn in_order_fft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
 376          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 377          // SNP TODO: how to set threshold and check that the type is Fr
 378          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 379              let result = alphavm_algorithms_cuda::NTT(
 380                  self.size as usize,
 381                  x_s,
 382                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 383                  alphavm_algorithms_cuda::NTTDirection::Forward,
 384                  alphavm_algorithms_cuda::NTTType::Standard,
 385              );
 386              if result.is_ok() {
 387                  return;
 388              }
 389          }
 390  
 391          let pc = self.precompute_fft();
 392          self.fft_helper_in_place_with_pc(x_s, FFTOrder::II, &pc)
 393      }
 394  
 395      pub fn in_order_fft_with_pc<T: DomainCoeff<F>>(&self, x_s: &[T], pc: &FFTPrecomputation<F>) -> Vec<T> {
 396          let mut x_s = x_s.to_vec();
 397          if self.size() != x_s.len() {
 398              x_s.extend(core::iter::repeat_n(T::zero(), self.size() - x_s.len()));
 399          }
 400          self.fft_helper_in_place_with_pc(&mut x_s, FFTOrder::II, pc);
 401          x_s
 402      }
 403  
 404      pub(crate) fn in_order_ifft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
 405          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 406          // SNP TODO: how to set threshold
 407          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 408              let result = alphavm_algorithms_cuda::NTT(
 409                  self.size as usize,
 410                  x_s,
 411                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 412                  alphavm_algorithms_cuda::NTTDirection::Inverse,
 413                  alphavm_algorithms_cuda::NTTType::Standard,
 414              );
 415              if result.is_ok() {
 416                  return;
 417              }
 418          }
 419  
 420          let pc = self.precompute_ifft();
 421          self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, &pc);
 422          cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
 423      }
 424  
 425      pub(crate) fn in_order_coset_ifft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
 426          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 427          // SNP TODO: how to set threshold
 428          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 429              let result = alphavm_algorithms_cuda::NTT(
 430                  self.size as usize,
 431                  x_s,
 432                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 433                  alphavm_algorithms_cuda::NTTDirection::Inverse,
 434                  alphavm_algorithms_cuda::NTTType::Coset,
 435              );
 436              if result.is_ok() {
 437                  return;
 438              }
 439          }
 440  
 441          let pc = self.precompute_ifft();
 442          self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, &pc);
 443          let coset_shift = self.generator_inv;
 444          Self::distribute_powers_and_mul_by_const(x_s, coset_shift, self.size_inv);
 445      }
 446  
 447      #[allow(unused)]
 448      pub(crate) fn in_order_fft_in_place_with_pc<T: DomainCoeff<F>>(
 449          &self,
 450          x_s: &mut [T],
 451          pre_comp: &FFTPrecomputation<F>,
 452      ) {
 453          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 454          // SNP TODO: how to set threshold
 455          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 456              let result = alphavm_algorithms_cuda::NTT(
 457                  self.size as usize,
 458                  x_s,
 459                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 460                  alphavm_algorithms_cuda::NTTDirection::Forward,
 461                  alphavm_algorithms_cuda::NTTType::Standard,
 462              );
 463              if result.is_ok() {
 464                  return;
 465              }
 466          }
 467  
 468          self.fft_helper_in_place_with_pc(x_s, FFTOrder::II, pre_comp)
 469      }
 470  
 471      pub(crate) fn out_order_fft_in_place_with_pc<T: DomainCoeff<F>>(
 472          &self,
 473          x_s: &mut [T],
 474          pre_comp: &FFTPrecomputation<F>,
 475      ) {
 476          self.fft_helper_in_place_with_pc(x_s, FFTOrder::IO, pre_comp)
 477      }
 478  
 479      pub(crate) fn in_order_ifft_in_place_with_pc<T: DomainCoeff<F>>(
 480          &self,
 481          x_s: &mut [T],
 482          pre_comp: &IFFTPrecomputation<F>,
 483      ) {
 484          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 485          // SNP TODO: how to set threshold
 486          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 487              let result = alphavm_algorithms_cuda::NTT(
 488                  self.size as usize,
 489                  x_s,
 490                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 491                  alphavm_algorithms_cuda::NTTDirection::Inverse,
 492                  alphavm_algorithms_cuda::NTTType::Standard,
 493              );
 494              if result.is_ok() {
 495                  return;
 496              }
 497          }
 498  
 499          self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, pre_comp);
 500          cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
 501      }
 502  
 503      pub(crate) fn out_order_ifft_in_place_with_pc<T: DomainCoeff<F>>(
 504          &self,
 505          x_s: &mut [T],
 506          pre_comp: &IFFTPrecomputation<F>,
 507      ) {
 508          self.ifft_helper_in_place_with_pc(x_s, FFTOrder::OI, pre_comp);
 509          cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
 510      }
 511  
 512      #[allow(unused)]
 513      pub(crate) fn in_order_coset_ifft_in_place_with_pc<T: DomainCoeff<F>>(
 514          &self,
 515          x_s: &mut [T],
 516          pre_comp: &IFFTPrecomputation<F>,
 517      ) {
 518          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 519          // SNP TODO: how to set threshold
 520          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 521              let result = alphavm_algorithms_cuda::NTT(
 522                  self.size as usize,
 523                  x_s,
 524                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 525                  alphavm_algorithms_cuda::NTTDirection::Inverse,
 526                  alphavm_algorithms_cuda::NTTType::Coset,
 527              );
 528              if result.is_ok() {
 529                  return;
 530              }
 531          }
 532  
 533          self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, pre_comp);
 534          let coset_shift = self.generator_inv;
 535          Self::distribute_powers_and_mul_by_const(x_s, coset_shift, self.size_inv);
 536      }
 537  
 538      fn fft_helper_in_place_with_pc<T: DomainCoeff<F>>(
 539          &self,
 540          x_s: &mut [T],
 541          ord: FFTOrder,
 542          pre_comp: &FFTPrecomputation<F>,
 543      ) {
 544          use FFTOrder::*;
 545          let pc = pre_comp.precomputation_for_subdomain(self).unwrap();
 546  
 547          let log_len = log2(x_s.len());
 548  
 549          if ord == OI {
 550              self.oi_helper_with_roots(x_s, &pc.roots);
 551          } else {
 552              self.io_helper_with_roots(x_s, &pc.roots);
 553          }
 554  
 555          if ord == II {
 556              derange_helper(x_s, log_len);
 557          }
 558      }
 559  
 560      // Handles doing an IFFT with handling of being in order and out of order.
 561      // The results here must all be divided by |x_s|,
 562      // which is left up to the caller to do.
 563      fn ifft_helper_in_place_with_pc<T: DomainCoeff<F>>(
 564          &self,
 565          x_s: &mut [T],
 566          ord: FFTOrder,
 567          pre_comp: &IFFTPrecomputation<F>,
 568      ) {
 569          use FFTOrder::*;
 570          let pc = pre_comp.precomputation_for_subdomain(self).unwrap();
 571  
 572          let log_len = log2(x_s.len());
 573  
 574          if ord == II {
 575              derange_helper(x_s, log_len);
 576          }
 577  
 578          if ord == IO {
 579              self.io_helper_with_roots(x_s, &pc.inverse_roots);
 580          } else {
 581              self.oi_helper_with_roots(x_s, &pc.inverse_roots);
 582          }
 583      }
 584  
 585      /// Computes the first `self.size / 2` roots of unity for the entire domain.
 586      /// e.g. for the domain [1, g, g^2, ..., g^{n - 1}], it computes
 587      // [1, g, g^2, ..., g^{(n/2) - 1}]
 588      #[cfg(feature = "serial")]
 589      pub fn roots_of_unity(&self, root: F) -> Vec<F> {
 590          compute_powers_serial((self.size as usize) / 2, root)
 591      }
 592  
 593      /// Computes the first `self.size / 2` roots of unity.
 594      #[cfg(not(feature = "serial"))]
 595      pub fn roots_of_unity(&self, root: F) -> Vec<F> {
 596          // TODO: check if this method can replace parallel compute powers.
 597          let log_size = log2(self.size as usize);
 598          // early exit for short inputs
 599          if log_size <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE {
 600              compute_powers_serial((self.size as usize) / 2, root)
 601          } else {
 602              let mut temp = root;
 603              // w, w^2, w^4, w^8, ..., w^(2^(log_size - 1))
 604              let log_powers: Vec<F> = (0..(log_size - 1))
 605                  .map(|_| {
 606                      let old_value = temp;
 607                      temp.square_in_place();
 608                      old_value
 609                  })
 610                  .collect();
 611  
 612              // allocate the return array and start the recursion
 613              let mut powers = vec![F::zero(); 1 << (log_size - 1)];
 614              Self::roots_of_unity_recursive(&mut powers, &log_powers);
 615              powers
 616          }
 617      }
 618  
 619      #[cfg(not(feature = "serial"))]
 620      fn roots_of_unity_recursive(out: &mut [F], log_powers: &[F]) {
 621          assert_eq!(out.len(), 1 << log_powers.len());
 622          // base case: just compute the powers sequentially,
 623          // g = log_powers[0], out = [1, g, g^2, ...]
 624          if log_powers.len() <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE as usize {
 625              out[0] = F::one();
 626              for idx in 1..out.len() {
 627                  out[idx] = out[idx - 1] * log_powers[0];
 628              }
 629              return;
 630          }
 631  
 632          // recursive case:
 633          // 1. split log_powers in half
 634          let (lr_lo, lr_hi) = log_powers.split_at(log_powers.len().div_ceil(2));
 635          let mut scr_lo = vec![F::default(); 1 << lr_lo.len()];
 636          let mut scr_hi = vec![F::default(); 1 << lr_hi.len()];
 637          // 2. compute each half individually
 638          rayon::join(
 639              || Self::roots_of_unity_recursive(&mut scr_lo, lr_lo),
 640              || Self::roots_of_unity_recursive(&mut scr_hi, lr_hi),
 641          );
 642          // 3. recombine halves
 643          // At this point, out is a blank slice.
 644          out.par_chunks_mut(scr_lo.len()).zip(&scr_hi).for_each(|(out_chunk, scr_hi)| {
 645              for (out_elem, scr_lo) in out_chunk.iter_mut().zip(&scr_lo) {
 646                  *out_elem = *scr_hi * scr_lo;
 647              }
 648          });
 649      }
 650  
 651      #[inline(always)]
 652      fn butterfly_fn_io<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
 653          let neg = *lo - *hi;
 654          *lo += *hi;
 655          *hi = neg;
 656          *hi *= *root;
 657      }
 658  
 659      #[inline(always)]
 660      fn butterfly_fn_oi<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
 661          *hi *= *root;
 662          let neg = *lo - *hi;
 663          *lo += *hi;
 664          *hi = neg;
 665      }
 666  
 667      #[allow(clippy::too_many_arguments)]
 668      fn apply_butterfly<T: DomainCoeff<F>, G: Fn(((&mut T, &mut T), &F)) + Copy + Sync + Send>(
 669          g: G,
 670          xi: &mut [T],
 671          roots: &[F],
 672          step: usize,
 673          chunk_size: usize,
 674          num_chunks: usize,
 675          max_threads: usize,
 676          gap: usize,
 677      ) {
 678          cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
 679              let (lo, hi) = cxi.split_at_mut(gap);
 680              // If the chunk is sufficiently big that parallelism helps,
 681              // we parallelize the butterfly operation within the chunk.
 682  
 683              if gap > MIN_GAP_SIZE_FOR_PARALLELISATION && num_chunks < max_threads {
 684                  cfg_iter_mut!(lo).zip(hi).zip(cfg_iter!(roots).step_by(step)).for_each(g);
 685              } else {
 686                  lo.iter_mut().zip(hi).zip(roots.iter().step_by(step)).for_each(g);
 687              }
 688          });
 689      }
 690  
 691      #[allow(clippy::unnecessary_to_owned)]
 692      fn io_helper_with_roots<T: DomainCoeff<F>>(&self, xi: &mut [T], roots: &[F]) {
 693          let mut roots = std::borrow::Cow::Borrowed(roots);
 694  
 695          let mut step = 1;
 696          let mut first = true;
 697  
 698          #[cfg(not(feature = "serial"))]
 699          let max_threads = alphavm_utilities::parallel::max_available_threads();
 700          #[cfg(feature = "serial")]
 701          let max_threads = 1;
 702  
 703          let mut gap = xi.len() / 2;
 704          while gap > 0 {
 705              // each butterfly cluster uses 2*gap positions
 706              let chunk_size = 2 * gap;
 707              let num_chunks = xi.len() / chunk_size;
 708  
 709              // Only compact roots to achieve cache locality/compactness if
 710              // the roots lookup is done a significant amount of times
 711              // Which also implies a large lookup stride.
 712              if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION {
 713                  if !first {
 714                      roots = Cow::Owned(cfg_into_iter!(roots.into_owned()).step_by(step * 2).collect());
 715                  }
 716                  step = 1;
 717                  roots.to_mut().shrink_to_fit();
 718              } else {
 719                  step = num_chunks;
 720              }
 721              first = false;
 722  
 723              Self::apply_butterfly(
 724                  Self::butterfly_fn_io,
 725                  xi,
 726                  &roots[..],
 727                  step,
 728                  chunk_size,
 729                  num_chunks,
 730                  max_threads,
 731                  gap,
 732              );
 733  
 734              gap /= 2;
 735          }
 736      }
 737  
 738      fn oi_helper_with_roots<T: DomainCoeff<F>>(&self, xi: &mut [T], roots_cache: &[F]) {
 739          // The `cmp::min` is only necessary for the case where
 740          // `MIN_NUM_CHUNKS_FOR_COMPACTION = 1`. Else, notice that we compact
 741          // the roots cache by a stride of at least `MIN_NUM_CHUNKS_FOR_COMPACTION`.
 742  
 743          let compaction_max_size =
 744              core::cmp::min(roots_cache.len() / 2, roots_cache.len() / MIN_NUM_CHUNKS_FOR_COMPACTION);
 745          let mut compacted_roots = vec![F::default(); compaction_max_size];
 746  
 747          #[cfg(not(feature = "serial"))]
 748          let max_threads = alphavm_utilities::parallel::max_available_threads();
 749          #[cfg(feature = "serial")]
 750          let max_threads = 1;
 751  
 752          let mut gap = 1;
 753          while gap < xi.len() {
 754              // each butterfly cluster uses 2*gap positions
 755              let chunk_size = 2 * gap;
 756              let num_chunks = xi.len() / chunk_size;
 757  
 758              // Only compact roots to achieve cache locality/compactness if
 759              // the roots lookup is done a significant amount of times
 760              // Which also implies a large lookup stride.
 761              let (roots, step) = if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION && gap < xi.len() / 2 {
 762                  cfg_iter_mut!(compacted_roots[..gap])
 763                      .zip(cfg_iter!(roots_cache[..(gap * num_chunks)]).step_by(num_chunks))
 764                      .for_each(|(a, b)| *a = *b);
 765                  (&compacted_roots[..gap], 1)
 766              } else {
 767                  (roots_cache, num_chunks)
 768              };
 769  
 770              Self::apply_butterfly(Self::butterfly_fn_oi, xi, roots, step, chunk_size, num_chunks, max_threads, gap);
 771  
 772              gap *= 2;
 773          }
 774      }
 775  }
 776  
 777  /// The minimum number of chunks at which root compaction
 778  /// is beneficial.
 779  const MIN_NUM_CHUNKS_FOR_COMPACTION: usize = 1 << 7;
 780  
 781  /// The minimum size of a chunk at which parallelization of `butterfly`s is
 782  /// beneficial. This value was chosen empirically.
 783  const MIN_GAP_SIZE_FOR_PARALLELISATION: usize = 1 << 10;
 784  
 785  // minimum size at which to parallelize.
 786  #[cfg(not(feature = "serial"))]
 787  const LOG_ROOTS_OF_UNITY_PARALLEL_SIZE: u32 = 7;
 788  
 789  #[inline]
 790  pub(super) fn bitrev(a: u64, log_len: u32) -> u64 {
 791      a.reverse_bits() >> (64 - log_len)
 792  }
 793  
 794  pub(crate) fn derange<T>(xi: &mut [T]) {
 795      derange_helper(xi, log2(xi.len()))
 796  }
 797  
 798  fn derange_helper<T>(xi: &mut [T], log_len: u32) {
 799      for idx in 1..(xi.len() as u64 - 1) {
 800          let ridx = bitrev(idx, log_len);
 801          if idx < ridx {
 802              xi.swap(idx as usize, ridx as usize);
 803          }
 804      }
 805  }
 806  
 807  #[derive(PartialEq, Eq, Debug)]
 808  enum FFTOrder {
 809      /// Both the input and the output of the FFT must be in-order.
 810      II,
 811      /// The input of the FFT must be in-order, but the output does not have to
 812      /// be.
 813      IO,
 814      /// The input of the FFT can be out of order, but the output must be
 815      /// in-order.
 816      OI,
 817  }
 818  
 819  pub(crate) fn compute_powers_serial<F: Field>(size: usize, root: F) -> Vec<F> {
 820      compute_powers_and_mul_by_const_serial(size, root, F::one())
 821  }
 822  
 823  pub(crate) fn compute_powers_and_mul_by_const_serial<F: Field>(size: usize, root: F, c: F) -> Vec<F> {
 824      let mut value = c;
 825      (0..size)
 826          .map(|_| {
 827              let old_value = value;
 828              value *= root;
 829              old_value
 830          })
 831          .collect()
 832  }
 833  
 834  #[allow(unused)]
 835  #[cfg(not(feature = "serial"))]
 836  pub(crate) fn compute_powers<F: Field>(size: usize, g: F) -> Vec<F> {
 837      if size < MIN_PARALLEL_CHUNK_SIZE {
 838          return compute_powers_serial(size, g);
 839      }
 840      // compute the number of threads we will be using.
 841      let num_cpus_available = max_available_threads();
 842      let num_elem_per_thread = core::cmp::max(size / num_cpus_available, MIN_PARALLEL_CHUNK_SIZE);
 843      let num_cpus_used = size / num_elem_per_thread;
 844  
 845      // Split up the powers to compute across each thread evenly.
 846      let res: Vec<F> = (0..num_cpus_used)
 847          .into_par_iter()
 848          .flat_map(|i| {
 849              let offset = g.pow([(i * num_elem_per_thread) as u64]);
 850              // Compute the size that this chunks' output should be
 851              // (num_elem_per_thread, unless there are less than num_elem_per_thread elements
 852              // remaining)
 853              let num_elements_to_compute = core::cmp::min(size - i * num_elem_per_thread, num_elem_per_thread);
 854              compute_powers_and_mul_by_const_serial(num_elements_to_compute, g, offset)
 855          })
 856          .collect();
 857      res
 858  }
 859  
 860  /// An iterator over the elements of the domain.
 861  #[derive(Clone)]
 862  pub struct Elements<F: FftField> {
 863      cur_elem: F,
 864      cur_pow: u64,
 865      domain: EvaluationDomain<F>,
 866  }
 867  
 868  impl<F: FftField> Iterator for Elements<F> {
 869      type Item = F;
 870  
 871      fn next(&mut self) -> Option<F> {
 872          if self.cur_pow == self.domain.size {
 873              None
 874          } else {
 875              let cur_elem = self.cur_elem;
 876              self.cur_elem *= &self.domain.group_gen;
 877              self.cur_pow += 1;
 878              Some(cur_elem)
 879          }
 880      }
 881  }
 882  
 883  /// An iterator over the elements of the domain.
 884  #[derive(Clone, Eq, PartialEq, Debug, CanonicalDeserialize, CanonicalSerialize)]
 885  pub struct FFTPrecomputation<F: FftField> {
 886      roots: Vec<F>,
 887      domain: EvaluationDomain<F>,
 888  }
 889  
 890  impl<F: FftField> FFTPrecomputation<F> {
 891      pub fn to_ifft_precomputation(&self) -> IFFTPrecomputation<F> {
 892          let mut inverse_roots = self.roots.clone();
 893          alphavm_fields::batch_inversion(&mut inverse_roots);
 894          IFFTPrecomputation { inverse_roots, domain: self.domain }
 895      }
 896  
 897      pub fn precomputation_for_subdomain<'a>(&'a self, domain: &EvaluationDomain<F>) -> Option<Cow<'a, Self>> {
 898          if domain.size() == 1 {
 899              return Some(Cow::Owned(Self { roots: vec![], domain: *domain }));
 900          }
 901          if &self.domain == domain {
 902              Some(Cow::Borrowed(self))
 903          } else if domain.size() < self.domain.size() {
 904              let size_ratio = self.domain.size() / domain.size();
 905              let roots = self.roots.iter().step_by(size_ratio).copied().collect();
 906              Some(Cow::Owned(Self { roots, domain: *domain }))
 907          } else {
 908              None
 909          }
 910      }
 911  }
 912  
 913  /// An iterator over the elements of the domain.
 914  #[derive(Clone, Eq, PartialEq, Debug, CanonicalSerialize, CanonicalDeserialize)]
 915  pub struct IFFTPrecomputation<F: FftField> {
 916      inverse_roots: Vec<F>,
 917      domain: EvaluationDomain<F>,
 918  }
 919  
 920  impl<F: FftField> IFFTPrecomputation<F> {
 921      pub fn precomputation_for_subdomain<'a>(&'a self, domain: &EvaluationDomain<F>) -> Option<Cow<'a, Self>> {
 922          if domain.size() == 1 {
 923              return Some(Cow::Owned(Self { inverse_roots: vec![], domain: *domain }));
 924          }
 925          if &self.domain == domain {
 926              Some(Cow::Borrowed(self))
 927          } else if domain.size() < self.domain.size() {
 928              let size_ratio = self.domain.size() / domain.size();
 929              let inverse_roots = self.inverse_roots.iter().step_by(size_ratio).copied().collect();
 930              Some(Cow::Owned(Self { inverse_roots, domain: *domain }))
 931          } else {
 932              None
 933          }
 934      }
 935  }
 936  
 937  #[cfg(test)]
 938  mod tests {
 939      #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 940      use crate::fft::domain::FFTOrder;
 941      use crate::fft::{DensePolynomial, EvaluationDomain};
 942      use alphavm_curves::bls12_377::Fr;
 943      use alphavm_fields::{FftField, Field, One, Zero};
 944      use alphavm_utilities::{TestRng, Uniform};
 945      use rand::Rng;
 946  
 947      #[test]
 948      fn vanishing_polynomial_evaluation() {
 949          let rng = &mut TestRng::default();
 950          for coeffs in 0..10 {
 951              let domain = EvaluationDomain::<Fr>::new(coeffs).unwrap();
 952              let z = domain.vanishing_polynomial();
 953              for _ in 0..100 {
 954                  let point = rng.r#gen();
 955                  assert_eq!(z.evaluate(point), domain.evaluate_vanishing_polynomial(point))
 956              }
 957          }
 958      }
 959  
 960      #[test]
 961      fn vanishing_polynomial_vanishes_on_domain() {
 962          for coeffs in 0..1000 {
 963              let domain = EvaluationDomain::<Fr>::new(coeffs).unwrap();
 964              let z = domain.vanishing_polynomial();
 965              for point in domain.elements() {
 966                  assert!(z.evaluate(point).is_zero())
 967              }
 968          }
 969      }
 970  
 971      #[test]
 972      fn size_of_elements() {
 973          for coeffs in 1..10 {
 974              let size = 1 << coeffs;
 975              let domain = EvaluationDomain::<Fr>::new(size).unwrap();
 976              let domain_size = domain.size();
 977              assert_eq!(domain_size, domain.elements().count());
 978          }
 979      }
 980  
 981      #[test]
 982      fn elements_contents() {
 983          for coeffs in 1..10 {
 984              let size = 1 << coeffs;
 985              let domain = EvaluationDomain::<Fr>::new(size).unwrap();
 986              for (i, element) in domain.elements().enumerate() {
 987                  assert_eq!(element, domain.group_gen.pow([i as u64]));
 988              }
 989          }
 990      }
 991  
 992      /// Test that lagrange interpolation for a random polynomial at a random
 993      /// point works.
 994      #[test]
 995      fn non_systematic_lagrange_coefficients_test() {
 996          let mut rng = TestRng::default();
 997          for domain_dimension in 1..10 {
 998              let domain_size = 1 << domain_dimension;
 999              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1000              // Get random point & lagrange coefficients
1001              let random_point = Fr::rand(&mut rng);
1002              let lagrange_coefficients = domain.evaluate_all_lagrange_coefficients(random_point);
1003  
1004              // Sample the random polynomial, evaluate it over the domain and the random
1005              // point.
1006              let random_polynomial = DensePolynomial::<Fr>::rand(domain_size - 1, &mut rng);
1007              let polynomial_evaluations = domain.fft(random_polynomial.coeffs());
1008              let actual_evaluations = random_polynomial.evaluate(random_point);
1009  
1010              // Do lagrange interpolation, and compare against the actual evaluation
1011              let mut interpolated_evaluation = Fr::zero();
1012              for i in 0..domain_size {
1013                  interpolated_evaluation += lagrange_coefficients[i] * polynomial_evaluations[i];
1014              }
1015              assert_eq!(actual_evaluations, interpolated_evaluation);
1016          }
1017      }
1018  
1019      /// Test that lagrange coefficients for a point in the domain is correct.
1020      #[test]
1021      fn systematic_lagrange_coefficients_test() {
1022          // This runs in time O(N^2) in the domain size, so keep the domain dimension
1023          // low. We generate lagrange coefficients for each element in the
1024          // domain.
1025          for domain_dimension in 1..5 {
1026              let domain_size = 1 << domain_dimension;
1027              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1028              let all_domain_elements: Vec<Fr> = domain.elements().collect();
1029              for (i, domain_element) in all_domain_elements.iter().enumerate().take(domain_size) {
1030                  let lagrange_coefficients = domain.evaluate_all_lagrange_coefficients(*domain_element);
1031                  for (j, lagrange_coefficient) in lagrange_coefficients.iter().enumerate().take(domain_size) {
1032                      // Lagrange coefficient for the evaluation point, which should be 1
1033                      if i == j {
1034                          assert_eq!(*lagrange_coefficient, Fr::one());
1035                      } else {
1036                          assert_eq!(*lagrange_coefficient, Fr::zero());
1037                      }
1038                  }
1039              }
1040          }
1041      }
1042  
1043      /// Tests that the roots of unity result is the same as domain.elements().
1044      #[test]
1045      fn test_roots_of_unity() {
1046          let max_degree = 10;
1047          for log_domain_size in 0..max_degree {
1048              let domain_size = 1 << log_domain_size;
1049              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1050              let actual_roots = domain.roots_of_unity(domain.group_gen);
1051              for &value in &actual_roots {
1052                  assert!(domain.evaluate_vanishing_polynomial(value).is_zero());
1053              }
1054              let expected_roots_elements = domain.elements();
1055              for (expected, &actual) in expected_roots_elements.zip(&actual_roots) {
1056                  assert_eq!(expected, actual);
1057              }
1058              assert_eq!(actual_roots.len(), domain_size / 2);
1059          }
1060      }
1061  
1062      /// Tests that the FFTs output the correct result.
1063      #[test]
1064      fn test_fft_correctness() {
1065          // This assumes a correct polynomial evaluation at point procedure.
1066          // It tests consistency of FFT/IFFT, and coset_fft/coset_ifft,
1067          // along with testing that each individual evaluation is correct.
1068  
1069          let mut rng = TestRng::default();
1070  
1071          // Runs in time O(degree^2)
1072          let log_degree = 5;
1073          let degree = 1 << log_degree;
1074          let random_polynomial = DensePolynomial::<Fr>::rand(degree - 1, &mut rng);
1075  
1076          for log_domain_size in log_degree..(log_degree + 2) {
1077              let domain_size = 1 << log_domain_size;
1078              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1079              let polynomial_evaluations = domain.fft(&random_polynomial.coeffs);
1080              let polynomial_coset_evaluations = domain.coset_fft(&random_polynomial.coeffs);
1081              for (i, x) in domain.elements().enumerate() {
1082                  let coset_x = Fr::multiplicative_generator() * x;
1083  
1084                  assert_eq!(polynomial_evaluations[i], random_polynomial.evaluate(x));
1085                  assert_eq!(polynomial_coset_evaluations[i], random_polynomial.evaluate(coset_x));
1086              }
1087  
1088              let randon_polynomial_from_subgroup =
1089                  DensePolynomial::from_coefficients_vec(domain.ifft(&polynomial_evaluations));
1090              let random_polynomial_from_coset =
1091                  DensePolynomial::from_coefficients_vec(domain.coset_ifft(&polynomial_coset_evaluations));
1092  
1093              assert_eq!(
1094                  random_polynomial, randon_polynomial_from_subgroup,
1095                  "degree = {degree}, domain size = {domain_size}"
1096              );
1097              assert_eq!(
1098                  random_polynomial, random_polynomial_from_coset,
1099                  "degree = {degree}, domain size = {domain_size}"
1100              );
1101          }
1102      }
1103  
1104      /// Tests that FFT precomputation is correctly subdomained
1105      #[test]
1106      fn test_fft_precomputation() {
1107          for i in 1..10 {
1108              let big_domain = EvaluationDomain::<Fr>::new(i).unwrap();
1109              let pc = big_domain.precompute_fft();
1110              for j in 1..i {
1111                  let small_domain = EvaluationDomain::<Fr>::new(j).unwrap();
1112                  let small_pc = small_domain.precompute_fft();
1113                  assert_eq!(pc.precomputation_for_subdomain(&small_domain).unwrap().as_ref(), &small_pc);
1114              }
1115          }
1116      }
1117  
1118      /// Tests that IFFT precomputation is correctly subdomained
1119      #[test]
1120      fn test_ifft_precomputation() {
1121          for i in 1..10 {
1122              let big_domain = EvaluationDomain::<Fr>::new(i).unwrap();
1123              let pc = big_domain.precompute_ifft();
1124              for j in 1..i {
1125                  let small_domain = EvaluationDomain::<Fr>::new(j).unwrap();
1126                  let small_pc = small_domain.precompute_ifft();
1127                  assert_eq!(pc.precomputation_for_subdomain(&small_domain).unwrap().as_ref(), &small_pc);
1128              }
1129          }
1130      }
1131  
1132      /// Tests that IFFT precomputation can be correctly computed from
1133      /// FFT precomputation
1134      #[test]
1135      fn test_ifft_precomputation_from_fft() {
1136          for i in 1..10 {
1137              let domain = EvaluationDomain::<Fr>::new(i).unwrap();
1138              let pc = domain.precompute_ifft();
1139              let fft_pc = domain.precompute_fft();
1140              assert_eq!(pc, fft_pc.to_ifft_precomputation())
1141          }
1142      }
1143  
1144      /// Tests that the FFTs output the correct result.
1145      #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
1146      #[test]
1147      fn test_fft_correctness_cuda() {
1148          let mut rng = TestRng::default();
1149          for log_domain in 2..20 {
1150              println!("Testing domain size {log_domain}");
1151              let domain_size = 1 << log_domain;
1152              let random_polynomial = DensePolynomial::<Fr>::rand(domain_size - 1, &mut rng);
1153              let mut polynomial_evaluations = random_polynomial.coeffs.clone();
1154              let mut polynomial_evaluations_cuda = random_polynomial.coeffs.clone();
1155  
1156              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1157              let pc = domain.precompute_fft();
1158              domain.fft_helper_in_place_with_pc(&mut polynomial_evaluations, FFTOrder::II, &pc);
1159  
1160              if alphavm_algorithms_cuda::NTT::<Fr>(
1161                  domain_size,
1162                  &mut polynomial_evaluations_cuda,
1163                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
1164                  alphavm_algorithms_cuda::NTTDirection::Forward,
1165                  alphavm_algorithms_cuda::NTTType::Standard,
1166              )
1167              .is_err()
1168              {
1169                  println!("cuda error!");
1170              }
1171  
1172              assert_eq!(polynomial_evaluations, polynomial_evaluations_cuda, "domain size = {domain_size}");
1173  
1174              // iNTT
1175              if alphavm_algorithms_cuda::NTT::<Fr>(
1176                  domain_size,
1177                  &mut polynomial_evaluations_cuda,
1178                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
1179                  alphavm_algorithms_cuda::NTTDirection::Inverse,
1180                  alphavm_algorithms_cuda::NTTType::Standard,
1181              )
1182              .is_err()
1183              {
1184                  println!("cuda error!");
1185              }
1186              assert_eq!(random_polynomial.coeffs, polynomial_evaluations_cuda, "domain size = {domain_size}");
1187  
1188              // Coset NTT
1189              polynomial_evaluations = random_polynomial.coeffs.clone();
1190              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1191              let pc = domain.precompute_fft();
1192              EvaluationDomain::<Fr>::distribute_powers(&mut polynomial_evaluations, Fr::multiplicative_generator());
1193              domain.fft_helper_in_place_with_pc(&mut polynomial_evaluations, FFTOrder::II, &pc);
1194  
1195              if alphavm_algorithms_cuda::NTT::<Fr>(
1196                  domain_size,
1197                  &mut polynomial_evaluations_cuda,
1198                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
1199                  alphavm_algorithms_cuda::NTTDirection::Forward,
1200                  alphavm_algorithms_cuda::NTTType::Coset,
1201              )
1202              .is_err()
1203              {
1204                  println!("cuda error!");
1205              }
1206  
1207              assert_eq!(polynomial_evaluations, polynomial_evaluations_cuda, "domain size = {domain_size}");
1208  
1209              // Coset iNTT
1210              if alphavm_algorithms_cuda::NTT::<Fr>(
1211                  domain_size,
1212                  &mut polynomial_evaluations_cuda,
1213                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
1214                  alphavm_algorithms_cuda::NTTDirection::Inverse,
1215                  alphavm_algorithms_cuda::NTTType::Coset,
1216              )
1217              .is_err()
1218              {
1219                  println!("cuda error!");
1220              }
1221              assert_eq!(random_polynomial.coeffs, polynomial_evaluations_cuda, "domain size = {domain_size}");
1222          }
1223      }
1224  }