mdspan_test.cpp
1 // examples/cxx/mdspan_test.cpp 2 // 3 // c++23 std::mdspan verification - the whole point of gcc15 4 // 5 // verifies: 6 // - std::mdspan compiles and works 7 // - std::extents, std::dextents work 8 // - layout policies work 9 // - submdspan works (c++26 but gcc15 has it) 10 11 #include <array> 12 #include <cstdio> 13 #include <mdspan> 14 #include <numeric> 15 #include <span> 16 17 namespace straylight::examples { 18 19 // ════════════════════════════════════════════════════════════════════════════════ 20 // basic mdspan test - 2d matrix view 21 // ════════════════════════════════════════════════════════════════════════════════ 22 23 auto test_basic_mdspan() -> bool { 24 // 3x4 matrix stored in row-major order 25 std::array<float, 12> data{}; 26 std::iota(data.begin(), data.end(), 0.0f); // 0, 1, 2, ..., 11 27 28 // create mdspan view 29 std::mdspan matrix{data.data(), std::extents<std::size_t, 3, 4>{}}; 30 31 // verify dimensions 32 if (matrix.extent(0) != 3 || matrix.extent(1) != 4) { 33 std::printf("mdspan: extent mismatch\n"); 34 return false; 35 } 36 37 // verify element access - row 1, col 2 should be 1*4 + 2 = 6 38 if (matrix[1, 2] != 6.0f) { 39 std::printf("mdspan: element access failed, got %f expected 6.0\n", 40 static_cast<double>(matrix[1, 2])); 41 return false; 42 } 43 44 // verify we can modify through the view 45 matrix[2, 3] = 99.0f; 46 if (data[11] != 99.0f) { 47 std::printf("mdspan: modification failed\n"); 48 return false; 49 } 50 51 return true; 52 } 53 54 // ════════════════════════════════════════════════════════════════════════════════ 55 // dynamic extents test 56 // ════════════════════════════════════════════════════════════════════════════════ 57 58 auto test_dynamic_extents() -> bool { 59 std::array<int, 24> data{}; 60 std::iota(data.begin(), data.end(), 0); 61 62 // 2x3x4 tensor with all dynamic extents 63 std::mdspan tensor{data.data(), std::dextents<std::size_t, 3>{2, 3, 4}}; 64 65 if (tensor.extent(0) != 2 || tensor.extent(1) != 3 || tensor.extent(2) != 4) { 66 std::printf("mdspan: dynamic extent mismatch\n"); 67 return false; 68 } 69 70 // element [1][2][3] = 1*12 + 2*4 + 3 = 23 71 if (tensor[1, 2, 3] != 23) { 72 std::printf("mdspan: 3d access failed, got %d expected 23\n", 73 tensor[1, 2, 3]); 74 return false; 75 } 76 77 return true; 78 } 79 80 // ════════════════════════════════════════════════════════════════════════════════ 81 // layout stride test - column major 82 // ════════════════════════════════════════════════════════════════════════════════ 83 84 auto test_layout_stride() -> bool { 85 std::array<float, 6> data{}; 86 std::iota(data.begin(), data.end(), 0.0f); 87 88 // 2x3 matrix in column-major order 89 using col_major = std::layout_left; 90 std::mdspan<float, std::extents<std::size_t, 2, 3>, col_major> matrix{ 91 data.data()}; 92 93 // in column major, [1][0] should be element 1 (second element of first 94 // column) 95 if (matrix[1, 0] != 1.0f) { 96 std::printf("mdspan: column major layout failed\n"); 97 return false; 98 } 99 100 // [0][1] should be element 2 (first element of second column) 101 if (matrix[0, 1] != 2.0f) { 102 std::printf("mdspan: column major [0,1] failed, got %f\n", 103 static_cast<double>(matrix[0, 1])); 104 return false; 105 } 106 107 return true; 108 } 109 110 // ════════════════════════════════════════════════════════════════════════════════ 111 // mixed static/dynamic extents 112 // ════════════════════════════════════════════════════════════════════════════════ 113 114 auto test_mixed_extents() -> bool { 115 std::array<double, 32> data{}; 116 std::iota(data.begin(), data.end(), 0.0); 117 118 // batch of 4x8 matrices where batch size is dynamic 119 // extents<size_t, dynamic_extent, 4, 8> means [?, 4, 8] 120 using batch_matrix_extents = 121 std::extents<std::size_t, std::dynamic_extent, 4, 8>; 122 std::mdspan batch{data.data(), batch_matrix_extents{1}}; // 1 batch 123 124 if (batch.extent(0) != 1 || batch.extent(1) != 4 || batch.extent(2) != 8) { 125 std::printf("mdspan: mixed extents failed\n"); 126 return false; 127 } 128 129 // [0][2][3] = 0*32 + 2*8 + 3 = 19 130 if (batch[0, 2, 3] != 19.0) { 131 std::printf("mdspan: mixed extent access failed\n"); 132 return false; 133 } 134 135 return true; 136 } 137 138 // ════════════════════════════════════════════════════════════════════════════════ 139 // submdspan test (c++26 feature, but gcc15 has it) 140 // ════════════════════════════════════════════════════════════════════════════════ 141 142 auto test_submdspan() -> bool { 143 std::array<int, 12> data{}; 144 std::iota(data.begin(), data.end(), 0); 145 146 // 3x4 matrix 147 std::mdspan matrix{data.data(), std::extents<std::size_t, 3, 4>{}}; 148 149 // extract row 1 as a 1d span 150 auto row1 = std::submdspan(matrix, 1, std::full_extent); 151 152 if (row1.extent(0) != 4) { 153 std::printf("submdspan: row extent wrong\n"); 154 return false; 155 } 156 157 // row 1 starts at element 4 158 if (row1[0] != 4 || row1[3] != 7) { 159 std::printf("submdspan: row elements wrong\n"); 160 return false; 161 } 162 163 // extract column 2 as strided 1d span 164 auto col2 = std::submdspan(matrix, std::full_extent, 2); 165 166 if (col2.extent(0) != 3) { 167 std::printf("submdspan: column extent wrong\n"); 168 return false; 169 } 170 171 // column 2 elements: 2, 6, 10 172 if (col2[0] != 2 || col2[1] != 6 || col2[2] != 10) { 173 std::printf("submdspan: column elements wrong: %d %d %d\n", col2[0], 174 col2[1], col2[2]); 175 return false; 176 } 177 178 return true; 179 } 180 181 auto implementation() -> int { 182 int failures = 0; 183 184 std::printf("mdspan tests (gcc15 libstdc++ c++23):\n"); 185 186 if (test_basic_mdspan()) { 187 std::printf(" basic_mdspan: pass\n"); 188 } else { 189 std::printf(" basic_mdspan: FAIL\n"); 190 failures++; 191 } 192 193 if (test_dynamic_extents()) { 194 std::printf(" dynamic_extents: pass\n"); 195 } else { 196 std::printf(" dynamic_extents: FAIL\n"); 197 failures++; 198 } 199 200 if (test_layout_stride()) { 201 std::printf(" layout_stride: pass\n"); 202 } else { 203 std::printf(" layout_stride: FAIL\n"); 204 failures++; 205 } 206 207 if (test_mixed_extents()) { 208 std::printf(" mixed_extents: pass\n"); 209 } else { 210 std::printf(" mixed_extents: FAIL\n"); 211 failures++; 212 } 213 214 if (test_submdspan()) { 215 std::printf(" submdspan: pass\n"); 216 } else { 217 std::printf(" submdspan: FAIL\n"); 218 failures++; 219 } 220 221 if (failures == 0) { 222 std::printf("all mdspan tests passed\n"); 223 return 0; 224 } else { 225 std::printf("%d mdspan tests FAILED\n", failures); 226 return 1; 227 } 228 } 229 230 } // namespace straylight::examples 231 232 auto main(int argc, char *argv[]) -> int { 233 return straylight::examples::implementation(); 234 }