polynomial.cuh
1 // Copyright (c) 2019-2025 Provable Inc. 2 // This file is part of the snarkVM library. 3 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at: 7 8 // http://www.apache.org/licenses/LICENSE-2.0 9 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 #ifndef __POLYNOMIAL_CUH__ 17 #define __POLYNOMIAL_CUH__ 18 19 #ifndef __CUDA_ARCH__ 20 21 #include <util/exception.cuh> 22 #include <util/rusterror.h> 23 #include <util/gpu_t.cuh> 24 25 #endif 26 27 #ifndef __CUDA_ARCH__ 28 void host_print_fr(fr_t f) { 29 uint64_t val[4]; 30 f.from(); 31 f.store((limb_t*)val); 32 printf("0x%016lx%016lx%016lx%016lx\n", val[3], val[2], val[1], val[0]); 33 } 34 #endif 35 36 __global__ 37 void polynomial_inner_multiply(size_t domain_size, fr_t* out, fr_t* in0, fr_t* in1) { 38 index_t idx = threadIdx.x + blockDim.x * (index_t)blockIdx.x; 39 if (idx >= domain_size) 40 return; 41 42 fr_t x = in0[idx]; 43 fr_t y = in1[idx]; 44 out[idx] = x * y; 45 } 46 47 #ifndef __CUDA_ARCH__ 48 49 class Polynomial : public NTT { 50 protected: 51 // // out will be in lg_domain_size 52 // static void MulDev(stream_t& stream, 53 // fr_t* d_out, fr_t* d_in0, fr_t* d_in1, // Device pointers 54 // uint32_t lg_domain_size) { 55 // size_t domain_size = (size_t)1 << lg_domain_size; 56 57 // // Perform NTT on the input data 58 // NTT_internal(d_in0, lg_domain_size, 59 // NTT::InputOutputOrder::NR, 60 // NTT::Direction::forward, 61 // NTT::Type::standard, 62 // stream); 63 // NTT_internal(d_in1, lg_domain_size, 64 // NTT::InputOutputOrder::NR, 65 // NTT::Direction::forward, 66 // NTT::Type::standard, 67 // stream); 68 69 // // Inner multiply 70 // polynomial_inner_multiply<<<(domain_size + 1023) / 1024, 1024, 0, stream>>> 71 // (domain_size, d_out, d_in0, d_in1); 72 73 // // Perform iNTT on the result 74 // NTT_internal(d_out, lg_domain_size, 75 // NTT::InputOutputOrder::RN, 76 // NTT::Direction::inverse, 77 // NTT::Type::standard, 78 // stream); 79 //} 80 81 static void mul_copy_poly(fr_t* hmem, fr_t* dmem, fr_t* poly, size_t len, 82 stream_t& stream, uint32_t lg_domain_size) { 83 size_t domain_size = (size_t)1 << lg_domain_size; 84 // Copy the data to the GPU 85 //memcpy(hmem, poly, sizeof(fr_t) * len); 86 cudaMemcpyAsync(hmem, poly, sizeof(fr_t) * len, cudaMemcpyHostToHost, stream); 87 //cudaMemsetAsync(dmem, 0, sizeof(fr_t) * domain_size, stream); 88 stream.HtoD(dmem, hmem, len); 89 cudaMemsetAsync(&dmem[len], 0, sizeof(fr_t) * (domain_size - len), stream); 90 } 91 static void mul_copy_eval(fr_t* hmem, fr_t* dmem, 92 fr_t* eval, size_t len, 93 stream_t& stream, uint32_t lg_domain_size) { 94 size_t domain_size = (size_t)1 << lg_domain_size; 95 assert (len == domain_size); 96 // Copy the data to the GPU 97 cudaMemcpyAsync(hmem, eval, sizeof(fr_t) * len, cudaMemcpyHostToHost, stream); 98 stream.HtoD(dmem, hmem, len); 99 //cudaMemsetAsync(&dmem[len], 0, sizeof(fr_t) * (domain_size - len), stream); 100 } 101 102 public: 103 // out will be in lg_domain_size 104 static RustError Mul(const gpu_t& gpu, stream_t& stream, 105 fr_t* hmem0, fr_t* hmem1, fr_t* hmem2, 106 fr_t* dmem0, fr_t* dmem1, fr_t* dmem2, fr_t* dmem3, 107 size_t pcount, fr_t** polynomials, size_t* plens, 108 size_t ecount, fr_t** evaluations, size_t* elens, 109 uint32_t lg_domain_size) { 110 try { 111 gpu.select(); 112 113 size_t domain_size = (size_t)1 << lg_domain_size; 114 115 size_t pcur = 0; 116 size_t ecur = 0; 117 118 // Set up the first polynomial / evaluation in dmem0 119 if (pcount > 0) { 120 mul_copy_poly(hmem0, dmem0, polynomials[0], plens[0], stream, lg_domain_size); 121 // Perform NTT on the input data 122 NTT_internal(dmem0, lg_domain_size, NTT::InputOutputOrder::NR, 123 NTT::Direction::forward, NTT::Type::standard, stream); 124 pcur++; 125 } else { 126 mul_copy_eval(hmem0, dmem3, evaluations[0], elens[0], stream, lg_domain_size); 127 // Bit reversal 128 bit_rev(dmem0, dmem3, lg_domain_size, stream); 129 ecur++; 130 } 131 132 // Compute counters 133 size_t pcomp = pcur; 134 size_t ecomp = ecur; 135 // Alternate stream 136 stream_t& alt = gpu; 137 138 // Start copying the next data into dmem on the alternate stream 139 if (pcur < pcount) { 140 mul_copy_poly(hmem2, dmem2, polynomials[pcur], plens[pcur], alt, lg_domain_size); 141 pcur++; 142 } else { 143 assert (ecur < ecount); 144 mul_copy_eval(hmem2, dmem2, evaluations[ecur], elens[ecur], alt, lg_domain_size); 145 ecur++; 146 } 147 148 // From here we overlap compute and copy using a double buffer 149 fr_t* hmem_comp = hmem2; 150 fr_t* dmem_comp = dmem2; 151 fr_t* hmem_copy = hmem1; 152 fr_t* dmem_copy = dmem1; 153 bool comp_on_mem2 = true; 154 155 while (pcomp < pcount || ecomp < ecount) { 156 // Sync the alternate stream to ensure data is on the GPU 157 alt.sync(); 158 // And the compute stream to ensure the compute buffers can be overwritten 159 stream.sync(); 160 //printf("SYNC\n"); 161 162 // // Start the next data copy concurrently (if there is another to do) 163 // if (pcur < pcount) { 164 // printf(" poly copy pcur %ld, plen %ld\n", pcur, plens[pcur]); 165 // mul_copy_poly(hmem_copy, dmem_copy, polynomials[pcur], plens[pcur], 166 // alt, lg_domain_size); 167 // pcur++; 168 // } else if (ecur < ecount) { 169 // printf(" eval copy ecur %ld, elen %ld\n", ecur, elens[ecur]); 170 // mul_copy_eval(hmem_copy, dmem_copy, evaluations[ecur], elens[ecur], 171 // alt, lg_domain_size); 172 // ecur++; 173 // } 174 175 176 177 // Start the computation 178 fr_t* mul_src = nullptr; 179 if (pcomp < pcount) { 180 //printf(" poly comp mem2 %d, pcomp %ld\n", comp_on_mem2, pcomp); 181 // Perform NTT on the input data 182 NTT_internal(dmem_comp, lg_domain_size, NTT::InputOutputOrder::NR, 183 NTT::Direction::forward, NTT::Type::standard, stream); 184 mul_src = dmem_comp; 185 pcomp++; 186 } else { 187 //printf(" eval comp mem2 %d, pcomp %ld\n", comp_on_mem2, pcomp); 188 // Bit reversal into an aux buffer 189 bit_rev(dmem3, dmem_comp, lg_domain_size, stream); 190 mul_src = dmem3; 191 192 // stream.DtoH(hmem0, dmem3, domain_size); 193 // cudaDeviceSynchronize(); 194 // printf("After bit_rev aux\n"); 195 // for (size_t i = 0; i < domain_size; i++) { 196 // printf(" %5ld: ", i); 197 // host_print_fr(hmem0[i]); 198 // } 199 200 // bit_rev(dmem_comp, dmem_comp, lg_domain_size, stream); 201 // mul_src = dmem_comp; 202 203 // stream.DtoH(hmem0, dmem_comp, domain_size); 204 // cudaDeviceSynchronize(); 205 // printf("After bit_rev\n"); 206 // for (size_t i = 0; i < domain_size; i++) { 207 // printf(" %5ld: ", i); 208 // host_print_fr(hmem0[i]); 209 // } 210 211 ecomp++; 212 } 213 //cudaDeviceSynchronize(); 214 215 // Inner multiply 216 polynomial_inner_multiply<<<(domain_size + 1023) / 1024, 1024, 0, stream>>> 217 (domain_size, dmem0, dmem0, mul_src); 218 219 //alt.sync(); 220 //stream.sync(); 221 //cudaDeviceSynchronize(); 222 223 // Start the next data copy concurrently (if there is another to do) 224 if (pcur < pcount) { 225 //printf(" poly copy pcur %ld, plen %ld\n", pcur, plens[pcur]); 226 mul_copy_poly(hmem_copy, dmem_copy, polynomials[pcur], plens[pcur], 227 alt, lg_domain_size); 228 pcur++; 229 } else if (ecur < ecount) { 230 //printf(" eval copy ecur %ld, elen %ld\n", ecur, elens[ecur]); 231 mul_copy_eval(hmem_copy, dmem_copy, evaluations[ecur], elens[ecur], 232 alt, lg_domain_size); 233 ecur++; 234 } 235 236 //cudaDeviceSynchronize(); 237 238 // Swap buffers 239 std::swap(hmem_comp, hmem_copy); 240 std::swap(dmem_comp, dmem_copy); 241 comp_on_mem2 = !comp_on_mem2; 242 } 243 //cudaDeviceSynchronize(); 244 245 // alt.sync(); 246 // stream.sync(); 247 248 // Perform iNTT on the result 249 NTT_internal(dmem0, lg_domain_size, 250 NTT::InputOutputOrder::RN, NTT::Direction::inverse, 251 NTT::Type::standard, stream); 252 253 // Copy the output data 254 stream.DtoH(hmem0, dmem0, domain_size); 255 stream.sync(); 256 } catch (const cuda_error& e) { 257 gpu.sync(); 258 #ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE 259 return RustError{e.code(), e.what()}; 260 #else 261 return RustError{e.code()}; 262 #endif 263 } 264 265 return RustError{cudaSuccess}; 266 } 267 }; 268 #endif // __CUDA_ARCH__ 269 270 #endif