lib.rs
1 // Copyright (c) 2019-2025 Alpha-Delta Network Inc. 2 // This file is part of the deltavm library. 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at: 7 8 // http://www.apache.org/licenses/LICENSE-2.0 9 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #[allow(unused_imports)] 17 use blst::*; 18 19 use core::ffi::c_void; 20 sppark::cuda_error!(); 21 22 #[repr(C)] 23 pub enum NTTInputOutputOrder { 24 NN = 0, 25 NR = 1, 26 RN = 2, 27 RR = 3, 28 } 29 30 #[repr(C)] 31 pub enum NTTDirection { 32 Forward = 0, 33 Inverse = 1, 34 } 35 36 #[repr(C)] 37 pub enum NTTType { 38 Standard = 0, 39 Coset = 1, 40 } 41 42 extern "C" { 43 fn deltavm_ntt( 44 inout: *mut core::ffi::c_void, 45 lg_domain_size: u32, 46 ntt_order: NTTInputOutputOrder, 47 ntt_direction: NTTDirection, 48 ntt_type: NTTType, 49 ) -> cuda::Error; 50 51 fn deltavm_polymul( 52 out: *mut core::ffi::c_void, 53 pcount: usize, 54 polynomials: *const core::ffi::c_void, 55 plens: *const core::ffi::c_void, 56 ecount: usize, 57 evaluations: *const core::ffi::c_void, 58 elens: *const core::ffi::c_void, 59 lg_domain_size: u32, 60 ) -> cuda::Error; 61 62 fn deltavm_msm( 63 out: *mut c_void, 64 points_with_infinity: *const c_void, 65 npoints: usize, 66 scalars: *const c_void, 67 ffi_affine_sz: usize, 68 ) -> cuda::Error; 69 } 70 71 /////////////////////////////////////////////////////////////////////////////// 72 // Rust functions 73 /////////////////////////////////////////////////////////////////////////////// 74 75 /// Compute an in-place NTT on the input data. 76 #[allow(non_snake_case)] 77 pub fn NTT<T>( 78 domain_size: usize, 79 inout: &mut [T], 80 ntt_order: NTTInputOutputOrder, 81 ntt_direction: NTTDirection, 82 ntt_type: NTTType, 83 ) -> Result<(), cuda::Error> { 84 if (domain_size & (domain_size - 1)) != 0 { 85 panic!("domain_size is not power of 2"); 86 } 87 let lg_domain_size = domain_size.trailing_zeros(); 88 89 let err = unsafe { 90 deltavm_ntt(inout.as_mut_ptr() as *mut core::ffi::c_void, lg_domain_size, ntt_order, ntt_direction, ntt_type) 91 }; 92 93 if err.code != 0 { 94 return Err(err); 95 } 96 Ok(()) 97 } 98 99 /// Compute a polynomial multiply 100 pub fn polymul<T: std::clone::Clone>( 101 domain: usize, 102 polynomials: &Vec<Vec<T>>, 103 evaluations: &Vec<Vec<T>>, 104 zero: &T, 105 ) -> Result<Vec<T>, cuda::Error> { 106 let initial_domain_size = domain; 107 if (initial_domain_size & (initial_domain_size - 1)) != 0 { 108 panic!("domain_size is not power of 2"); 109 } 110 111 let lg_domain_size = initial_domain_size.trailing_zeros(); 112 113 let mut pptrs = Vec::new(); 114 let mut plens = Vec::new(); 115 for polynomial in polynomials { 116 pptrs.push(polynomial.as_ptr() as *const core::ffi::c_void); 117 plens.push(polynomial.len()); 118 } 119 let mut eptrs = Vec::new(); 120 let mut elens = Vec::new(); 121 for evaluation in evaluations { 122 eptrs.push(evaluation.as_ptr() as *const core::ffi::c_void); 123 elens.push(evaluation.len()); 124 } 125 126 let mut out = Vec::new(); 127 out.resize(initial_domain_size, zero.clone()); 128 let err = unsafe { 129 deltavm_polymul( 130 out.as_mut_ptr() as *mut core::ffi::c_void, 131 pptrs.len(), 132 pptrs.as_ptr() as *const core::ffi::c_void, 133 plens.as_ptr() as *const core::ffi::c_void, 134 eptrs.len(), 135 eptrs.as_ptr() as *const core::ffi::c_void, 136 elens.as_ptr() as *const core::ffi::c_void, 137 lg_domain_size, 138 ) 139 }; 140 141 if err.code != 0 { 142 return Err(err); 143 } 144 Ok(out) 145 } 146 147 /// Compute a multi-scalar multiplication 148 pub fn msm<Affine, Projective, Scalar>(points: &[Affine], scalars: &[Scalar]) -> Result<Projective, cuda::Error> { 149 let npoints = scalars.len(); 150 if npoints > points.len() { 151 panic!("length mismatch {} points < {} scalars", npoints, scalars.len()) 152 } 153 #[allow(clippy::uninit_assumed_init)] 154 let mut ret: Projective = unsafe { std::mem::MaybeUninit::uninit().assume_init() }; 155 let err = unsafe { 156 deltavm_msm( 157 &mut ret as *mut _ as *mut c_void, 158 points as *const _ as *const c_void, 159 npoints, 160 scalars as *const _ as *const c_void, 161 std::mem::size_of::<Affine>(), 162 ) 163 }; 164 if err.code != 0 { 165 return Err(err); 166 } 167 Ok(ret) 168 }