/ algorithms / cuda / src / lib.rs
lib.rs
  1  // Copyright (c) 2019-2025 Alpha-Delta Network Inc.
  2  // This file is part of the deltavm library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  #[allow(unused_imports)]
 17  use blst::*;
 18  
 19  use core::ffi::c_void;
 20  sppark::cuda_error!();
 21  
 22  #[repr(C)]
 23  pub enum NTTInputOutputOrder {
 24      NN = 0,
 25      NR = 1,
 26      RN = 2,
 27      RR = 3,
 28  }
 29  
 30  #[repr(C)]
 31  pub enum NTTDirection {
 32      Forward = 0,
 33      Inverse = 1,
 34  }
 35  
 36  #[repr(C)]
 37  pub enum NTTType {
 38      Standard = 0,
 39      Coset = 1,
 40  }
 41  
 42  extern "C" {
 43      fn deltavm_ntt(
 44          inout: *mut core::ffi::c_void,
 45          lg_domain_size: u32,
 46          ntt_order: NTTInputOutputOrder,
 47          ntt_direction: NTTDirection,
 48          ntt_type: NTTType,
 49      ) -> cuda::Error;
 50  
 51      fn deltavm_polymul(
 52          out: *mut core::ffi::c_void,
 53          pcount: usize,
 54          polynomials: *const core::ffi::c_void,
 55          plens: *const core::ffi::c_void,
 56          ecount: usize,
 57          evaluations: *const core::ffi::c_void,
 58          elens: *const core::ffi::c_void,
 59          lg_domain_size: u32,
 60      ) -> cuda::Error;
 61  
 62      fn deltavm_msm(
 63          out: *mut c_void,
 64          points_with_infinity: *const c_void,
 65          npoints: usize,
 66          scalars: *const c_void,
 67          ffi_affine_sz: usize,
 68      ) -> cuda::Error;
 69  }
 70  
 71  ///////////////////////////////////////////////////////////////////////////////
 72  // Rust functions
 73  ///////////////////////////////////////////////////////////////////////////////
 74  
 75  /// Compute an in-place NTT on the input data.
 76  #[allow(non_snake_case)]
 77  pub fn NTT<T>(
 78      domain_size: usize,
 79      inout: &mut [T],
 80      ntt_order: NTTInputOutputOrder,
 81      ntt_direction: NTTDirection,
 82      ntt_type: NTTType,
 83  ) -> Result<(), cuda::Error> {
 84      if (domain_size & (domain_size - 1)) != 0 {
 85          panic!("domain_size is not power of 2");
 86      }
 87      let lg_domain_size = domain_size.trailing_zeros();
 88  
 89      let err = unsafe {
 90          deltavm_ntt(inout.as_mut_ptr() as *mut core::ffi::c_void, lg_domain_size, ntt_order, ntt_direction, ntt_type)
 91      };
 92  
 93      if err.code != 0 {
 94          return Err(err);
 95      }
 96      Ok(())
 97  }
 98  
 99  /// Compute a polynomial multiply
100  pub fn polymul<T: std::clone::Clone>(
101      domain: usize,
102      polynomials: &Vec<Vec<T>>,
103      evaluations: &Vec<Vec<T>>,
104      zero: &T,
105  ) -> Result<Vec<T>, cuda::Error> {
106      let initial_domain_size = domain;
107      if (initial_domain_size & (initial_domain_size - 1)) != 0 {
108          panic!("domain_size is not power of 2");
109      }
110  
111      let lg_domain_size = initial_domain_size.trailing_zeros();
112  
113      let mut pptrs = Vec::new();
114      let mut plens = Vec::new();
115      for polynomial in polynomials {
116          pptrs.push(polynomial.as_ptr() as *const core::ffi::c_void);
117          plens.push(polynomial.len());
118      }
119      let mut eptrs = Vec::new();
120      let mut elens = Vec::new();
121      for evaluation in evaluations {
122          eptrs.push(evaluation.as_ptr() as *const core::ffi::c_void);
123          elens.push(evaluation.len());
124      }
125  
126      let mut out = Vec::new();
127      out.resize(initial_domain_size, zero.clone());
128      let err = unsafe {
129          deltavm_polymul(
130              out.as_mut_ptr() as *mut core::ffi::c_void,
131              pptrs.len(),
132              pptrs.as_ptr() as *const core::ffi::c_void,
133              plens.as_ptr() as *const core::ffi::c_void,
134              eptrs.len(),
135              eptrs.as_ptr() as *const core::ffi::c_void,
136              elens.as_ptr() as *const core::ffi::c_void,
137              lg_domain_size,
138          )
139      };
140  
141      if err.code != 0 {
142          return Err(err);
143      }
144      Ok(out)
145  }
146  
147  /// Compute a multi-scalar multiplication
148  pub fn msm<Affine, Projective, Scalar>(points: &[Affine], scalars: &[Scalar]) -> Result<Projective, cuda::Error> {
149      let npoints = scalars.len();
150      if npoints > points.len() {
151          panic!("length mismatch {} points < {} scalars", npoints, scalars.len())
152      }
153      #[allow(clippy::uninit_assumed_init)]
154      let mut ret: Projective = unsafe { std::mem::MaybeUninit::uninit().assume_init() };
155      let err = unsafe {
156          deltavm_msm(
157              &mut ret as *mut _ as *mut c_void,
158              points as *const _ as *const c_void,
159              npoints,
160              scalars as *const _ as *const c_void,
161              std::mem::size_of::<Affine>(),
162          )
163      };
164      if err.code != 0 {
165          return Err(err);
166      }
167      Ok(ret)
168  }