/ src / examples / nv / mdspan_device_test.cpp
mdspan_device_test.cpp
  1  // examples/nv/mdspan_device_test.cpp
  2  //
  3  // std::mdspan on nvidia device - the holy grail
  4  //
  5  // verifies:
  6  //   - mdspan works in device code
  7  //   - cuda::std::span interop
  8  //   - matmul and reduction kernels
  9  //
 10  // compiled with: clang++ -x cuda --cuda-path=... --cuda-gpu-arch=sm_90
 11  
 12  #include <array>
 13  #include <cstdio>
 14  #include <cuda_runtime.h>
 15  #include <numeric>
 16  
 17  // use kokkos mdspan for device compatibility
 18  // (std::mdspan not yet in cuda::std::)
 19  #include <experimental/mdspan>
 20  
 21  namespace stdex = std::experimental;
 22  
 23  namespace straylight::examples {
 24  
 25  // ════════════════════════════════════════════════════════════════════════════════
 26  // device kernel using mdspan
 27  // ════════════════════════════════════════════════════════════════════════════════
 28  
 29  // matrix multiply kernel using mdspan for type-safe indexing
 30  // A[M,K] * B[K,N] = C[M,N]
 31  template <typename T>
 32  __global__ void matmul_kernel(const T *__restrict__ a_data,
 33                                const T *__restrict__ b_data,
 34                                T *__restrict__ c_data, int M, int K, int N) {
 35    // create mdspan views inside kernel
 36    using matrix_t = stdex::mdspan<const T, stdex::dextents<int, 2>>;
 37    using out_matrix_t = stdex::mdspan<T, stdex::dextents<int, 2>>;
 38  
 39    matrix_t A{a_data, M, K};
 40    matrix_t B{b_data, K, N};
 41    out_matrix_t C{c_data, M, N};
 42  
 43    int row = blockIdx.y * blockDim.y + threadIdx.y;
 44    int col = blockIdx.x * blockDim.x + threadIdx.x;
 45  
 46    if (row < M && col < N) {
 47      T sum = 0;
 48      for (int k = 0; k < K; ++k) {
 49        // use operator[] for C++23 multidimensional subscript
 50        sum += A[row, k] * B[k, col];
 51      }
 52      C[row, col] = sum;
 53    }
 54  }
 55  
 56  // simple reduction kernel
 57  template <typename T>
 58  __global__ void reduce_sum_kernel(const T *data, T *result, int n) {
 59    __shared__ T shared_data[256];
 60  
 61    int tid = threadIdx.x;
 62    int idx = blockIdx.x * blockDim.x + threadIdx.x;
 63  
 64    shared_data[tid] = (idx < n) ? data[idx] : T{0};
 65    __syncthreads();
 66  
 67    // parallel reduction in shared memory
 68    for (int stride = blockDim.x / 2; stride > 0; stride >>= 1) {
 69      if (tid < stride) {
 70        shared_data[tid] += shared_data[tid + stride];
 71      }
 72      __syncthreads();
 73    }
 74  
 75    if (tid == 0) {
 76      atomicAdd(result, shared_data[0]);
 77    }
 78  }
 79  
 80  // ════════════════════════════════════════════════════════════════════════════════
 81  // host-side test runner
 82  // ════════════════════════════════════════════════════════════════════════════════
 83  
 84  auto check_cuda_error(cudaError_t error, const char *operation) -> bool {
 85    if (error != cudaSuccess) {
 86      std::printf("  %s failed: %s\n", operation, cudaGetErrorString(error));
 87      return false;
 88    }
 89    return true;
 90  }
 91  
 92  auto test_matmul() -> bool {
 93    constexpr int M = 4;
 94    constexpr int K = 3;
 95    constexpr int N = 2;
 96  
 97    // host data
 98    std::array<float, M * K> h_a{};
 99    std::array<float, K * N> h_b{};
100    std::array<float, M * N> h_c{};
101  
102    // initialize A and B
103    std::iota(h_a.begin(), h_a.end(), 1.0f); // 1, 2, 3, ...
104    std::fill(h_b.begin(), h_b.end(), 1.0f); // all ones
105  
106    // device memory
107    float *d_a = nullptr, *d_b = nullptr, *d_c = nullptr;
108  
109    if (!check_cuda_error(cudaMalloc(&d_a, sizeof(h_a)), "malloc A"))
110      return false;
111    if (!check_cuda_error(cudaMalloc(&d_b, sizeof(h_b)), "malloc B"))
112      return false;
113    if (!check_cuda_error(cudaMalloc(&d_c, sizeof(h_c)), "malloc C"))
114      return false;
115  
116    // copy to device
117    cudaMemcpy(d_a, h_a.data(), sizeof(h_a), cudaMemcpyHostToDevice);
118    cudaMemcpy(d_b, h_b.data(), sizeof(h_b), cudaMemcpyHostToDevice);
119  
120    // launch kernel
121    dim3 block(16, 16);
122    dim3 grid((N + 15) / 16, (M + 15) / 16);
123    matmul_kernel<float><<<grid, block>>>(d_a, d_b, d_c, M, K, N);
124  
125    if (!check_cuda_error(cudaGetLastError(), "matmul kernel")) {
126      cudaFree(d_a);
127      cudaFree(d_b);
128      cudaFree(d_c);
129      return false;
130    }
131  
132    cudaDeviceSynchronize();
133  
134    // copy back
135    cudaMemcpy(h_c.data(), d_c, sizeof(h_c), cudaMemcpyDeviceToHost);
136  
137    // verify: each row of C should be sum of that row of A (since B is all ones)
138    // row 0: 1+2+3 = 6, row 1: 4+5+6 = 15, row 2: 7+8+9 = 24, row 3: 10+11+12 =
139    // 33
140    std::array<float, M> expected_row_sums{6.0f, 15.0f, 24.0f, 33.0f};
141  
142    bool passed = true;
143    for (int i = 0; i < M && passed; ++i) {
144      for (int j = 0; j < N && passed; ++j) {
145        if (h_c[i * N + j] != expected_row_sums[i]) {
146          std::printf("  matmul C[%d,%d] = %f, expected %f\n", i, j,
147                      h_c[i * N + j], expected_row_sums[i]);
148          passed = false;
149        }
150      }
151    }
152  
153    cudaFree(d_a);
154    cudaFree(d_b);
155    cudaFree(d_c);
156  
157    return passed;
158  }
159  
160  auto test_reduction() -> bool {
161    constexpr int N = 1000;
162  
163    std::array<float, N> h_data{};
164    std::iota(h_data.begin(), h_data.end(), 1.0f); // 1, 2, 3, ..., 1000
165  
166    // expected sum: n*(n+1)/2 = 1000*1001/2 = 500500
167    constexpr float expected_sum = 500500.0f;
168  
169    float *d_data = nullptr, *d_result = nullptr;
170    float h_result = 0.0f;
171  
172    if (!check_cuda_error(cudaMalloc(&d_data, sizeof(h_data)), "malloc data"))
173      return false;
174    if (!check_cuda_error(cudaMalloc(&d_result, sizeof(float)), "malloc result"))
175      return false;
176  
177    cudaMemcpy(d_data, h_data.data(), sizeof(h_data), cudaMemcpyHostToDevice);
178    cudaMemset(d_result, 0, sizeof(float));
179  
180    // launch reduction
181    int threads = 256;
182    int blocks = (N + threads - 1) / threads;
183    reduce_sum_kernel<float><<<blocks, threads>>>(d_data, d_result, N);
184  
185    if (!check_cuda_error(cudaGetLastError(), "reduce kernel")) {
186      cudaFree(d_data);
187      cudaFree(d_result);
188      return false;
189    }
190  
191    cudaDeviceSynchronize();
192    cudaMemcpy(&h_result, d_result, sizeof(float), cudaMemcpyDeviceToHost);
193  
194    cudaFree(d_data);
195    cudaFree(d_result);
196  
197    // allow small floating point error
198    float diff = h_result - expected_sum;
199    if (diff < 0)
200      diff = -diff;
201  
202    if (diff > 1.0f) {
203      std::printf("  reduction got %f, expected %f\n", h_result, expected_sum);
204      return false;
205    }
206  
207    return true;
208  }
209  
210  auto main_impl() -> int {
211    // check for devices
212    int device_count = 0;
213    cudaError_t error = cudaGetDeviceCount(&device_count);
214  
215    if (error != cudaSuccess || device_count == 0) {
216      std::printf("nv mdspan tests: no devices available\n");
217      std::printf(
218          "compilation succeeded - mdspan device code compiled correctly\n");
219      return 0; // success - testing toolchain, not hardware
220    }
221  
222    // get device info
223    cudaDeviceProp props;
224    cudaGetDeviceProperties(&props, 0);
225    std::printf("nv mdspan tests on: %s (sm_%d%d)\n", props.name, props.major,
226                props.minor);
227  
228    int failures = 0;
229  
230    if (test_matmul()) {
231      std::printf("  matmul_mdspan: pass\n");
232    } else {
233      std::printf("  matmul_mdspan: FAIL\n");
234      failures++;
235    }
236  
237    if (test_reduction()) {
238      std::printf("  reduction: pass\n");
239    } else {
240      std::printf("  reduction: FAIL\n");
241      failures++;
242    }
243  
244    if (failures == 0) {
245      std::printf("all nv mdspan tests passed\n");
246      return 0;
247    } else {
248      std::printf("%d nv mdspan tests FAILED\n", failures);
249      return 1;
250    }
251  }
252  
253  } // namespace straylight::examples
254  
255  auto main() -> int { return straylight::examples::main_impl(); }