mdspan_test.cpp
1 // examples/cxx/mdspan_test.cpp 2 // 3 // mdspan verification using Kokkos reference implementation 4 // (provides std::mdspan via <mdspan> header) 5 // 6 // verifies: 7 // - std::mdspan compiles and works 8 // - std::extents, std::dextents work 9 // - layout policies work 10 // - submdspan works 11 12 #include <array> 13 #include <cstdio> 14 #include <mdspan> 15 #include <numeric> 16 #include <span> 17 18 namespace straylight::examples { 19 20 // ════════════════════════════════════════════════════════════════════════════════ 21 // basic mdspan test - 2d matrix view 22 // ════════════════════════════════════════════════════════════════════════════════ 23 24 auto test_basic_mdspan() -> bool { 25 // 3x4 matrix stored in row-major order 26 std::array<float, 12> data{}; 27 std::iota(data.begin(), data.end(), 0.0f); // 0, 1, 2, ..., 11 28 29 // create mdspan view 30 std::mdspan matrix{data.data(), std::extents<std::size_t, 3, 4>{}}; 31 32 // verify dimensions 33 if (matrix.extent(0) != 3 || matrix.extent(1) != 4) { 34 std::printf("mdspan: extent mismatch\n"); 35 return false; 36 } 37 38 // verify element access - row 1, col 2 should be 1*4 + 2 = 6 39 if (matrix[1, 2] != 6.0f) { 40 std::printf("mdspan: element access failed, got %f expected 6.0\n", 41 static_cast<double>(matrix[1, 2])); 42 return false; 43 } 44 45 // verify we can modify through the view 46 matrix[2, 3] = 99.0f; 47 if (data[11] != 99.0f) { 48 std::printf("mdspan: modification failed\n"); 49 return false; 50 } 51 52 return true; 53 } 54 55 // ════════════════════════════════════════════════════════════════════════════════ 56 // dynamic extents test 57 // ════════════════════════════════════════════════════════════════════════════════ 58 59 auto test_dynamic_extents() -> bool { 60 std::array<int, 24> data{}; 61 std::iota(data.begin(), data.end(), 0); 62 63 // 2x3x4 tensor with all dynamic extents 64 std::mdspan tensor{data.data(), std::dextents<std::size_t, 3>{2, 3, 4}}; 65 66 if (tensor.extent(0) != 2 || tensor.extent(1) != 3 || tensor.extent(2) != 4) { 67 std::printf("mdspan: dynamic extent mismatch\n"); 68 return false; 69 } 70 71 // element [1][2][3] = 1*12 + 2*4 + 3 = 23 72 if (tensor[1, 2, 3] != 23) { 73 std::printf("mdspan: 3d access failed, got %d expected 23\n", tensor[1, 2, 3]); 74 return false; 75 } 76 77 return true; 78 } 79 80 // ════════════════════════════════════════════════════════════════════════════════ 81 // layout stride test - column major 82 // ════════════════════════════════════════════════════════════════════════════════ 83 84 auto test_layout_stride() -> bool { 85 std::array<float, 6> data{}; 86 std::iota(data.begin(), data.end(), 0.0f); 87 88 // 2x3 matrix in column-major order 89 using col_major = std::layout_left; 90 std::mdspan<float, std::extents<std::size_t, 2, 3>, col_major> matrix{data.data()}; 91 92 // in column major, [1][0] should be element 1 (second element of first 93 // column) 94 if (matrix[1, 0] != 1.0f) { 95 std::printf("mdspan: column major layout failed\n"); 96 return false; 97 } 98 99 // [0][1] should be element 2 (first element of second column) 100 if (matrix[0, 1] != 2.0f) { 101 std::printf("mdspan: column major [0,1] failed, got %f\n", static_cast<double>(matrix[0, 1])); 102 return false; 103 } 104 105 return true; 106 } 107 108 // ════════════════════════════════════════════════════════════════════════════════ 109 // mixed static/dynamic extents 110 // ════════════════════════════════════════════════════════════════════════════════ 111 112 auto test_mixed_extents() -> bool { 113 std::array<double, 32> data{}; 114 std::iota(data.begin(), data.end(), 0.0); 115 116 // batch of 4x8 matrices where batch size is dynamic 117 // extents<size_t, dynamic_extent, 4, 8> means [?, 4, 8] 118 using batch_matrix_extents = std::extents<std::size_t, std::dynamic_extent, 4, 8>; 119 std::mdspan batch{data.data(), batch_matrix_extents{1}}; // 1 batch 120 121 if (batch.extent(0) != 1 || batch.extent(1) != 4 || batch.extent(2) != 8) { 122 std::printf("mdspan: mixed extents failed\n"); 123 return false; 124 } 125 126 // [0][2][3] = 0*32 + 2*8 + 3 = 19 127 if (batch[0, 2, 3] != 19.0) { 128 std::printf("mdspan: mixed extent access failed\n"); 129 return false; 130 } 131 132 return true; 133 } 134 135 // ════════════════════════════════════════════════════════════════════════════════ 136 // submdspan test (c++26 feature, but gcc15 has it) 137 // ════════════════════════════════════════════════════════════════════════════════ 138 139 auto test_submdspan() -> bool { 140 std::array<int, 12> data{}; 141 std::iota(data.begin(), data.end(), 0); 142 143 // 3x4 matrix 144 std::mdspan matrix{data.data(), std::extents<std::size_t, 3, 4>{}}; 145 146 // extract row 1 as a 1d span 147 auto row1 = std::submdspan(matrix, 1, std::full_extent); 148 149 if (row1.extent(0) != 4) { 150 std::printf("submdspan: row extent wrong\n"); 151 return false; 152 } 153 154 // row 1 starts at element 4 155 if (row1[0] != 4 || row1[3] != 7) { 156 std::printf("submdspan: row elements wrong\n"); 157 return false; 158 } 159 160 // extract column 2 as strided 1d span 161 auto col2 = std::submdspan(matrix, std::full_extent, 2); 162 163 if (col2.extent(0) != 3) { 164 std::printf("submdspan: column extent wrong\n"); 165 return false; 166 } 167 168 // column 2 elements: 2, 6, 10 169 if (col2[0] != 2 || col2[1] != 6 || col2[2] != 10) { 170 std::printf("submdspan: column elements wrong: %d %d %d\n", col2[0], col2[1], col2[2]); 171 return false; 172 } 173 174 return true; 175 } 176 177 auto implementation() -> int { 178 int failures = 0; 179 180 std::printf("mdspan tests (gcc15 libstdc++ c++23):\n"); 181 182 if (test_basic_mdspan()) { 183 std::printf(" basic_mdspan: pass\n"); 184 } else { 185 std::printf(" basic_mdspan: FAIL\n"); 186 failures++; 187 } 188 189 if (test_dynamic_extents()) { 190 std::printf(" dynamic_extents: pass\n"); 191 } else { 192 std::printf(" dynamic_extents: FAIL\n"); 193 failures++; 194 } 195 196 if (test_layout_stride()) { 197 std::printf(" layout_stride: pass\n"); 198 } else { 199 std::printf(" layout_stride: FAIL\n"); 200 failures++; 201 } 202 203 if (test_mixed_extents()) { 204 std::printf(" mixed_extents: pass\n"); 205 } else { 206 std::printf(" mixed_extents: FAIL\n"); 207 failures++; 208 } 209 210 if (test_submdspan()) { 211 std::printf(" submdspan: pass\n"); 212 } else { 213 std::printf(" submdspan: FAIL\n"); 214 failures++; 215 } 216 217 if (failures == 0) { 218 std::printf("all mdspan tests passed\n"); 219 return 0; 220 } else { 221 std::printf("%d mdspan tests FAILED\n", failures); 222 return 1; 223 } 224 } 225 226 } // namespace straylight::examples 227 228 auto main(int argc, char* argv[]) -> int { 229 return straylight::examples::implementation(); 230 }