/ src / examples / cxx / mdspan_test.cpp
mdspan_test.cpp
  1  // examples/cxx/mdspan_test.cpp
  2  //
  3  // c++23 std::mdspan verification - the whole point of gcc15
  4  //
  5  // verifies:
  6  //   - std::mdspan compiles and works
  7  //   - std::extents, std::dextents work
  8  //   - layout policies work
  9  //   - submdspan works (c++26 but gcc15 has it)
 10  
 11  #include <array>
 12  #include <cstdio>
 13  #include <mdspan>
 14  #include <numeric>
 15  #include <span>
 16  
 17  namespace straylight::examples {
 18  
 19  // ════════════════════════════════════════════════════════════════════════════════
 20  // basic mdspan test - 2d matrix view
 21  // ════════════════════════════════════════════════════════════════════════════════
 22  
 23  auto test_basic_mdspan() -> bool {
 24    // 3x4 matrix stored in row-major order
 25    std::array<float, 12> data{};
 26    std::iota(data.begin(), data.end(), 0.0f); // 0, 1, 2, ..., 11
 27  
 28    // create mdspan view
 29    std::mdspan matrix{data.data(), std::extents<std::size_t, 3, 4>{}};
 30  
 31    // verify dimensions
 32    if (matrix.extent(0) != 3 || matrix.extent(1) != 4) {
 33      std::printf("mdspan: extent mismatch\n");
 34      return false;
 35    }
 36  
 37    // verify element access - row 1, col 2 should be 1*4 + 2 = 6
 38    if (matrix[1, 2] != 6.0f) {
 39      std::printf("mdspan: element access failed, got %f expected 6.0\n",
 40                  static_cast<double>(matrix[1, 2]));
 41      return false;
 42    }
 43  
 44    // verify we can modify through the view
 45    matrix[2, 3] = 99.0f;
 46    if (data[11] != 99.0f) {
 47      std::printf("mdspan: modification failed\n");
 48      return false;
 49    }
 50  
 51    return true;
 52  }
 53  
 54  // ════════════════════════════════════════════════════════════════════════════════
 55  // dynamic extents test
 56  // ════════════════════════════════════════════════════════════════════════════════
 57  
 58  auto test_dynamic_extents() -> bool {
 59    std::array<int, 24> data{};
 60    std::iota(data.begin(), data.end(), 0);
 61  
 62    // 2x3x4 tensor with all dynamic extents
 63    std::mdspan tensor{data.data(), std::dextents<std::size_t, 3>{2, 3, 4}};
 64  
 65    if (tensor.extent(0) != 2 || tensor.extent(1) != 3 || tensor.extent(2) != 4) {
 66      std::printf("mdspan: dynamic extent mismatch\n");
 67      return false;
 68    }
 69  
 70    // element [1][2][3] = 1*12 + 2*4 + 3 = 23
 71    if (tensor[1, 2, 3] != 23) {
 72      std::printf("mdspan: 3d access failed, got %d expected 23\n",
 73                  tensor[1, 2, 3]);
 74      return false;
 75    }
 76  
 77    return true;
 78  }
 79  
 80  // ════════════════════════════════════════════════════════════════════════════════
 81  // layout stride test - column major
 82  // ════════════════════════════════════════════════════════════════════════════════
 83  
 84  auto test_layout_stride() -> bool {
 85    std::array<float, 6> data{};
 86    std::iota(data.begin(), data.end(), 0.0f);
 87  
 88    // 2x3 matrix in column-major order
 89    using col_major = std::layout_left;
 90    std::mdspan<float, std::extents<std::size_t, 2, 3>, col_major> matrix{
 91        data.data()};
 92  
 93    // in column major, [1][0] should be element 1 (second element of first
 94    // column)
 95    if (matrix[1, 0] != 1.0f) {
 96      std::printf("mdspan: column major layout failed\n");
 97      return false;
 98    }
 99  
100    // [0][1] should be element 2 (first element of second column)
101    if (matrix[0, 1] != 2.0f) {
102      std::printf("mdspan: column major [0,1] failed, got %f\n",
103                  static_cast<double>(matrix[0, 1]));
104      return false;
105    }
106  
107    return true;
108  }
109  
110  // ════════════════════════════════════════════════════════════════════════════════
111  // mixed static/dynamic extents
112  // ════════════════════════════════════════════════════════════════════════════════
113  
114  auto test_mixed_extents() -> bool {
115    std::array<double, 32> data{};
116    std::iota(data.begin(), data.end(), 0.0);
117  
118    // batch of 4x8 matrices where batch size is dynamic
119    // extents<size_t, dynamic_extent, 4, 8> means [?, 4, 8]
120    using batch_matrix_extents =
121        std::extents<std::size_t, std::dynamic_extent, 4, 8>;
122    std::mdspan batch{data.data(), batch_matrix_extents{1}}; // 1 batch
123  
124    if (batch.extent(0) != 1 || batch.extent(1) != 4 || batch.extent(2) != 8) {
125      std::printf("mdspan: mixed extents failed\n");
126      return false;
127    }
128  
129    // [0][2][3] = 0*32 + 2*8 + 3 = 19
130    if (batch[0, 2, 3] != 19.0) {
131      std::printf("mdspan: mixed extent access failed\n");
132      return false;
133    }
134  
135    return true;
136  }
137  
138  // ════════════════════════════════════════════════════════════════════════════════
139  // submdspan test (c++26 feature, but gcc15 has it)
140  // ════════════════════════════════════════════════════════════════════════════════
141  
142  auto test_submdspan() -> bool {
143    std::array<int, 12> data{};
144    std::iota(data.begin(), data.end(), 0);
145  
146    // 3x4 matrix
147    std::mdspan matrix{data.data(), std::extents<std::size_t, 3, 4>{}};
148  
149    // extract row 1 as a 1d span
150    auto row1 = std::submdspan(matrix, 1, std::full_extent);
151  
152    if (row1.extent(0) != 4) {
153      std::printf("submdspan: row extent wrong\n");
154      return false;
155    }
156  
157    // row 1 starts at element 4
158    if (row1[0] != 4 || row1[3] != 7) {
159      std::printf("submdspan: row elements wrong\n");
160      return false;
161    }
162  
163    // extract column 2 as strided 1d span
164    auto col2 = std::submdspan(matrix, std::full_extent, 2);
165  
166    if (col2.extent(0) != 3) {
167      std::printf("submdspan: column extent wrong\n");
168      return false;
169    }
170  
171    // column 2 elements: 2, 6, 10
172    if (col2[0] != 2 || col2[1] != 6 || col2[2] != 10) {
173      std::printf("submdspan: column elements wrong: %d %d %d\n", col2[0],
174                  col2[1], col2[2]);
175      return false;
176    }
177  
178    return true;
179  }
180  
181  auto implementation() -> int {
182    int failures = 0;
183  
184    std::printf("mdspan tests (gcc15 libstdc++ c++23):\n");
185  
186    if (test_basic_mdspan()) {
187      std::printf("  basic_mdspan: pass\n");
188    } else {
189      std::printf("  basic_mdspan: FAIL\n");
190      failures++;
191    }
192  
193    if (test_dynamic_extents()) {
194      std::printf("  dynamic_extents: pass\n");
195    } else {
196      std::printf("  dynamic_extents: FAIL\n");
197      failures++;
198    }
199  
200    if (test_layout_stride()) {
201      std::printf("  layout_stride: pass\n");
202    } else {
203      std::printf("  layout_stride: FAIL\n");
204      failures++;
205    }
206  
207    if (test_mixed_extents()) {
208      std::printf("  mixed_extents: pass\n");
209    } else {
210      std::printf("  mixed_extents: FAIL\n");
211      failures++;
212    }
213  
214    if (test_submdspan()) {
215      std::printf("  submdspan: pass\n");
216    } else {
217      std::printf("  submdspan: FAIL\n");
218      failures++;
219    }
220  
221    if (failures == 0) {
222      std::printf("all mdspan tests passed\n");
223      return 0;
224    } else {
225      std::printf("%d mdspan tests FAILED\n", failures);
226      return 1;
227    }
228  }
229  
230  } // namespace straylight::examples
231  
232  auto main(int argc, char *argv[]) -> int {
233    return straylight::examples::implementation();
234  }