/ src / examples / cxx / mdspan_test.cpp
mdspan_test.cpp
  1  // examples/cxx/mdspan_test.cpp
  2  //
  3  // mdspan verification using Kokkos reference implementation
  4  // (provides std::mdspan via <mdspan> header)
  5  //
  6  // verifies:
  7  //   - std::mdspan compiles and works
  8  //   - std::extents, std::dextents work
  9  //   - layout policies work
 10  //   - submdspan works
 11  
 12  #include <array>
 13  #include <cstdio>
 14  #include <mdspan>
 15  #include <numeric>
 16  #include <span>
 17  
 18  namespace straylight::examples {
 19  
 20  // ════════════════════════════════════════════════════════════════════════════════
 21  // basic mdspan test - 2d matrix view
 22  // ════════════════════════════════════════════════════════════════════════════════
 23  
 24  auto test_basic_mdspan() -> bool {
 25    // 3x4 matrix stored in row-major order
 26    std::array<float, 12> data{};
 27    std::iota(data.begin(), data.end(), 0.0f); // 0, 1, 2, ..., 11
 28  
 29    // create mdspan view
 30    std::mdspan matrix{data.data(), std::extents<std::size_t, 3, 4>{}};
 31  
 32    // verify dimensions
 33    if (matrix.extent(0) != 3 || matrix.extent(1) != 4) {
 34      std::printf("mdspan: extent mismatch\n");
 35      return false;
 36    }
 37  
 38    // verify element access - row 1, col 2 should be 1*4 + 2 = 6
 39    if (matrix[1, 2] != 6.0f) {
 40      std::printf("mdspan: element access failed, got %f expected 6.0\n",
 41                  static_cast<double>(matrix[1, 2]));
 42      return false;
 43    }
 44  
 45    // verify we can modify through the view
 46    matrix[2, 3] = 99.0f;
 47    if (data[11] != 99.0f) {
 48      std::printf("mdspan: modification failed\n");
 49      return false;
 50    }
 51  
 52    return true;
 53  }
 54  
 55  // ════════════════════════════════════════════════════════════════════════════════
 56  // dynamic extents test
 57  // ════════════════════════════════════════════════════════════════════════════════
 58  
 59  auto test_dynamic_extents() -> bool {
 60    std::array<int, 24> data{};
 61    std::iota(data.begin(), data.end(), 0);
 62  
 63    // 2x3x4 tensor with all dynamic extents
 64    std::mdspan tensor{data.data(), std::dextents<std::size_t, 3>{2, 3, 4}};
 65  
 66    if (tensor.extent(0) != 2 || tensor.extent(1) != 3 || tensor.extent(2) != 4) {
 67      std::printf("mdspan: dynamic extent mismatch\n");
 68      return false;
 69    }
 70  
 71    // element [1][2][3] = 1*12 + 2*4 + 3 = 23
 72    if (tensor[1, 2, 3] != 23) {
 73      std::printf("mdspan: 3d access failed, got %d expected 23\n", tensor[1, 2, 3]);
 74      return false;
 75    }
 76  
 77    return true;
 78  }
 79  
 80  // ════════════════════════════════════════════════════════════════════════════════
 81  // layout stride test - column major
 82  // ════════════════════════════════════════════════════════════════════════════════
 83  
 84  auto test_layout_stride() -> bool {
 85    std::array<float, 6> data{};
 86    std::iota(data.begin(), data.end(), 0.0f);
 87  
 88    // 2x3 matrix in column-major order
 89    using col_major = std::layout_left;
 90    std::mdspan<float, std::extents<std::size_t, 2, 3>, col_major> matrix{data.data()};
 91  
 92    // in column major, [1][0] should be element 1 (second element of first
 93    // column)
 94    if (matrix[1, 0] != 1.0f) {
 95      std::printf("mdspan: column major layout failed\n");
 96      return false;
 97    }
 98  
 99    // [0][1] should be element 2 (first element of second column)
100    if (matrix[0, 1] != 2.0f) {
101      std::printf("mdspan: column major [0,1] failed, got %f\n", static_cast<double>(matrix[0, 1]));
102      return false;
103    }
104  
105    return true;
106  }
107  
108  // ════════════════════════════════════════════════════════════════════════════════
109  // mixed static/dynamic extents
110  // ════════════════════════════════════════════════════════════════════════════════
111  
112  auto test_mixed_extents() -> bool {
113    std::array<double, 32> data{};
114    std::iota(data.begin(), data.end(), 0.0);
115  
116    // batch of 4x8 matrices where batch size is dynamic
117    // extents<size_t, dynamic_extent, 4, 8> means [?, 4, 8]
118    using batch_matrix_extents = std::extents<std::size_t, std::dynamic_extent, 4, 8>;
119    std::mdspan batch{data.data(), batch_matrix_extents{1}}; // 1 batch
120  
121    if (batch.extent(0) != 1 || batch.extent(1) != 4 || batch.extent(2) != 8) {
122      std::printf("mdspan: mixed extents failed\n");
123      return false;
124    }
125  
126    // [0][2][3] = 0*32 + 2*8 + 3 = 19
127    if (batch[0, 2, 3] != 19.0) {
128      std::printf("mdspan: mixed extent access failed\n");
129      return false;
130    }
131  
132    return true;
133  }
134  
135  // ════════════════════════════════════════════════════════════════════════════════
136  // submdspan test (c++26 feature, but gcc15 has it)
137  // ════════════════════════════════════════════════════════════════════════════════
138  
139  auto test_submdspan() -> bool {
140    std::array<int, 12> data{};
141    std::iota(data.begin(), data.end(), 0);
142  
143    // 3x4 matrix
144    std::mdspan matrix{data.data(), std::extents<std::size_t, 3, 4>{}};
145  
146    // extract row 1 as a 1d span
147    auto row1 = std::submdspan(matrix, 1, std::full_extent);
148  
149    if (row1.extent(0) != 4) {
150      std::printf("submdspan: row extent wrong\n");
151      return false;
152    }
153  
154    // row 1 starts at element 4
155    if (row1[0] != 4 || row1[3] != 7) {
156      std::printf("submdspan: row elements wrong\n");
157      return false;
158    }
159  
160    // extract column 2 as strided 1d span
161    auto col2 = std::submdspan(matrix, std::full_extent, 2);
162  
163    if (col2.extent(0) != 3) {
164      std::printf("submdspan: column extent wrong\n");
165      return false;
166    }
167  
168    // column 2 elements: 2, 6, 10
169    if (col2[0] != 2 || col2[1] != 6 || col2[2] != 10) {
170      std::printf("submdspan: column elements wrong: %d %d %d\n", col2[0], col2[1], col2[2]);
171      return false;
172    }
173  
174    return true;
175  }
176  
177  auto implementation() -> int {
178    int failures = 0;
179  
180    std::printf("mdspan tests (gcc15 libstdc++ c++23):\n");
181  
182    if (test_basic_mdspan()) {
183      std::printf("  basic_mdspan: pass\n");
184    } else {
185      std::printf("  basic_mdspan: FAIL\n");
186      failures++;
187    }
188  
189    if (test_dynamic_extents()) {
190      std::printf("  dynamic_extents: pass\n");
191    } else {
192      std::printf("  dynamic_extents: FAIL\n");
193      failures++;
194    }
195  
196    if (test_layout_stride()) {
197      std::printf("  layout_stride: pass\n");
198    } else {
199      std::printf("  layout_stride: FAIL\n");
200      failures++;
201    }
202  
203    if (test_mixed_extents()) {
204      std::printf("  mixed_extents: pass\n");
205    } else {
206      std::printf("  mixed_extents: FAIL\n");
207      failures++;
208    }
209  
210    if (test_submdspan()) {
211      std::printf("  submdspan: pass\n");
212    } else {
213      std::printf("  submdspan: FAIL\n");
214      failures++;
215    }
216  
217    if (failures == 0) {
218      std::printf("all mdspan tests passed\n");
219      return 0;
220    } else {
221      std::printf("%d mdspan tests FAILED\n", failures);
222      return 1;
223    }
224  }
225  
226  } // namespace straylight::examples
227  
228  auto main(int argc, char* argv[]) -> int {
229    return straylight::examples::implementation();
230  }