/ algorithms / src / fft / domain.rs
domain.rs
   1  // Copyright (c) 2025-2026 ACDC Network
   2  // This file is part of the alphavm library.
   3  //
   4  // Alpha Chain | Delta Chain Protocol
   5  // International Monetary Graphite.
   6  //
   7  // Derived from Aleo (https://aleo.org) and ProvableHQ (https://provable.com).
   8  // They built world-class ZK infrastructure. We installed the EASY button.
   9  // Their cryptography: elegant. Our modifications: bureaucracy-compatible.
  10  // Original brilliance: theirs. Robert's Rules: ours. Bugs: definitely ours.
  11  //
  12  // Original Aleo/ProvableHQ code subject to Apache 2.0 https://www.apache.org/licenses/LICENSE-2.0
  13  // All modifications and new work: CC0 1.0 Universal Public Domain Dedication.
  14  // No rights reserved. No permission required. No warranty. No refunds.
  15  //
  16  // https://creativecommons.org/publicdomain/zero/1.0/
  17  // SPDX-License-Identifier: CC0-1.0
  18  
  19  //! This module contains an `EvaluationDomain` abstraction for
  20  //! performing various kinds of polynomial arithmetic on top of
  21  //! the scalar field.
  22  //!
  23  //! In pairing-based SNARKs like GM17, we need to calculate
  24  //! a quotient polynomial over a target polynomial with roots
  25  //! at distinct points associated with each constraint of the
  26  //! constraint system. In order to be efficient, we choose these
  27  //! roots to be the powers of a 2^n root of unity in the field.
  28  //! This allows us to perform polynomial operations in O(n)
  29  //! by performing an O(n log n) FFT over such a domain.
  30  
  31  use crate::{
  32      cfg_chunks_mut,
  33      cfg_into_iter,
  34      cfg_iter,
  35      cfg_iter_mut,
  36      fft::{DomainCoeff, SparsePolynomial},
  37  };
  38  use alphavm_fields::{batch_inversion, FftField, FftParameters, Field};
  39  #[cfg(not(feature = "serial"))]
  40  use alphavm_utilities::max_available_threads;
  41  use alphavm_utilities::{execute_with_max_available_threads, serialize::*};
  42  
  43  use rand::Rng;
  44  use std::{borrow::Cow, fmt};
  45  
  46  use anyhow::{ensure, Result};
  47  
  48  #[cfg(not(feature = "serial"))]
  49  use rayon::prelude::*;
  50  
  51  #[cfg(feature = "serial")]
  52  use itertools::Itertools;
  53  
  54  /// Returns the ceiling of the base-2 logarithm of `x`.
  55  ///
  56  /// ```
  57  /// use alphavm_algorithms::fft::domain::log2;
  58  ///
  59  /// assert_eq!(log2(16), 4);
  60  /// assert_eq!(log2(17), 5);
  61  /// assert_eq!(log2(1), 0);
  62  /// assert_eq!(log2(0), 0);
  63  /// assert_eq!(log2(usize::MAX), (core::mem::size_of::<usize>() * 8) as u32);
  64  /// assert_eq!(log2(1 << 15), 15);
  65  /// assert_eq!(log2(2usize.pow(18)), 18);
  66  /// ```
  67  pub fn log2(x: usize) -> u32 {
  68      if x == 0 {
  69          0
  70      } else if x.is_power_of_two() {
  71          1usize.leading_zeros() - x.leading_zeros()
  72      } else {
  73          0usize.leading_zeros() - x.leading_zeros()
  74      }
  75  }
  76  
  77  // minimum size of a parallelized chunk
  78  #[allow(unused)]
  79  #[cfg(not(feature = "serial"))]
  80  const MIN_PARALLEL_CHUNK_SIZE: usize = 1 << 7;
  81  
  82  /// Defines a domain over which finite field (I)FFTs can be performed. Works
  83  /// only for fields that have a large multiplicative subgroup of size that is
  84  /// a power-of-2.
  85  #[derive(Copy, Clone, Hash, Eq, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
  86  pub struct EvaluationDomain<F: FftField> {
  87      /// The size of the domain.
  88      pub size: u64,
  89      /// `log_2(self.size)`.
  90      pub log_size_of_group: u32,
  91      /// Size of the domain as a field element.
  92      pub size_as_field_element: F,
  93      /// Inverse of the size in the field.
  94      pub size_inv: F,
  95      /// A generator of the subgroup.
  96      pub group_gen: F,
  97      /// Inverse of the generator of the subgroup.
  98      pub group_gen_inv: F,
  99      /// Inverse of the multiplicative generator of the finite field.
 100      pub generator_inv: F,
 101  }
 102  
 103  impl<F: FftField> fmt::Debug for EvaluationDomain<F> {
 104      fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 105          write!(f, "Multiplicative subgroup of size {}", self.size)
 106      }
 107  }
 108  
 109  impl<F: FftField> EvaluationDomain<F> {
 110      /// Sample an element that is *not* in the domain.
 111      pub fn sample_element_outside_domain<R: Rng>(&self, rng: &mut R) -> F {
 112          let mut t = F::rand(rng);
 113          while self.evaluate_vanishing_polynomial(t).is_zero() {
 114              t = F::rand(rng);
 115          }
 116          t
 117      }
 118  
 119      /// Construct a domain that is large enough for evaluations of a polynomial
 120      /// having `num_coeffs` coefficients.
 121      pub fn new(num_coeffs: usize) -> Option<Self> {
 122          // Compute the size of our evaluation domain
 123          let size = num_coeffs.checked_next_power_of_two()? as u64;
 124          let log_size_of_group = size.trailing_zeros();
 125  
 126          // libfqfft uses > https://github.com/scipr-lab/libfqfft/blob/e0183b2cef7d4c5deb21a6eaf3fe3b586d738fe0/libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc#L33
 127          if log_size_of_group > F::FftParameters::TWO_ADICITY {
 128              return None;
 129          }
 130  
 131          // Compute the generator for the multiplicative subgroup.
 132          // It should be the 2^(log_size_of_group) root of unity.
 133          let group_gen = F::get_root_of_unity(size as usize)?;
 134  
 135          // Check that it is indeed the 2^(log_size_of_group) root of unity.
 136          debug_assert_eq!(group_gen.pow([size]), F::one());
 137  
 138          let size_as_field_element = F::from(size);
 139          let size_inv = size_as_field_element.inverse()?;
 140  
 141          Some(EvaluationDomain {
 142              size,
 143              log_size_of_group,
 144              size_as_field_element,
 145              size_inv,
 146              group_gen,
 147              group_gen_inv: group_gen.inverse()?,
 148              generator_inv: F::multiplicative_generator().inverse()?,
 149          })
 150      }
 151  
 152      /// Return the size of a domain that is large enough for evaluations of a
 153      /// polynomial having `num_coeffs` coefficients.
 154      pub fn compute_size_of_domain(num_coeffs: usize) -> Option<usize> {
 155          let size = num_coeffs.checked_next_power_of_two()?;
 156          if size.trailing_zeros() <= F::FftParameters::TWO_ADICITY {
 157              Some(size)
 158          } else {
 159              None
 160          }
 161      }
 162  
 163      /// Return the size of `self`.
 164      pub fn size(&self) -> usize {
 165          self.size as usize
 166      }
 167  
 168      /// Compute an FFT.
 169      pub fn fft<T: DomainCoeff<F>>(&self, coeffs: &[T]) -> Vec<T> {
 170          let mut coeffs = coeffs.to_vec();
 171          self.fft_in_place(&mut coeffs);
 172          coeffs
 173      }
 174  
 175      /// Compute an FFT, modifying the vector in place.
 176      pub fn fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
 177          execute_with_max_available_threads(|| {
 178              coeffs.resize(self.size(), T::zero());
 179              self.in_order_fft_in_place(&mut *coeffs);
 180          });
 181      }
 182  
 183      /// Compute an IFFT.
 184      pub fn ifft<T: DomainCoeff<F>>(&self, evals: &[T]) -> Vec<T> {
 185          let mut evals = evals.to_vec();
 186          self.ifft_in_place(&mut evals);
 187          evals
 188      }
 189  
 190      /// Compute an IFFT, modifying the vector in place.
 191      #[inline]
 192      pub fn ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
 193          execute_with_max_available_threads(|| {
 194              evals.resize(self.size(), T::zero());
 195              self.in_order_ifft_in_place(&mut *evals);
 196          });
 197      }
 198  
 199      /// Compute an FFT over a coset of the domain.
 200      pub fn coset_fft<T: DomainCoeff<F>>(&self, coeffs: &[T]) -> Vec<T> {
 201          let mut coeffs = coeffs.to_vec();
 202          self.coset_fft_in_place(&mut coeffs);
 203          coeffs
 204      }
 205  
 206      /// Compute an FFT over a coset of the domain, modifying the input vector
 207      /// in place.
 208      pub fn coset_fft_in_place<T: DomainCoeff<F>>(&self, coeffs: &mut Vec<T>) {
 209          execute_with_max_available_threads(|| {
 210              Self::distribute_powers(coeffs, F::multiplicative_generator());
 211              self.fft_in_place(coeffs);
 212          });
 213      }
 214  
 215      /// Compute an IFFT over a coset of the domain.
 216      pub fn coset_ifft<T: DomainCoeff<F>>(&self, evals: &[T]) -> Vec<T> {
 217          let mut evals = evals.to_vec();
 218          self.coset_ifft_in_place(&mut evals);
 219          evals
 220      }
 221  
 222      /// Compute an IFFT over a coset of the domain, modifying the input vector
 223      /// in place.
 224      pub fn coset_ifft_in_place<T: DomainCoeff<F>>(&self, evals: &mut Vec<T>) {
 225          execute_with_max_available_threads(|| {
 226              evals.resize(self.size(), T::zero());
 227              self.in_order_coset_ifft_in_place(&mut *evals);
 228          });
 229      }
 230  
 231      /// Multiply the `i`-th element of `coeffs` with `g^i`.
 232      fn distribute_powers<T: DomainCoeff<F>>(coeffs: &mut [T], g: F) {
 233          Self::distribute_powers_and_mul_by_const(coeffs, g, F::one());
 234      }
 235  
 236      /// Multiply the `i`-th element of `coeffs` with `c*g^i`.
 237      #[cfg(feature = "serial")]
 238      fn distribute_powers_and_mul_by_const<T: DomainCoeff<F>>(coeffs: &mut [T], g: F, c: F) {
 239          // invariant: pow = c*g^i at the ith iteration of the loop
 240          let mut pow = c;
 241          coeffs.iter_mut().for_each(|coeff| {
 242              *coeff *= pow;
 243              pow *= &g
 244          })
 245      }
 246  
 247      /// Multiply the `i`-th element of `coeffs` with `c*g^i`.
 248      #[cfg(not(feature = "serial"))]
 249      fn distribute_powers_and_mul_by_const<T: DomainCoeff<F>>(coeffs: &mut [T], g: F, c: F) {
 250          let min_parallel_chunk_size = 1024;
 251          let num_cpus_available = max_available_threads();
 252          let num_elem_per_thread = core::cmp::max(coeffs.len() / num_cpus_available, min_parallel_chunk_size);
 253  
 254          cfg_chunks_mut!(coeffs, num_elem_per_thread).enumerate().for_each(|(i, chunk)| {
 255              let offset = c * g.pow([(i * num_elem_per_thread) as u64]);
 256              let mut pow = offset;
 257              chunk.iter_mut().for_each(|coeff| {
 258                  *coeff *= pow;
 259                  pow *= &g
 260              })
 261          });
 262      }
 263  
 264      /// Evaluate all the lagrange polynomials defined by this domain at the
 265      /// point `tau`.
 266      pub fn evaluate_all_lagrange_coefficients(&self, tau: F) -> Vec<F> {
 267          // Evaluate all Lagrange polynomials
 268          let size = self.size as usize;
 269          let t_size = tau.pow([self.size]);
 270          let one = F::one();
 271          if t_size.is_one() {
 272              let mut u = vec![F::zero(); size];
 273              let mut omega_i = one;
 274              for x in u.iter_mut().take(size) {
 275                  if omega_i == tau {
 276                      *x = one;
 277                      break;
 278                  }
 279                  omega_i *= &self.group_gen;
 280              }
 281              u
 282          } else {
 283              let mut l = (t_size - one) * self.size_inv;
 284              let mut r = one;
 285              let mut u = vec![F::zero(); size];
 286              let mut ls = vec![F::zero(); size];
 287              for i in 0..size {
 288                  u[i] = tau - r;
 289                  ls[i] = l;
 290                  l *= &self.group_gen;
 291                  r *= &self.group_gen;
 292              }
 293  
 294              batch_inversion(u.as_mut_slice());
 295              cfg_iter_mut!(u).zip_eq(ls).for_each(|(tau_minus_r, l)| {
 296                  *tau_minus_r = l * *tau_minus_r;
 297              });
 298              u
 299          }
 300      }
 301  
 302      /// Return the sparse vanishing polynomial.
 303      pub fn vanishing_polynomial(&self) -> SparsePolynomial<F> {
 304          let coeffs = [(0, -F::one()), (self.size(), F::one())];
 305          SparsePolynomial::from_coefficients(coeffs)
 306      }
 307  
 308      /// This evaluates the vanishing polynomial for this domain at tau.
 309      /// For multiplicative subgroups, this polynomial is `z(X) = X^self.size -
 310      /// 1`.
 311      pub fn evaluate_vanishing_polynomial(&self, tau: F) -> F {
 312          tau.pow([self.size]) - F::one()
 313      }
 314  
 315      /// Return an iterator over the elements of the domain.
 316      pub fn elements(&self) -> Elements<F> {
 317          Elements { cur_elem: F::one(), cur_pow: 0, domain: *self }
 318      }
 319  
 320      /// The target polynomial is the zero polynomial in our
 321      /// evaluation domain, so we must perform division over
 322      /// a coset.
 323      pub fn divide_by_vanishing_poly_on_coset_in_place(&self, evals: &mut [F]) {
 324          let i = self.evaluate_vanishing_polynomial(F::multiplicative_generator()).inverse().unwrap();
 325  
 326          cfg_iter_mut!(evals).for_each(|eval| *eval *= &i);
 327      }
 328  
 329      /// Given an index in the `other` subdomain, return an index into this
 330      /// domain `self` This assumes the `other`'s elements are also `self`'s
 331      /// first elements
 332      pub fn reindex_by_subdomain(&self, other: &Self, index: usize) -> Result<usize> {
 333          ensure!(self.size() > other.size(), "other.size() must be smaller than self.size()");
 334  
 335          // Let this subgroup be G, and the subgroup we're re-indexing by be S.
 336          // Since its a subgroup, the 0th element of S is at index 0 in G, the first
 337          // element of S is at index |G|/|S|, the second at 2*|G|/|S|, etc.
 338          // Thus for an index i that corresponds to S, the index in G is i*|G|/|S|
 339          let period = self.size() / other.size();
 340          if index < other.size() {
 341              Ok(index * period)
 342          } else {
 343              // Let i now be the index of this element in G \ S
 344              // Let x be the number of elements in G \ S, for every element in S. Then x =
 345              // (|G|/|S| - 1). At index i in G \ S, the number of elements in S
 346              // that appear before the index in G to which i corresponds to, is
 347              // floor(i / x) + 1. The +1 is because index 0 of G is S_0, so the
 348              // position is offset by at least one. The floor(i / x) term is
 349              // because after x elements in G \ S, there is one more element from S
 350              // that will have appeared in G.
 351              let i = index - other.size();
 352              let x = period - 1;
 353              Ok(i + (i / x) + 1)
 354          }
 355      }
 356  
 357      /// Perform O(n) multiplication of two polynomials that are presented by
 358      /// their evaluations in the domain.
 359      /// Returns the evaluations of the product over the domain.
 360      pub fn mul_polynomials_in_evaluation_domain(&self, self_evals: Vec<F>, other_evals: &[F]) -> Result<Vec<F>> {
 361          let mut result = self_evals;
 362  
 363          ensure!(result.len() == other_evals.len());
 364          cfg_iter_mut!(result).zip_eq(other_evals).for_each(|(a, b)| *a *= b);
 365  
 366          Ok(result)
 367      }
 368  }
 369  
 370  impl<F: FftField> EvaluationDomain<F> {
 371      pub fn precompute_fft(&self) -> FFTPrecomputation<F> {
 372          execute_with_max_available_threads(|| FFTPrecomputation {
 373              roots: self.roots_of_unity(self.group_gen),
 374              domain: *self,
 375          })
 376      }
 377  
 378      pub fn precompute_ifft(&self) -> IFFTPrecomputation<F> {
 379          execute_with_max_available_threads(|| IFFTPrecomputation {
 380              inverse_roots: self.roots_of_unity(self.group_gen_inv),
 381              domain: *self,
 382          })
 383      }
 384  
 385      pub(crate) fn in_order_fft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
 386          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 387          // SNP TODO: how to set threshold and check that the type is Fr
 388          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 389              let result = alphavm_algorithms_cuda::NTT(
 390                  self.size as usize,
 391                  x_s,
 392                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 393                  alphavm_algorithms_cuda::NTTDirection::Forward,
 394                  alphavm_algorithms_cuda::NTTType::Standard,
 395              );
 396              if result.is_ok() {
 397                  return;
 398              }
 399          }
 400  
 401          let pc = self.precompute_fft();
 402          self.fft_helper_in_place_with_pc(x_s, FFTOrder::II, &pc)
 403      }
 404  
 405      pub fn in_order_fft_with_pc<T: DomainCoeff<F>>(&self, x_s: &[T], pc: &FFTPrecomputation<F>) -> Vec<T> {
 406          let mut x_s = x_s.to_vec();
 407          if self.size() != x_s.len() {
 408              x_s.extend(core::iter::repeat_n(T::zero(), self.size() - x_s.len()));
 409          }
 410          self.fft_helper_in_place_with_pc(&mut x_s, FFTOrder::II, pc);
 411          x_s
 412      }
 413  
 414      pub(crate) fn in_order_ifft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
 415          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 416          // SNP TODO: how to set threshold
 417          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 418              let result = alphavm_algorithms_cuda::NTT(
 419                  self.size as usize,
 420                  x_s,
 421                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 422                  alphavm_algorithms_cuda::NTTDirection::Inverse,
 423                  alphavm_algorithms_cuda::NTTType::Standard,
 424              );
 425              if result.is_ok() {
 426                  return;
 427              }
 428          }
 429  
 430          let pc = self.precompute_ifft();
 431          self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, &pc);
 432          cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
 433      }
 434  
 435      pub(crate) fn in_order_coset_ifft_in_place<T: DomainCoeff<F>>(&self, x_s: &mut [T]) {
 436          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 437          // SNP TODO: how to set threshold
 438          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 439              let result = alphavm_algorithms_cuda::NTT(
 440                  self.size as usize,
 441                  x_s,
 442                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 443                  alphavm_algorithms_cuda::NTTDirection::Inverse,
 444                  alphavm_algorithms_cuda::NTTType::Coset,
 445              );
 446              if result.is_ok() {
 447                  return;
 448              }
 449          }
 450  
 451          let pc = self.precompute_ifft();
 452          self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, &pc);
 453          let coset_shift = self.generator_inv;
 454          Self::distribute_powers_and_mul_by_const(x_s, coset_shift, self.size_inv);
 455      }
 456  
 457      #[allow(unused)]
 458      pub(crate) fn in_order_fft_in_place_with_pc<T: DomainCoeff<F>>(
 459          &self,
 460          x_s: &mut [T],
 461          pre_comp: &FFTPrecomputation<F>,
 462      ) {
 463          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 464          // SNP TODO: how to set threshold
 465          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 466              let result = alphavm_algorithms_cuda::NTT(
 467                  self.size as usize,
 468                  x_s,
 469                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 470                  alphavm_algorithms_cuda::NTTDirection::Forward,
 471                  alphavm_algorithms_cuda::NTTType::Standard,
 472              );
 473              if result.is_ok() {
 474                  return;
 475              }
 476          }
 477  
 478          self.fft_helper_in_place_with_pc(x_s, FFTOrder::II, pre_comp)
 479      }
 480  
 481      pub(crate) fn out_order_fft_in_place_with_pc<T: DomainCoeff<F>>(
 482          &self,
 483          x_s: &mut [T],
 484          pre_comp: &FFTPrecomputation<F>,
 485      ) {
 486          self.fft_helper_in_place_with_pc(x_s, FFTOrder::IO, pre_comp)
 487      }
 488  
 489      pub(crate) fn in_order_ifft_in_place_with_pc<T: DomainCoeff<F>>(
 490          &self,
 491          x_s: &mut [T],
 492          pre_comp: &IFFTPrecomputation<F>,
 493      ) {
 494          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 495          // SNP TODO: how to set threshold
 496          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 497              let result = alphavm_algorithms_cuda::NTT(
 498                  self.size as usize,
 499                  x_s,
 500                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 501                  alphavm_algorithms_cuda::NTTDirection::Inverse,
 502                  alphavm_algorithms_cuda::NTTType::Standard,
 503              );
 504              if result.is_ok() {
 505                  return;
 506              }
 507          }
 508  
 509          self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, pre_comp);
 510          cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
 511      }
 512  
 513      pub(crate) fn out_order_ifft_in_place_with_pc<T: DomainCoeff<F>>(
 514          &self,
 515          x_s: &mut [T],
 516          pre_comp: &IFFTPrecomputation<F>,
 517      ) {
 518          self.ifft_helper_in_place_with_pc(x_s, FFTOrder::OI, pre_comp);
 519          cfg_iter_mut!(x_s).for_each(|val| *val *= self.size_inv);
 520      }
 521  
 522      #[allow(unused)]
 523      pub(crate) fn in_order_coset_ifft_in_place_with_pc<T: DomainCoeff<F>>(
 524          &self,
 525          x_s: &mut [T],
 526          pre_comp: &IFFTPrecomputation<F>,
 527      ) {
 528          #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 529          // SNP TODO: how to set threshold
 530          if self.size >= 32 && std::mem::size_of::<T>() == 32 {
 531              let result = alphavm_algorithms_cuda::NTT(
 532                  self.size as usize,
 533                  x_s,
 534                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
 535                  alphavm_algorithms_cuda::NTTDirection::Inverse,
 536                  alphavm_algorithms_cuda::NTTType::Coset,
 537              );
 538              if result.is_ok() {
 539                  return;
 540              }
 541          }
 542  
 543          self.ifft_helper_in_place_with_pc(x_s, FFTOrder::II, pre_comp);
 544          let coset_shift = self.generator_inv;
 545          Self::distribute_powers_and_mul_by_const(x_s, coset_shift, self.size_inv);
 546      }
 547  
 548      fn fft_helper_in_place_with_pc<T: DomainCoeff<F>>(
 549          &self,
 550          x_s: &mut [T],
 551          ord: FFTOrder,
 552          pre_comp: &FFTPrecomputation<F>,
 553      ) {
 554          use FFTOrder::*;
 555          let pc = pre_comp.precomputation_for_subdomain(self).unwrap();
 556  
 557          let log_len = log2(x_s.len());
 558  
 559          if ord == OI {
 560              self.oi_helper_with_roots(x_s, &pc.roots);
 561          } else {
 562              self.io_helper_with_roots(x_s, &pc.roots);
 563          }
 564  
 565          if ord == II {
 566              derange_helper(x_s, log_len);
 567          }
 568      }
 569  
 570      // Handles doing an IFFT with handling of being in order and out of order.
 571      // The results here must all be divided by |x_s|,
 572      // which is left up to the caller to do.
 573      fn ifft_helper_in_place_with_pc<T: DomainCoeff<F>>(
 574          &self,
 575          x_s: &mut [T],
 576          ord: FFTOrder,
 577          pre_comp: &IFFTPrecomputation<F>,
 578      ) {
 579          use FFTOrder::*;
 580          let pc = pre_comp.precomputation_for_subdomain(self).unwrap();
 581  
 582          let log_len = log2(x_s.len());
 583  
 584          if ord == II {
 585              derange_helper(x_s, log_len);
 586          }
 587  
 588          if ord == IO {
 589              self.io_helper_with_roots(x_s, &pc.inverse_roots);
 590          } else {
 591              self.oi_helper_with_roots(x_s, &pc.inverse_roots);
 592          }
 593      }
 594  
 595      /// Computes the first `self.size / 2` roots of unity for the entire domain.
 596      /// e.g. for the domain [1, g, g^2, ..., g^{n - 1}], it computes
 597      // [1, g, g^2, ..., g^{(n/2) - 1}]
 598      #[cfg(feature = "serial")]
 599      pub fn roots_of_unity(&self, root: F) -> Vec<F> {
 600          compute_powers_serial((self.size as usize) / 2, root)
 601      }
 602  
 603      /// Computes the first `self.size / 2` roots of unity.
 604      #[cfg(not(feature = "serial"))]
 605      pub fn roots_of_unity(&self, root: F) -> Vec<F> {
 606          // TODO: check if this method can replace parallel compute powers.
 607          let log_size = log2(self.size as usize);
 608          // early exit for short inputs
 609          if log_size <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE {
 610              compute_powers_serial((self.size as usize) / 2, root)
 611          } else {
 612              let mut temp = root;
 613              // w, w^2, w^4, w^8, ..., w^(2^(log_size - 1))
 614              let log_powers: Vec<F> = (0..(log_size - 1))
 615                  .map(|_| {
 616                      let old_value = temp;
 617                      temp.square_in_place();
 618                      old_value
 619                  })
 620                  .collect();
 621  
 622              // allocate the return array and start the recursion
 623              let mut powers = vec![F::zero(); 1 << (log_size - 1)];
 624              Self::roots_of_unity_recursive(&mut powers, &log_powers);
 625              powers
 626          }
 627      }
 628  
 629      #[cfg(not(feature = "serial"))]
 630      fn roots_of_unity_recursive(out: &mut [F], log_powers: &[F]) {
 631          assert_eq!(out.len(), 1 << log_powers.len());
 632          // base case: just compute the powers sequentially,
 633          // g = log_powers[0], out = [1, g, g^2, ...]
 634          if log_powers.len() <= LOG_ROOTS_OF_UNITY_PARALLEL_SIZE as usize {
 635              out[0] = F::one();
 636              for idx in 1..out.len() {
 637                  out[idx] = out[idx - 1] * log_powers[0];
 638              }
 639              return;
 640          }
 641  
 642          // recursive case:
 643          // 1. split log_powers in half
 644          let (lr_lo, lr_hi) = log_powers.split_at(log_powers.len().div_ceil(2));
 645          let mut scr_lo = vec![F::default(); 1 << lr_lo.len()];
 646          let mut scr_hi = vec![F::default(); 1 << lr_hi.len()];
 647          // 2. compute each half individually
 648          rayon::join(
 649              || Self::roots_of_unity_recursive(&mut scr_lo, lr_lo),
 650              || Self::roots_of_unity_recursive(&mut scr_hi, lr_hi),
 651          );
 652          // 3. recombine halves
 653          // At this point, out is a blank slice.
 654          out.par_chunks_mut(scr_lo.len()).zip(&scr_hi).for_each(|(out_chunk, scr_hi)| {
 655              for (out_elem, scr_lo) in out_chunk.iter_mut().zip(&scr_lo) {
 656                  *out_elem = *scr_hi * scr_lo;
 657              }
 658          });
 659      }
 660  
 661      #[inline(always)]
 662      fn butterfly_fn_io<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
 663          let neg = *lo - *hi;
 664          *lo += *hi;
 665          *hi = neg;
 666          *hi *= *root;
 667      }
 668  
 669      #[inline(always)]
 670      fn butterfly_fn_oi<T: DomainCoeff<F>>(((lo, hi), root): ((&mut T, &mut T), &F)) {
 671          *hi *= *root;
 672          let neg = *lo - *hi;
 673          *lo += *hi;
 674          *hi = neg;
 675      }
 676  
 677      #[allow(clippy::too_many_arguments)]
 678      fn apply_butterfly<T: DomainCoeff<F>, G: Fn(((&mut T, &mut T), &F)) + Copy + Sync + Send>(
 679          g: G,
 680          xi: &mut [T],
 681          roots: &[F],
 682          step: usize,
 683          chunk_size: usize,
 684          num_chunks: usize,
 685          max_threads: usize,
 686          gap: usize,
 687      ) {
 688          cfg_chunks_mut!(xi, chunk_size).for_each(|cxi| {
 689              let (lo, hi) = cxi.split_at_mut(gap);
 690              // If the chunk is sufficiently big that parallelism helps,
 691              // we parallelize the butterfly operation within the chunk.
 692  
 693              if gap > MIN_GAP_SIZE_FOR_PARALLELISATION && num_chunks < max_threads {
 694                  cfg_iter_mut!(lo).zip(hi).zip(cfg_iter!(roots).step_by(step)).for_each(g);
 695              } else {
 696                  lo.iter_mut().zip(hi).zip(roots.iter().step_by(step)).for_each(g);
 697              }
 698          });
 699      }
 700  
 701      #[allow(clippy::unnecessary_to_owned)]
 702      fn io_helper_with_roots<T: DomainCoeff<F>>(&self, xi: &mut [T], roots: &[F]) {
 703          let mut roots = std::borrow::Cow::Borrowed(roots);
 704  
 705          let mut step = 1;
 706          let mut first = true;
 707  
 708          #[cfg(not(feature = "serial"))]
 709          let max_threads = alphavm_utilities::parallel::max_available_threads();
 710          #[cfg(feature = "serial")]
 711          let max_threads = 1;
 712  
 713          let mut gap = xi.len() / 2;
 714          while gap > 0 {
 715              // each butterfly cluster uses 2*gap positions
 716              let chunk_size = 2 * gap;
 717              let num_chunks = xi.len() / chunk_size;
 718  
 719              // Only compact roots to achieve cache locality/compactness if
 720              // the roots lookup is done a significant amount of times
 721              // Which also implies a large lookup stride.
 722              if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION {
 723                  if !first {
 724                      roots = Cow::Owned(cfg_into_iter!(roots.into_owned()).step_by(step * 2).collect());
 725                  }
 726                  step = 1;
 727                  roots.to_mut().shrink_to_fit();
 728              } else {
 729                  step = num_chunks;
 730              }
 731              first = false;
 732  
 733              Self::apply_butterfly(
 734                  Self::butterfly_fn_io,
 735                  xi,
 736                  &roots[..],
 737                  step,
 738                  chunk_size,
 739                  num_chunks,
 740                  max_threads,
 741                  gap,
 742              );
 743  
 744              gap /= 2;
 745          }
 746      }
 747  
 748      fn oi_helper_with_roots<T: DomainCoeff<F>>(&self, xi: &mut [T], roots_cache: &[F]) {
 749          // The `cmp::min` is only necessary for the case where
 750          // `MIN_NUM_CHUNKS_FOR_COMPACTION = 1`. Else, notice that we compact
 751          // the roots cache by a stride of at least `MIN_NUM_CHUNKS_FOR_COMPACTION`.
 752  
 753          let compaction_max_size =
 754              core::cmp::min(roots_cache.len() / 2, roots_cache.len() / MIN_NUM_CHUNKS_FOR_COMPACTION);
 755          let mut compacted_roots = vec![F::default(); compaction_max_size];
 756  
 757          #[cfg(not(feature = "serial"))]
 758          let max_threads = alphavm_utilities::parallel::max_available_threads();
 759          #[cfg(feature = "serial")]
 760          let max_threads = 1;
 761  
 762          let mut gap = 1;
 763          while gap < xi.len() {
 764              // each butterfly cluster uses 2*gap positions
 765              let chunk_size = 2 * gap;
 766              let num_chunks = xi.len() / chunk_size;
 767  
 768              // Only compact roots to achieve cache locality/compactness if
 769              // the roots lookup is done a significant amount of times
 770              // Which also implies a large lookup stride.
 771              let (roots, step) = if num_chunks >= MIN_NUM_CHUNKS_FOR_COMPACTION && gap < xi.len() / 2 {
 772                  cfg_iter_mut!(compacted_roots[..gap])
 773                      .zip(cfg_iter!(roots_cache[..(gap * num_chunks)]).step_by(num_chunks))
 774                      .for_each(|(a, b)| *a = *b);
 775                  (&compacted_roots[..gap], 1)
 776              } else {
 777                  (roots_cache, num_chunks)
 778              };
 779  
 780              Self::apply_butterfly(Self::butterfly_fn_oi, xi, roots, step, chunk_size, num_chunks, max_threads, gap);
 781  
 782              gap *= 2;
 783          }
 784      }
 785  }
 786  
 787  /// The minimum number of chunks at which root compaction
 788  /// is beneficial.
 789  const MIN_NUM_CHUNKS_FOR_COMPACTION: usize = 1 << 7;
 790  
 791  /// The minimum size of a chunk at which parallelization of `butterfly`s is
 792  /// beneficial. This value was chosen empirically.
 793  const MIN_GAP_SIZE_FOR_PARALLELISATION: usize = 1 << 10;
 794  
 795  // minimum size at which to parallelize.
 796  #[cfg(not(feature = "serial"))]
 797  const LOG_ROOTS_OF_UNITY_PARALLEL_SIZE: u32 = 7;
 798  
 799  #[inline]
 800  pub(super) fn bitrev(a: u64, log_len: u32) -> u64 {
 801      a.reverse_bits() >> (64 - log_len)
 802  }
 803  
 804  pub(crate) fn derange<T>(xi: &mut [T]) {
 805      derange_helper(xi, log2(xi.len()))
 806  }
 807  
 808  fn derange_helper<T>(xi: &mut [T], log_len: u32) {
 809      for idx in 1..(xi.len() as u64 - 1) {
 810          let ridx = bitrev(idx, log_len);
 811          if idx < ridx {
 812              xi.swap(idx as usize, ridx as usize);
 813          }
 814      }
 815  }
 816  
 817  #[derive(PartialEq, Eq, Debug)]
 818  enum FFTOrder {
 819      /// Both the input and the output of the FFT must be in-order.
 820      II,
 821      /// The input of the FFT must be in-order, but the output does not have to
 822      /// be.
 823      IO,
 824      /// The input of the FFT can be out of order, but the output must be
 825      /// in-order.
 826      OI,
 827  }
 828  
 829  pub(crate) fn compute_powers_serial<F: Field>(size: usize, root: F) -> Vec<F> {
 830      compute_powers_and_mul_by_const_serial(size, root, F::one())
 831  }
 832  
 833  pub(crate) fn compute_powers_and_mul_by_const_serial<F: Field>(size: usize, root: F, c: F) -> Vec<F> {
 834      let mut value = c;
 835      (0..size)
 836          .map(|_| {
 837              let old_value = value;
 838              value *= root;
 839              old_value
 840          })
 841          .collect()
 842  }
 843  
 844  #[allow(unused)]
 845  #[cfg(not(feature = "serial"))]
 846  pub(crate) fn compute_powers<F: Field>(size: usize, g: F) -> Vec<F> {
 847      if size < MIN_PARALLEL_CHUNK_SIZE {
 848          return compute_powers_serial(size, g);
 849      }
 850      // compute the number of threads we will be using.
 851      let num_cpus_available = max_available_threads();
 852      let num_elem_per_thread = core::cmp::max(size / num_cpus_available, MIN_PARALLEL_CHUNK_SIZE);
 853      let num_cpus_used = size / num_elem_per_thread;
 854  
 855      // Split up the powers to compute across each thread evenly.
 856      let res: Vec<F> = (0..num_cpus_used)
 857          .into_par_iter()
 858          .flat_map(|i| {
 859              let offset = g.pow([(i * num_elem_per_thread) as u64]);
 860              // Compute the size that this chunks' output should be
 861              // (num_elem_per_thread, unless there are less than num_elem_per_thread elements
 862              // remaining)
 863              let num_elements_to_compute = core::cmp::min(size - i * num_elem_per_thread, num_elem_per_thread);
 864              compute_powers_and_mul_by_const_serial(num_elements_to_compute, g, offset)
 865          })
 866          .collect();
 867      res
 868  }
 869  
 870  /// An iterator over the elements of the domain.
 871  #[derive(Clone)]
 872  pub struct Elements<F: FftField> {
 873      cur_elem: F,
 874      cur_pow: u64,
 875      domain: EvaluationDomain<F>,
 876  }
 877  
 878  impl<F: FftField> Iterator for Elements<F> {
 879      type Item = F;
 880  
 881      fn next(&mut self) -> Option<F> {
 882          if self.cur_pow == self.domain.size {
 883              None
 884          } else {
 885              let cur_elem = self.cur_elem;
 886              self.cur_elem *= &self.domain.group_gen;
 887              self.cur_pow += 1;
 888              Some(cur_elem)
 889          }
 890      }
 891  }
 892  
 893  /// An iterator over the elements of the domain.
 894  #[derive(Clone, Eq, PartialEq, Debug, CanonicalDeserialize, CanonicalSerialize)]
 895  pub struct FFTPrecomputation<F: FftField> {
 896      roots: Vec<F>,
 897      domain: EvaluationDomain<F>,
 898  }
 899  
 900  impl<F: FftField> FFTPrecomputation<F> {
 901      pub fn to_ifft_precomputation(&self) -> IFFTPrecomputation<F> {
 902          let mut inverse_roots = self.roots.clone();
 903          alphavm_fields::batch_inversion(&mut inverse_roots);
 904          IFFTPrecomputation { inverse_roots, domain: self.domain }
 905      }
 906  
 907      pub fn precomputation_for_subdomain<'a>(&'a self, domain: &EvaluationDomain<F>) -> Option<Cow<'a, Self>> {
 908          if domain.size() == 1 {
 909              return Some(Cow::Owned(Self { roots: vec![], domain: *domain }));
 910          }
 911          if &self.domain == domain {
 912              Some(Cow::Borrowed(self))
 913          } else if domain.size() < self.domain.size() {
 914              let size_ratio = self.domain.size() / domain.size();
 915              let roots = self.roots.iter().step_by(size_ratio).copied().collect();
 916              Some(Cow::Owned(Self { roots, domain: *domain }))
 917          } else {
 918              None
 919          }
 920      }
 921  }
 922  
 923  /// An iterator over the elements of the domain.
 924  #[derive(Clone, Eq, PartialEq, Debug, CanonicalSerialize, CanonicalDeserialize)]
 925  pub struct IFFTPrecomputation<F: FftField> {
 926      inverse_roots: Vec<F>,
 927      domain: EvaluationDomain<F>,
 928  }
 929  
 930  impl<F: FftField> IFFTPrecomputation<F> {
 931      pub fn precomputation_for_subdomain<'a>(&'a self, domain: &EvaluationDomain<F>) -> Option<Cow<'a, Self>> {
 932          if domain.size() == 1 {
 933              return Some(Cow::Owned(Self { inverse_roots: vec![], domain: *domain }));
 934          }
 935          if &self.domain == domain {
 936              Some(Cow::Borrowed(self))
 937          } else if domain.size() < self.domain.size() {
 938              let size_ratio = self.domain.size() / domain.size();
 939              let inverse_roots = self.inverse_roots.iter().step_by(size_ratio).copied().collect();
 940              Some(Cow::Owned(Self { inverse_roots, domain: *domain }))
 941          } else {
 942              None
 943          }
 944      }
 945  }
 946  
 947  #[cfg(test)]
 948  mod tests {
 949      #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
 950      use crate::fft::domain::FFTOrder;
 951      use crate::fft::{DensePolynomial, EvaluationDomain};
 952      use alphavm_curves::bls12_377::Fr;
 953      use alphavm_fields::{FftField, Field, One, Zero};
 954      use alphavm_utilities::{TestRng, Uniform};
 955      use rand::Rng;
 956  
 957      #[test]
 958      fn vanishing_polynomial_evaluation() {
 959          let rng = &mut TestRng::default();
 960          for coeffs in 0..10 {
 961              let domain = EvaluationDomain::<Fr>::new(coeffs).unwrap();
 962              let z = domain.vanishing_polynomial();
 963              for _ in 0..100 {
 964                  let point = rng.r#gen();
 965                  assert_eq!(z.evaluate(point), domain.evaluate_vanishing_polynomial(point))
 966              }
 967          }
 968      }
 969  
 970      #[test]
 971      fn vanishing_polynomial_vanishes_on_domain() {
 972          for coeffs in 0..1000 {
 973              let domain = EvaluationDomain::<Fr>::new(coeffs).unwrap();
 974              let z = domain.vanishing_polynomial();
 975              for point in domain.elements() {
 976                  assert!(z.evaluate(point).is_zero())
 977              }
 978          }
 979      }
 980  
 981      #[test]
 982      fn size_of_elements() {
 983          for coeffs in 1..10 {
 984              let size = 1 << coeffs;
 985              let domain = EvaluationDomain::<Fr>::new(size).unwrap();
 986              let domain_size = domain.size();
 987              assert_eq!(domain_size, domain.elements().count());
 988          }
 989      }
 990  
 991      #[test]
 992      fn elements_contents() {
 993          for coeffs in 1..10 {
 994              let size = 1 << coeffs;
 995              let domain = EvaluationDomain::<Fr>::new(size).unwrap();
 996              for (i, element) in domain.elements().enumerate() {
 997                  assert_eq!(element, domain.group_gen.pow([i as u64]));
 998              }
 999          }
1000      }
1001  
1002      /// Test that lagrange interpolation for a random polynomial at a random
1003      /// point works.
1004      #[test]
1005      fn non_systematic_lagrange_coefficients_test() {
1006          let mut rng = TestRng::default();
1007          for domain_dimension in 1..10 {
1008              let domain_size = 1 << domain_dimension;
1009              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1010              // Get random point & lagrange coefficients
1011              let random_point = Fr::rand(&mut rng);
1012              let lagrange_coefficients = domain.evaluate_all_lagrange_coefficients(random_point);
1013  
1014              // Sample the random polynomial, evaluate it over the domain and the random
1015              // point.
1016              let random_polynomial = DensePolynomial::<Fr>::rand(domain_size - 1, &mut rng);
1017              let polynomial_evaluations = domain.fft(random_polynomial.coeffs());
1018              let actual_evaluations = random_polynomial.evaluate(random_point);
1019  
1020              // Do lagrange interpolation, and compare against the actual evaluation
1021              let mut interpolated_evaluation = Fr::zero();
1022              for i in 0..domain_size {
1023                  interpolated_evaluation += lagrange_coefficients[i] * polynomial_evaluations[i];
1024              }
1025              assert_eq!(actual_evaluations, interpolated_evaluation);
1026          }
1027      }
1028  
1029      /// Test that lagrange coefficients for a point in the domain is correct.
1030      #[test]
1031      fn systematic_lagrange_coefficients_test() {
1032          // This runs in time O(N^2) in the domain size, so keep the domain dimension
1033          // low. We generate lagrange coefficients for each element in the
1034          // domain.
1035          for domain_dimension in 1..5 {
1036              let domain_size = 1 << domain_dimension;
1037              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1038              let all_domain_elements: Vec<Fr> = domain.elements().collect();
1039              for (i, domain_element) in all_domain_elements.iter().enumerate().take(domain_size) {
1040                  let lagrange_coefficients = domain.evaluate_all_lagrange_coefficients(*domain_element);
1041                  for (j, lagrange_coefficient) in lagrange_coefficients.iter().enumerate().take(domain_size) {
1042                      // Lagrange coefficient for the evaluation point, which should be 1
1043                      if i == j {
1044                          assert_eq!(*lagrange_coefficient, Fr::one());
1045                      } else {
1046                          assert_eq!(*lagrange_coefficient, Fr::zero());
1047                      }
1048                  }
1049              }
1050          }
1051      }
1052  
1053      /// Tests that the roots of unity result is the same as domain.elements().
1054      #[test]
1055      fn test_roots_of_unity() {
1056          let max_degree = 10;
1057          for log_domain_size in 0..max_degree {
1058              let domain_size = 1 << log_domain_size;
1059              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1060              let actual_roots = domain.roots_of_unity(domain.group_gen);
1061              for &value in &actual_roots {
1062                  assert!(domain.evaluate_vanishing_polynomial(value).is_zero());
1063              }
1064              let expected_roots_elements = domain.elements();
1065              for (expected, &actual) in expected_roots_elements.zip(&actual_roots) {
1066                  assert_eq!(expected, actual);
1067              }
1068              assert_eq!(actual_roots.len(), domain_size / 2);
1069          }
1070      }
1071  
1072      /// Tests that the FFTs output the correct result.
1073      #[test]
1074      fn test_fft_correctness() {
1075          // This assumes a correct polynomial evaluation at point procedure.
1076          // It tests consistency of FFT/IFFT, and coset_fft/coset_ifft,
1077          // along with testing that each individual evaluation is correct.
1078  
1079          let mut rng = TestRng::default();
1080  
1081          // Runs in time O(degree^2)
1082          let log_degree = 5;
1083          let degree = 1 << log_degree;
1084          let random_polynomial = DensePolynomial::<Fr>::rand(degree - 1, &mut rng);
1085  
1086          for log_domain_size in log_degree..(log_degree + 2) {
1087              let domain_size = 1 << log_domain_size;
1088              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1089              let polynomial_evaluations = domain.fft(&random_polynomial.coeffs);
1090              let polynomial_coset_evaluations = domain.coset_fft(&random_polynomial.coeffs);
1091              for (i, x) in domain.elements().enumerate() {
1092                  let coset_x = Fr::multiplicative_generator() * x;
1093  
1094                  assert_eq!(polynomial_evaluations[i], random_polynomial.evaluate(x));
1095                  assert_eq!(polynomial_coset_evaluations[i], random_polynomial.evaluate(coset_x));
1096              }
1097  
1098              let randon_polynomial_from_subgroup =
1099                  DensePolynomial::from_coefficients_vec(domain.ifft(&polynomial_evaluations));
1100              let random_polynomial_from_coset =
1101                  DensePolynomial::from_coefficients_vec(domain.coset_ifft(&polynomial_coset_evaluations));
1102  
1103              assert_eq!(
1104                  random_polynomial, randon_polynomial_from_subgroup,
1105                  "degree = {degree}, domain size = {domain_size}"
1106              );
1107              assert_eq!(
1108                  random_polynomial, random_polynomial_from_coset,
1109                  "degree = {degree}, domain size = {domain_size}"
1110              );
1111          }
1112      }
1113  
1114      /// Tests that FFT precomputation is correctly subdomained
1115      #[test]
1116      fn test_fft_precomputation() {
1117          for i in 1..10 {
1118              let big_domain = EvaluationDomain::<Fr>::new(i).unwrap();
1119              let pc = big_domain.precompute_fft();
1120              for j in 1..i {
1121                  let small_domain = EvaluationDomain::<Fr>::new(j).unwrap();
1122                  let small_pc = small_domain.precompute_fft();
1123                  assert_eq!(pc.precomputation_for_subdomain(&small_domain).unwrap().as_ref(), &small_pc);
1124              }
1125          }
1126      }
1127  
1128      /// Tests that IFFT precomputation is correctly subdomained
1129      #[test]
1130      fn test_ifft_precomputation() {
1131          for i in 1..10 {
1132              let big_domain = EvaluationDomain::<Fr>::new(i).unwrap();
1133              let pc = big_domain.precompute_ifft();
1134              for j in 1..i {
1135                  let small_domain = EvaluationDomain::<Fr>::new(j).unwrap();
1136                  let small_pc = small_domain.precompute_ifft();
1137                  assert_eq!(pc.precomputation_for_subdomain(&small_domain).unwrap().as_ref(), &small_pc);
1138              }
1139          }
1140      }
1141  
1142      /// Tests that IFFT precomputation can be correctly computed from
1143      /// FFT precomputation
1144      #[test]
1145      fn test_ifft_precomputation_from_fft() {
1146          for i in 1..10 {
1147              let domain = EvaluationDomain::<Fr>::new(i).unwrap();
1148              let pc = domain.precompute_ifft();
1149              let fft_pc = domain.precompute_fft();
1150              assert_eq!(pc, fft_pc.to_ifft_precomputation())
1151          }
1152      }
1153  
1154      /// Tests that the FFTs output the correct result.
1155      #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
1156      #[test]
1157      fn test_fft_correctness_cuda() {
1158          let mut rng = TestRng::default();
1159          for log_domain in 2..20 {
1160              println!("Testing domain size {log_domain}");
1161              let domain_size = 1 << log_domain;
1162              let random_polynomial = DensePolynomial::<Fr>::rand(domain_size - 1, &mut rng);
1163              let mut polynomial_evaluations = random_polynomial.coeffs.clone();
1164              let mut polynomial_evaluations_cuda = random_polynomial.coeffs.clone();
1165  
1166              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1167              let pc = domain.precompute_fft();
1168              domain.fft_helper_in_place_with_pc(&mut polynomial_evaluations, FFTOrder::II, &pc);
1169  
1170              if alphavm_algorithms_cuda::NTT::<Fr>(
1171                  domain_size,
1172                  &mut polynomial_evaluations_cuda,
1173                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
1174                  alphavm_algorithms_cuda::NTTDirection::Forward,
1175                  alphavm_algorithms_cuda::NTTType::Standard,
1176              )
1177              .is_err()
1178              {
1179                  println!("cuda error!");
1180              }
1181  
1182              assert_eq!(polynomial_evaluations, polynomial_evaluations_cuda, "domain size = {domain_size}");
1183  
1184              // iNTT
1185              if alphavm_algorithms_cuda::NTT::<Fr>(
1186                  domain_size,
1187                  &mut polynomial_evaluations_cuda,
1188                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
1189                  alphavm_algorithms_cuda::NTTDirection::Inverse,
1190                  alphavm_algorithms_cuda::NTTType::Standard,
1191              )
1192              .is_err()
1193              {
1194                  println!("cuda error!");
1195              }
1196              assert_eq!(random_polynomial.coeffs, polynomial_evaluations_cuda, "domain size = {domain_size}");
1197  
1198              // Coset NTT
1199              polynomial_evaluations = random_polynomial.coeffs.clone();
1200              let domain = EvaluationDomain::<Fr>::new(domain_size).unwrap();
1201              let pc = domain.precompute_fft();
1202              EvaluationDomain::<Fr>::distribute_powers(&mut polynomial_evaluations, Fr::multiplicative_generator());
1203              domain.fft_helper_in_place_with_pc(&mut polynomial_evaluations, FFTOrder::II, &pc);
1204  
1205              if alphavm_algorithms_cuda::NTT::<Fr>(
1206                  domain_size,
1207                  &mut polynomial_evaluations_cuda,
1208                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
1209                  alphavm_algorithms_cuda::NTTDirection::Forward,
1210                  alphavm_algorithms_cuda::NTTType::Coset,
1211              )
1212              .is_err()
1213              {
1214                  println!("cuda error!");
1215              }
1216  
1217              assert_eq!(polynomial_evaluations, polynomial_evaluations_cuda, "domain size = {domain_size}");
1218  
1219              // Coset iNTT
1220              if alphavm_algorithms_cuda::NTT::<Fr>(
1221                  domain_size,
1222                  &mut polynomial_evaluations_cuda,
1223                  alphavm_algorithms_cuda::NTTInputOutputOrder::NN,
1224                  alphavm_algorithms_cuda::NTTDirection::Inverse,
1225                  alphavm_algorithms_cuda::NTTType::Coset,
1226              )
1227              .is_err()
1228              {
1229                  println!("cuda error!");
1230              }
1231              assert_eq!(random_polynomial.coeffs, polynomial_evaluations_cuda, "domain size = {domain_size}");
1232          }
1233      }
1234  }