/ algorithms / cuda / cuda / polynomial.cuh
polynomial.cuh
  1  // Copyright (c) 2019-2025 Provable Inc.
  2  // This file is part of the snarkVM library.
  3  
  4  // Licensed under the Apache License, Version 2.0 (the "License");
  5  // you may not use this file except in compliance with the License.
  6  // You may obtain a copy of the License at:
  7  
  8  // http://www.apache.org/licenses/LICENSE-2.0
  9  
 10  // Unless required by applicable law or agreed to in writing, software
 11  // distributed under the License is distributed on an "AS IS" BASIS,
 12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13  // See the License for the specific language governing permissions and
 14  // limitations under the License.
 15  
 16  #ifndef __POLYNOMIAL_CUH__
 17  #define __POLYNOMIAL_CUH__
 18  
 19  #ifndef __CUDA_ARCH__
 20  
 21  #include <util/exception.cuh>
 22  #include <util/rusterror.h>
 23  #include <util/gpu_t.cuh>
 24  
 25  #endif
 26  
 27  #ifndef __CUDA_ARCH__
 28  void host_print_fr(fr_t f) {
 29      uint64_t val[4];
 30      f.from();
 31      f.store((limb_t*)val);
 32      printf("0x%016lx%016lx%016lx%016lx\n", val[3], val[2], val[1], val[0]);
 33  }
 34  #endif
 35  
 36  __global__
 37  void polynomial_inner_multiply(size_t domain_size, fr_t* out, fr_t* in0, fr_t* in1) {
 38      index_t idx = threadIdx.x + blockDim.x * (index_t)blockIdx.x;
 39      if (idx >= domain_size)
 40          return;
 41      
 42      fr_t x = in0[idx];
 43      fr_t y = in1[idx];
 44      out[idx] = x * y;
 45  }
 46  
 47  #ifndef __CUDA_ARCH__
 48  
 49  class Polynomial : public NTT {
 50  protected:
 51      // // out will be in lg_domain_size
 52      // static void MulDev(stream_t& stream,
 53      //                    fr_t* d_out, fr_t* d_in0, fr_t* d_in1, // Device pointers
 54      //                    uint32_t lg_domain_size) {
 55      //     size_t domain_size = (size_t)1 << lg_domain_size;
 56          
 57      //     // Perform NTT on the input data
 58      //     NTT_internal(d_in0, lg_domain_size,
 59      //                  NTT::InputOutputOrder::NR,
 60      //                  NTT::Direction::forward,
 61      //                  NTT::Type::standard,
 62      //                  stream);
 63      //     NTT_internal(d_in1, lg_domain_size,
 64      //                  NTT::InputOutputOrder::NR,
 65      //                  NTT::Direction::forward,
 66      //                  NTT::Type::standard,
 67      //                  stream);
 68  
 69      //     // Inner multiply
 70      //     polynomial_inner_multiply<<<(domain_size + 1023) / 1024, 1024, 0, stream>>>
 71      //         (domain_size, d_out, d_in0, d_in1);
 72          
 73      //     // Perform iNTT on the result
 74      //     NTT_internal(d_out, lg_domain_size,
 75      //                  NTT::InputOutputOrder::RN,
 76      //                  NTT::Direction::inverse,
 77      //                  NTT::Type::standard,
 78      //                  stream);
 79      //}
 80  
 81      static void mul_copy_poly(fr_t* hmem, fr_t* dmem, fr_t* poly, size_t len,
 82                                stream_t& stream, uint32_t lg_domain_size) {
 83          size_t domain_size = (size_t)1 << lg_domain_size;
 84          // Copy the data to the GPU
 85          //memcpy(hmem, poly, sizeof(fr_t) * len);
 86          cudaMemcpyAsync(hmem, poly, sizeof(fr_t) * len, cudaMemcpyHostToHost, stream);
 87          //cudaMemsetAsync(dmem, 0, sizeof(fr_t) * domain_size, stream);
 88          stream.HtoD(dmem, hmem, len);
 89          cudaMemsetAsync(&dmem[len], 0, sizeof(fr_t) * (domain_size - len), stream);
 90      }
 91      static void mul_copy_eval(fr_t* hmem, fr_t* dmem,
 92                                fr_t* eval, size_t len,
 93                                stream_t& stream, uint32_t lg_domain_size) {
 94          size_t domain_size = (size_t)1 << lg_domain_size;
 95          assert (len == domain_size);
 96          // Copy the data to the GPU
 97          cudaMemcpyAsync(hmem, eval, sizeof(fr_t) * len, cudaMemcpyHostToHost, stream);
 98          stream.HtoD(dmem, hmem, len);
 99          //cudaMemsetAsync(&dmem[len], 0, sizeof(fr_t) * (domain_size - len), stream);
100      }
101      
102  public:
103      // out will be in lg_domain_size
104      static RustError Mul(const gpu_t& gpu, stream_t& stream,
105                           fr_t* hmem0, fr_t* hmem1, fr_t* hmem2,
106                           fr_t* dmem0, fr_t* dmem1, fr_t* dmem2, fr_t* dmem3,
107                           size_t pcount, fr_t** polynomials, size_t* plens,
108                           size_t ecount, fr_t** evaluations, size_t* elens,
109                           uint32_t lg_domain_size) {
110          try {
111              gpu.select();
112  
113              size_t domain_size = (size_t)1 << lg_domain_size;
114  
115              size_t pcur = 0;
116              size_t ecur = 0;
117  
118              // Set up the first polynomial / evaluation in dmem0
119              if (pcount > 0) {
120                  mul_copy_poly(hmem0, dmem0, polynomials[0], plens[0], stream, lg_domain_size);
121                  // Perform NTT on the input data
122                  NTT_internal(dmem0, lg_domain_size, NTT::InputOutputOrder::NR,
123                               NTT::Direction::forward, NTT::Type::standard, stream);
124                  pcur++;
125              } else {
126                  mul_copy_eval(hmem0, dmem3, evaluations[0], elens[0], stream, lg_domain_size);
127                  // Bit reversal
128                  bit_rev(dmem0, dmem3, lg_domain_size, stream);
129                  ecur++;
130              }
131  
132              // Compute counters
133              size_t pcomp = pcur;
134              size_t ecomp = ecur;
135              // Alternate stream
136              stream_t& alt = gpu;
137  
138              // Start copying the next data into dmem on the alternate stream
139              if (pcur < pcount) {
140                  mul_copy_poly(hmem2, dmem2, polynomials[pcur], plens[pcur], alt, lg_domain_size);
141                  pcur++;
142              } else {
143                  assert (ecur < ecount);
144                  mul_copy_eval(hmem2, dmem2, evaluations[ecur], elens[ecur], alt, lg_domain_size);
145                  ecur++;
146              }
147  
148              // From here we overlap compute and copy using a double buffer
149              fr_t* hmem_comp = hmem2;
150              fr_t* dmem_comp = dmem2;
151              fr_t* hmem_copy = hmem1;
152              fr_t* dmem_copy = dmem1;
153              bool comp_on_mem2 = true;
154              
155              while (pcomp < pcount || ecomp < ecount) {
156                  // Sync the alternate stream to ensure data is on the GPU
157                  alt.sync();
158                  // And the compute stream to ensure the compute buffers can be overwritten
159                  stream.sync();
160                  //printf("SYNC\n");
161  
162                  // // Start the next data copy concurrently (if there is another to do)
163                  // if (pcur < pcount) {
164                  //     printf("  poly copy pcur %ld, plen %ld\n", pcur, plens[pcur]);
165                  //     mul_copy_poly(hmem_copy, dmem_copy, polynomials[pcur], plens[pcur],
166                  //                   alt, lg_domain_size);
167                  //     pcur++;
168                  // } else if (ecur < ecount) {
169                  //     printf("  eval copy ecur %ld, elen %ld\n", ecur, elens[ecur]);
170                  //     mul_copy_eval(hmem_copy, dmem_copy, evaluations[ecur], elens[ecur],
171                  //                   alt, lg_domain_size);
172                  //     ecur++;
173                  // }
174  
175  
176                  
177                  // Start the computation
178                  fr_t* mul_src = nullptr;
179                  if (pcomp < pcount) {
180                      //printf("  poly comp mem2 %d, pcomp %ld\n", comp_on_mem2, pcomp);
181                      // Perform NTT on the input data
182                      NTT_internal(dmem_comp, lg_domain_size, NTT::InputOutputOrder::NR,
183                                   NTT::Direction::forward, NTT::Type::standard, stream);
184                      mul_src = dmem_comp;
185                      pcomp++;
186                  } else {
187                      //printf("  eval comp mem2 %d, pcomp %ld\n", comp_on_mem2, pcomp);
188                      // Bit reversal into an aux buffer
189                      bit_rev(dmem3, dmem_comp, lg_domain_size, stream);
190                      mul_src = dmem3;
191  
192                      // stream.DtoH(hmem0, dmem3, domain_size);
193                      // cudaDeviceSynchronize();
194                      // printf("After bit_rev aux\n");
195                      // for (size_t i = 0; i < domain_size; i++) {
196                      //     printf("  %5ld: ", i);
197                      //     host_print_fr(hmem0[i]);
198                      // }
199                      
200                      // bit_rev(dmem_comp, dmem_comp, lg_domain_size, stream);
201                      // mul_src = dmem_comp;
202  
203                      // stream.DtoH(hmem0, dmem_comp, domain_size);
204                      // cudaDeviceSynchronize();
205                      // printf("After bit_rev\n");
206                      // for (size_t i = 0; i < domain_size; i++) {
207                      //     printf("  %5ld: ", i);
208                      //     host_print_fr(hmem0[i]);
209                      // }
210  
211                      ecomp++;
212                  }
213                  //cudaDeviceSynchronize();
214                  
215                  // Inner multiply
216                  polynomial_inner_multiply<<<(domain_size + 1023) / 1024, 1024, 0, stream>>>
217                      (domain_size, dmem0, dmem0, mul_src);
218  
219                  //alt.sync();
220                  //stream.sync();
221                  //cudaDeviceSynchronize();
222                  
223                  // Start the next data copy concurrently (if there is another to do)
224                  if (pcur < pcount) {
225                      //printf("  poly copy pcur %ld, plen %ld\n", pcur, plens[pcur]);
226                      mul_copy_poly(hmem_copy, dmem_copy, polynomials[pcur], plens[pcur],
227                                    alt, lg_domain_size);
228                      pcur++;
229                  } else if (ecur < ecount) {
230                      //printf("  eval copy ecur %ld, elen %ld\n", ecur, elens[ecur]);
231                      mul_copy_eval(hmem_copy, dmem_copy, evaluations[ecur], elens[ecur],
232                                    alt, lg_domain_size);
233                      ecur++;
234                  }
235  
236                  //cudaDeviceSynchronize();
237  
238                  // Swap buffers
239                  std::swap(hmem_comp, hmem_copy);
240                  std::swap(dmem_comp, dmem_copy);
241                  comp_on_mem2 = !comp_on_mem2;
242              }
243              //cudaDeviceSynchronize();
244  
245              // alt.sync();
246              // stream.sync();
247              
248              // Perform iNTT on the result
249              NTT_internal(dmem0, lg_domain_size,
250                           NTT::InputOutputOrder::RN, NTT::Direction::inverse,
251                           NTT::Type::standard, stream);
252              
253              // Copy the output data
254              stream.DtoH(hmem0, dmem0, domain_size);
255              stream.sync();
256          } catch (const cuda_error& e) {
257              gpu.sync();
258  #ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE
259              return RustError{e.code(), e.what()};
260  #else
261              return RustError{e.code()};
262  #endif
263          }
264  
265          return RustError{cudaSuccess};
266      }
267  };
268  #endif //  __CUDA_ARCH__
269  
270  #endif