/ tests / test-backend-ops.cpp
test-backend-ops.cpp
   1  #include <ggml.h>
   2  #include <ggml-alloc.h>
   3  #include <ggml-backend.h>
   4  #include <ggml-backend-impl.h>
   5  
   6  #include <algorithm>
   7  #include <array>
   8  #include <cfloat>
   9  #include <cstring>
  10  #include <functional>
  11  #include <memory>
  12  #include <random>
  13  #include <stdio.h>
  14  #include <stdlib.h>
  15  #include <string>
  16  #include <thread>
  17  #include <vector>
  18  
  19  
  20  static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float max = 1.0f) {
  21      // static RNG initialization (revisit if n_threads stops being constant)
  22      static const size_t n_threads = std::thread::hardware_concurrency();
  23      static std::vector<std::default_random_engine> generators = []() {
  24          std::random_device rd;
  25          std::vector<std::default_random_engine> vec;
  26          vec.reserve(n_threads);
  27          //for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(1234 + i); } // fixed seed
  28          for (size_t i = 0; i < n_threads; i++) { vec.emplace_back(rd()); }
  29          return vec;
  30      }();
  31  
  32      size_t size = ggml_nelements(tensor);
  33      std::vector<float> data(size);
  34  
  35      auto init_thread = [&](size_t ith, size_t start, size_t end) {
  36          std::uniform_real_distribution<float> distribution(min, max);
  37          for (size_t i = start; i < end; i++) {
  38              data[i] = distribution(generators[ith]);
  39          }
  40      };
  41  
  42      std::vector<std::thread> threads;
  43      threads.reserve(n_threads);
  44      for (size_t i = 0; i < n_threads; i++) {
  45          size_t start =     i*size/n_threads;
  46          size_t end   = (i+1)*size/n_threads;
  47          threads.emplace_back(init_thread, i, start, end);
  48      }
  49      for (auto & t : threads) {
  50          t.join();
  51      }
  52  
  53  #if 0
  54      const char * val_str = getenv("GGML_TEST_EPS");
  55      float val = 1e-9f;
  56      if (val_str != nullptr) {
  57          val = std::stof(val_str);
  58          printf("GGML_TEST_EPS=%e\n", val);
  59      }
  60  
  61      // test quantization with very small values that may result in nan scales due to division by zero
  62      if (ggml_is_quantized(tensor->type)) {
  63          for (int i = 0; i < 256; i++) {
  64              data[i] = val;
  65          }
  66      }
  67  #endif
  68  
  69      if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
  70          ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
  71      } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) {
  72          GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
  73          std::vector<uint8_t> dataq(ggml_row_size(tensor->type, size));
  74          std::vector<float> imatrix(tensor->ne[0], 1.0f); // dummy importance matrix
  75          const float * im = imatrix.data();
  76          if (!ggml_quantize_requires_imatrix(tensor->type)) {
  77              // when the imatrix is optional, we want to test both quantization with and without imatrix
  78              // use one of the random numbers to decide
  79              if (data[0] > 0.5f*(min + max)) {
  80                  im = nullptr;
  81              }
  82          }
  83          ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size/tensor->ne[0], tensor->ne[0], im);
  84          GGML_ASSERT(ggml_validate_row_data(tensor->type, dataq.data(), dataq.size()));
  85          ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size());
  86      } else if (tensor->type == GGML_TYPE_I8 || tensor->type == GGML_TYPE_I16 || tensor->type == GGML_TYPE_I32) {
  87          // This is going to create some weird integers though.
  88          ggml_backend_tensor_set(tensor, data.data(), 0, ggml_nbytes(tensor));
  89      } else {
  90          GGML_ASSERT(false);
  91      }
  92  }
  93  
  94  static std::vector<float> tensor_to_float(const ggml_tensor * t) {
  95      std::vector<float> tv;
  96      tv.reserve(ggml_nelements(t));
  97  
  98      std::vector<uint8_t> buf(ggml_nbytes(t));
  99      ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
 100  
 101      ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type);
 102      size_t bs = ggml_blck_size(t->type);
 103      std::vector<float> vq(ggml_blck_size(t->type));
 104      bool quantized = ggml_is_quantized(t->type);
 105  
 106      // access elements by index to avoid gaps in views
 107      for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
 108          for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
 109              for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
 110                  for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
 111                      size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
 112                      if (t->type == GGML_TYPE_F16) {
 113                          tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
 114                      } else if (t->type == GGML_TYPE_BF16) {
 115                          tv.push_back(ggml_bf16_to_fp32(*(ggml_bf16_t*)&buf[i]));
 116                      } else if (t->type == GGML_TYPE_F32) {
 117                          tv.push_back(*(float *) &buf[i]);
 118                      } else if (t->type == GGML_TYPE_I32) {
 119                          tv.push_back((float)*(int32_t *) &buf[i]);
 120                      } else if (t->type == GGML_TYPE_I16) {
 121                          tv.push_back((float)*(int16_t *) &buf[i]);
 122                      } else if (t->type == GGML_TYPE_I8) {
 123                          tv.push_back((float)*(int8_t *) &buf[i]);
 124                      } else if (quantized) {
 125                          tt.to_float(&buf[i], vq.data(), bs);
 126                          tv.insert(tv.end(), vq.begin(), vq.end());
 127                      } else {
 128                          GGML_ASSERT(false);
 129                      }
 130                  }
 131              }
 132          }
 133      }
 134  
 135      return tv;
 136  }
 137  
 138  /*
 139  static double cosine_similarity(const float * v1, const float * v2, size_t n) {
 140      double dot = 0.0;
 141      double mag1 = 0.0;
 142      double mag2 = 0.0;
 143  
 144      for (size_t i = 0; i < n; i++) {
 145          if (std::isnan(v1[i]) || std::isnan(v2[i])) {
 146              return -1.0f;
 147          }
 148          if (std::isinf(v1[i]) && std::isinf(v2[i])) {
 149              continue;
 150          }
 151          dot  += v1[i]*v2[i];
 152          mag1 += v1[i]*v1[i];
 153          mag2 += v2[i]*v2[i];
 154      }
 155  
 156      return dot/sqrt(mag1*mag2);
 157  }
 158  
 159  static float distance(const float * v1, const float * v2, size_t n) {
 160      double d = 0.0;
 161  
 162      for (size_t i = 0; i < n; i++) {
 163          if (std::isnan(v1[i]) || std::isnan(v2[i])) {
 164              return INFINITY;
 165          }
 166          if (std::isinf(v1[i]) && std::isinf(v2[i])) {
 167              continue;
 168          }
 169          d += (v1[i] - v2[i])*(v1[i] - v2[i]);
 170      }
 171  
 172      return sqrt(d);
 173  }
 174  
 175  static float vec_len(const float * v, size_t n) {
 176      double d = 0.0;
 177  
 178      for (size_t i = 0; i < n; i++) {
 179          if (std::isnan(v[i])) {
 180              return INFINITY;
 181          }
 182          if (std::isinf(v[i])) {
 183              continue;
 184          }
 185          d += v[i]*v[i];
 186      }
 187  
 188      return sqrt(d);
 189  }
 190  */
 191  
 192  // normalized mean squared error = mse(a, b) / mse(a, 0)
 193  static double nmse(const float * a, const float * b, size_t n) {
 194      double mse_a_b = 0.0;
 195      double mse_a_0 = 0.0;
 196  
 197      for (size_t i = 0; i < n; i++) {
 198          float a_i = a[i];
 199          float b_i = b[i];
 200  
 201          mse_a_b += (a_i - b_i) * (a_i - b_i);
 202          mse_a_0 += a_i * a_i;
 203      }
 204  
 205      return mse_a_b / mse_a_0;
 206  }
 207  
 208  // utils for printing the variables of the test cases
 209  #define VAR_TO_STR(x) (#x "=" + var_to_str(x))
 210  
 211  template<typename T>
 212  static std::string var_to_str(const T & x) {
 213      return std::to_string(x);
 214  }
 215  
 216  template<typename T, size_t N>
 217  static std::string var_to_str(const T (&x)[N]) {
 218      std::string s = "[";
 219      for (size_t i = 0; i < N; i++) {
 220          if (i > 0) {
 221              s += ",";
 222          }
 223          s += var_to_str(x[i]);
 224      }
 225      s += "]";
 226      return s;
 227  }
 228  
 229  template<typename T, size_t N>
 230  static std::string var_to_str(const std::array<T, N> & x) {
 231      std::string s = "[";
 232      for (size_t i = 0; i < N; i++) {
 233          if (i > 0) {
 234              s += ",";
 235          }
 236          s += var_to_str(x[i]);
 237      }
 238      s += "]";
 239      return s;
 240  }
 241  
 242  //static std::string var_to_str(ggml_unary_op unary_op) {
 243  //    return ggml_unary_op_name(unary_op);
 244  //}
 245  
 246  static std::string var_to_str(ggml_type type) {
 247      return ggml_type_name(type);
 248  }
 249  
 250  static std::string var_to_str(ggml_op_pool pool) {
 251      switch (pool) {
 252          case GGML_OP_POOL_AVG:  return "avg";
 253          case GGML_OP_POOL_MAX:  return "max";
 254          default:                return std::to_string(pool);
 255      }
 256  }
 257  
 258  #define VARS_TO_STR1(a) VAR_TO_STR(a)
 259  #define VARS_TO_STR2(a, b) VAR_TO_STR(a) + "," + VAR_TO_STR(b)
 260  #define VARS_TO_STR3(a, b, c) VAR_TO_STR(a) + "," + VARS_TO_STR2(b, c)
 261  #define VARS_TO_STR4(a, b, c, d) VAR_TO_STR(a) + "," + VARS_TO_STR3(b, c, d)
 262  #define VARS_TO_STR5(a, b, c, d, e) VAR_TO_STR(a) + "," + VARS_TO_STR4(b, c, d, e)
 263  #define VARS_TO_STR6(a, b, c, d, e, f) VAR_TO_STR(a) + "," + VARS_TO_STR5(b, c, d, e, f)
 264  #define VARS_TO_STR7(a, b, c, d, e, f, g) VAR_TO_STR(a) + "," + VARS_TO_STR6(b, c, d, e, f, g)
 265  #define VARS_TO_STR8(a, b, c, d, e, f, g, h) VAR_TO_STR(a) + "," + VARS_TO_STR7(b, c, d, e, f, g, h)
 266  #define VARS_TO_STR9(a, b, c, d, e, f, g, h, i) VAR_TO_STR(a) + "," + VARS_TO_STR8(b, c, d, e, f, g, h, i)
 267  #define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
 268  #define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
 269  #define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
 270  
 271  #ifdef GGML_USE_SYCL
 272  static bool inline _isinf(float f) {
 273      return (*(uint32_t *)&f & 0x7fffffff) == 0x7f800000;
 274  }
 275  #else
 276  static bool inline _isinf(float f) { return std::isinf(f); }
 277  #endif
 278  
 279  // accept FLT_MAX as infinity
 280  static bool isinf_or_max(float f) {
 281      return _isinf(f) || f == FLT_MAX || f == -FLT_MAX;
 282  }
 283  
 284  static bool ggml_is_view_op(enum ggml_op op) {
 285      return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
 286  }
 287  
 288  enum test_mode {
 289      MODE_TEST,
 290      MODE_PERF,
 291  };
 292  
 293  struct test_case {
 294      virtual ~test_case() {}
 295  
 296      virtual std::string op_desc(ggml_tensor * t) {
 297          return ggml_op_desc(t);
 298      }
 299  
 300      virtual std::string vars() {
 301          return "";
 302      }
 303  
 304      virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
 305  
 306      virtual double max_nmse_err() {
 307          return 1e-7;
 308      }
 309  
 310      virtual void initialize_tensors(ggml_context * ctx) {
 311          for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) {
 312              init_tensor_uniform(t);
 313          }
 314      }
 315  
 316      virtual size_t op_size(ggml_tensor * t) {
 317          size_t size = ggml_nbytes(t);
 318          // add source tensors
 319          for (int i = 0; i < GGML_MAX_SRC; i++) {
 320              if (t->src[i] != NULL) {
 321                  size += ggml_nbytes(t->src[i]);
 322              }
 323          }
 324          return size;
 325      }
 326  
 327      ggml_cgraph * gf = nullptr;
 328  
 329      static const int sentinel_size = 1024;
 330  
 331      test_mode mode;
 332  
 333      std::vector<ggml_tensor *> sentinels;
 334  
 335      void add_sentinel(ggml_context * ctx) {
 336          if (mode == MODE_PERF) {
 337              return;
 338          }
 339          ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size);
 340          ggml_format_name(sentinel, "sent_%zu", sentinels.size());
 341          sentinels.push_back(sentinel);
 342      }
 343  
 344      // hijack ggml_new_tensor to add sentinels after each tensor to check for overflows in the backend
 345  
 346      ggml_tensor * ggml_new_tensor(ggml_context * ctx, ggml_type type, int n_dims, const int64_t * ne) {
 347          ggml_tensor * t = ::ggml_new_tensor(ctx, type, n_dims, ne);
 348          add_sentinel(ctx);
 349          return t;
 350      }
 351  
 352      ggml_tensor * ggml_new_tensor_1d(ggml_context * ctx, ggml_type type, int64_t ne0) {
 353          ggml_tensor * t = ::ggml_new_tensor_1d(ctx, type, ne0);
 354          add_sentinel(ctx);
 355          return t;
 356      }
 357  
 358      ggml_tensor * ggml_new_tensor_2d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1) {
 359          ggml_tensor * t = ::ggml_new_tensor_2d(ctx, type, ne0, ne1);
 360          add_sentinel(ctx);
 361          return t;
 362      }
 363  
 364      ggml_tensor * ggml_new_tensor_3d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2) {
 365          ggml_tensor * t = ::ggml_new_tensor_3d(ctx, type, ne0, ne1, ne2);
 366          add_sentinel(ctx);
 367          return t;
 368      }
 369  
 370      ggml_tensor * ggml_new_tensor_4d(ggml_context * ctx, ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
 371          ggml_tensor * t = ::ggml_new_tensor_4d(ctx, type, ne0, ne1, ne2, ne3);
 372          add_sentinel(ctx);
 373          return t;
 374      }
 375  
 376      bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) {
 377          mode = MODE_TEST;
 378  
 379          ggml_init_params params = {
 380              /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
 381              /* .mem_base = */ NULL,
 382              /* .no_alloc = */ true,
 383          };
 384          ggml_context * ctx = ggml_init(params);
 385  
 386          gf = ggml_new_graph(ctx);
 387  
 388          // pre-graph sentinel
 389          add_sentinel(ctx);
 390  
 391          ggml_tensor * out = build_graph(ctx);
 392  
 393          if (op_name != nullptr && op_desc(out) != op_name) {
 394              //printf("  %s: skipping\n", op_desc(out).c_str());
 395              ggml_free(ctx);
 396              return true;
 397          }
 398  
 399          printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
 400          fflush(stdout);
 401  
 402          // check if the backends support the ops
 403          bool supported = true;
 404          for (ggml_backend_t backend : {backend1, backend2}) {
 405              for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
 406                  if (!ggml_backend_supports_op(backend, t)) {
 407                      printf("not supported [%s] ", ggml_backend_name(backend));
 408                      supported = false;
 409                      break;
 410                  }
 411              }
 412          }
 413          if (!supported) {
 414              printf("\n");
 415              ggml_free(ctx);
 416              return true;
 417          }
 418  
 419          // post-graph sentinel
 420          add_sentinel(ctx);
 421  
 422          // allocate
 423          ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
 424          if (buf == NULL) {
 425              printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
 426              ggml_free(ctx);
 427              return false;
 428          }
 429  
 430          // build graph
 431          ggml_build_forward_expand(gf, out);
 432  
 433          // add sentinels as graph nodes so that they are checked in the callback
 434          for (ggml_tensor * sentinel : sentinels) {
 435              gf->nodes[gf->n_nodes++] = sentinel;
 436          }
 437  
 438          // randomize tensors
 439          initialize_tensors(ctx);
 440  
 441          // compare
 442          struct callback_userdata {
 443              bool   ok;
 444              double max_err;
 445              ggml_backend_t backend1;
 446              ggml_backend_t backend2;
 447          };
 448  
 449          callback_userdata ud {
 450              true,
 451              max_nmse_err(),
 452              backend1,
 453              backend2
 454          };
 455  
 456          auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool {
 457              callback_userdata * ud = (callback_userdata *) user_data;
 458              const char * bn1 = ggml_backend_name(ud->backend1);
 459              const char * bn2 = ggml_backend_name(ud->backend2);
 460  
 461              if (t1->op == GGML_OP_NONE) {
 462                  // sentinels must be unchanged
 463                  std::vector<uint8_t> t1_data(ggml_nbytes(t1));
 464                  std::vector<uint8_t> t2_data(ggml_nbytes(t2));
 465                  ggml_backend_tensor_get(t1, t1_data.data(), 0, ggml_nbytes(t1));
 466                  ggml_backend_tensor_get(t2, t2_data.data(), 0, ggml_nbytes(t2));
 467  
 468                  if (memcmp(t1_data.data(), t2_data.data(), ggml_nbytes(t1)) != 0) {
 469                      printf("sentinel mismatch: %s ", t1->name);
 470                      ud->ok = false;
 471                      return true;
 472                  }
 473              }
 474  
 475              std::vector<float> f1 = tensor_to_float(t1);
 476              std::vector<float> f2 = tensor_to_float(t2);
 477  
 478              for (size_t i = 0; i < f1.size(); i++) {
 479                  // check for nans
 480                  if (std::isnan(f1[i]) || std::isnan(f2[i])) {
 481                      printf("[%s] NaN at index %zu (%s=%f %s=%f) ", ggml_op_desc(t1), i, bn1, f1[i], bn2, f2[i]);
 482                      ud->ok = false;
 483                      return true;
 484                  }
 485                  // check for infs: both must be inf of the same sign, or both must be finite
 486                  if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
 487                      if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
 488                          if (std::signbit(f1[i]) != std::signbit(f2[i])) {
 489                              printf("[%s] inf sign mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]);
 490                              ud->ok = false;
 491                              return true;
 492                          }
 493                      } else {
 494                          printf("[%s] inf mismatch: %s=%f %s=%f ", ggml_op_desc(t1), bn1, f1[i], bn2, f2[i]);
 495                          ud->ok = false;
 496                          return true;
 497                      }
 498                  }
 499              }
 500  
 501              double err = nmse(f1.data(), f2.data(), f1.size());
 502              if (err > ud->max_err) {
 503                  printf("[%s] NMSE = %.9f > %.9f ", ggml_op_desc(t1), err, ud->max_err);
 504                  //for (int i = 0; i < (int) f1.size(); i++) {
 505                  //    printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
 506                  //}
 507                  //printf("\n");
 508                  //exit(1);
 509                  ud->ok = false;
 510              }
 511              return true;
 512  
 513              GGML_UNUSED(index);
 514          };
 515  
 516          const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
 517  
 518          if (!cmp_ok) {
 519              printf("compare failed ");
 520          }
 521  
 522          ggml_backend_buffer_free(buf);
 523  
 524          ggml_free(ctx);
 525  
 526          if (ud.ok && cmp_ok) {
 527              printf("\033[1;32mOK\033[0m\n");
 528              return true;
 529          }
 530  
 531          printf("\033[1;31mFAIL\033[0m\n");
 532          return false;
 533      }
 534  
 535      bool eval_perf(ggml_backend_t backend, const char * op_name) {
 536          mode = MODE_PERF;
 537  
 538          static const size_t graph_nodes = 8192;
 539  
 540          ggml_init_params params = {
 541              /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false),
 542              /* .mem_base = */ NULL,
 543              /* .no_alloc = */ true,
 544          };
 545          ggml_context * ctx = ggml_init(params);
 546  
 547          ggml_tensor * out = build_graph(ctx);
 548  
 549          if (op_name != nullptr && op_desc(out) != op_name) {
 550              //printf("  %s: skipping\n", op_desc(out).c_str());
 551              ggml_free(ctx);
 552              return true;
 553          }
 554  
 555          int len = printf("  %s(%s): ", op_desc(out).c_str(), vars().c_str());
 556          fflush(stdout);
 557  
 558          // check if backends support op
 559          if (!ggml_backend_supports_op(backend, out)) {
 560              printf("not supported\n");
 561              ggml_free(ctx);
 562              return true;
 563          }
 564  
 565          // align while also leaving some margin for variations in parameters
 566          int align = 20;
 567          int last = (len + align - 1) / align * align;
 568          if (last - len < 5) {
 569              last += align;
 570          }
 571          last = std::max(last, 60);
 572          printf("%*s", last - len, "");
 573  
 574          // allocate
 575          ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
 576          if (buf == NULL) {
 577              printf("failed to allocate tensors\n");
 578              ggml_free(ctx);
 579              return false;
 580          }
 581  
 582          // randomize tensors
 583          initialize_tensors(ctx);
 584  
 585          // build graph
 586          ggml_cgraph * gf = ggml_new_graph_custom(ctx, graph_nodes, false);
 587          ggml_build_forward_expand(gf, out);
 588  
 589          // warmup run
 590          ggml_backend_graph_compute(backend, gf);
 591  
 592          // duplicate the op
 593          size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
 594          int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
 595          for (int i = 1; i < n_runs; i++) {
 596              gf->nodes[gf->n_nodes++] = out;
 597          }
 598  
 599          // calculate memory
 600          size_t mem = n_runs * op_size(out);
 601          auto tensor_op_size = [](ggml_tensor * t) {
 602              size_t size = ggml_nbytes(t);
 603              // add source tensors
 604              for (int i = 0; i < GGML_MAX_SRC; i++) {
 605                  if (t->src[i] != NULL) {
 606                      size += ggml_nbytes(t->src[i]);
 607                  }
 608              }
 609              return size;
 610          };
 611          for (int i = 0; i < gf->n_nodes; i++) {
 612              if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) {
 613                  continue;
 614              }
 615              mem += tensor_op_size(gf->nodes[i]);
 616          }
 617  
 618          // run
 619          ggml_backend_synchronize(backend);
 620  
 621          int64_t start_time = ggml_time_us();
 622          ggml_backend_graph_compute(backend, gf);
 623          ggml_backend_synchronize(backend);
 624          int64_t end_time = ggml_time_us();
 625          double time_us = end_time - start_time;
 626  
 627          printf("    %5d runs - %8.2f us/run - %8zu kB/run - \033[1;34m%7.2f GB/s\033[0m\n",
 628              n_runs,
 629              time_us / n_runs,
 630              op_size(out) / 1024,
 631              mem / (time_us/1e6) / 1024.0 / 1024.0 / 1024.0);
 632  
 633          ggml_backend_buffer_free(buf);
 634  
 635          ggml_free(ctx);
 636  
 637          return true;
 638      }
 639  };
 640  
 641  // GGML_OP_UNARY
 642  struct test_unary : public test_case {
 643      const ggml_unary_op op;
 644      const ggml_type type;
 645      const std::array<int64_t, 4> ne_a;
 646      int v; // view (1 : non-contiguous a)
 647  
 648      std::string vars() override {
 649          return VARS_TO_STR3(type, ne_a, v);
 650      }
 651  
 652      test_unary(ggml_unary_op op,
 653              ggml_type type = GGML_TYPE_F32,
 654              std::array<int64_t, 4> ne_a = {128, 10, 10, 10},
 655              int v = 0)
 656          : op(op), type(type), ne_a(ne_a), v(v) {}
 657  
 658      ggml_tensor * build_graph(ggml_context * ctx) override {
 659          ggml_tensor * a;
 660          if (v & 1) {
 661              auto ne = ne_a; ne[0] *= 3;
 662              a = ggml_new_tensor(ctx, type, 4, ne.data());
 663              a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
 664          } else {
 665              a = ggml_new_tensor(ctx, type, 4, ne_a.data());
 666          }
 667          ggml_tensor * out = ggml_unary(ctx, a, op);
 668          return out;
 669      }
 670  
 671      void initialize_tensors(ggml_context * ctx) override {
 672          for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
 673              // test extended range of values to check for NaNs in GELU
 674              init_tensor_uniform(t, -150.f, 150.f);
 675          }
 676      }
 677  };
 678  
 679  // GGML_OP_GET_ROWS
 680  struct test_get_rows : public test_case {
 681      const ggml_type type;
 682      const int n; // cols
 683      const int m; // rows
 684      const int r; // rows to get
 685      const int b; // batch size
 686      const bool v; // view (non-contiguous src1)
 687  
 688      std::string vars() override {
 689          return VARS_TO_STR6(type, n, m, r, b, v);
 690      }
 691  
 692      test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
 693          : type(type), n(n), m(m), r(r), b(b), v(v) {}
 694  
 695      ggml_tensor * build_graph(ggml_context * ctx) override {
 696          ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
 697          ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
 698          if (v) {
 699              rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
 700          }
 701          ggml_tensor * out = ggml_get_rows(ctx, in, rows);
 702          return out;
 703      }
 704  
 705      void initialize_tensors(ggml_context * ctx) override {
 706          for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
 707              if (t->type == GGML_TYPE_I32) {
 708                  if (ggml_is_view_op(t->op)) { continue; }
 709                  // rows
 710                  std::vector<int> data(r*b);
 711                  for (int i = 0; i < r*b; i++) {
 712                      data[i] = rand() % m;
 713                  }
 714                  ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
 715              } else {
 716                  init_tensor_uniform(t);
 717              }
 718          }
 719      }
 720  };
 721  
 722  // GGML_OP_REPEAT
 723  struct test_repeat : public test_case {
 724      const ggml_type type;
 725      const std::array<int64_t, 4> ne;
 726      const std::array<int, 4> nr;
 727  
 728      std::string vars() override {
 729          return VARS_TO_STR3(type, ne, nr);
 730      }
 731  
 732      size_t op_size(ggml_tensor * t) override {
 733          return ggml_nbytes(t) * 2;
 734      }
 735  
 736      test_repeat(ggml_type type = GGML_TYPE_F32,
 737              std::array<int64_t, 4> ne = {10, 10, 10, 10},
 738              std::array<int, 4> nr = {2, 2, 2, 2})
 739          : type(type), ne(ne), nr(nr) {}
 740  
 741      ggml_tensor * build_graph(ggml_context * ctx) override {
 742          ggml_tensor * target = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
 743          ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
 744          ggml_tensor * out = ggml_repeat(ctx, src, target);
 745          return out;
 746      }
 747  };
 748  
 749  // GGML_OP_DUP
 750  struct test_dup : public test_case {
 751      const ggml_type type;
 752      const std::array<int64_t, 4> ne;
 753      const std::array<int64_t, 4> permute;
 754      bool _use_permute;
 755  
 756      std::string vars() override {
 757          std::string v = VARS_TO_STR2(type, ne);
 758          if (_use_permute) v += "," + VAR_TO_STR(permute);
 759          return v;
 760      }
 761  
 762      test_dup(ggml_type type = GGML_TYPE_F32,
 763              std::array<int64_t, 4> ne = {10, 10, 10, 1},
 764              std::array<int64_t, 4> permute = {0, 0, 0, 0})
 765          : type(type), ne(ne), permute(permute),
 766              _use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
 767  
 768      ggml_tensor * build_graph(ggml_context * ctx) override {
 769          ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
 770          if (_use_permute) {
 771              src = ggml_permute(ctx, src, permute[0], permute[1], permute[2], permute[3]);
 772          }
 773          ggml_tensor * out = ggml_dup(ctx, src);
 774          return out;
 775      }
 776  };
 777  
 778  // GGML_OP_CPY
 779  struct test_cpy : public test_case {
 780      const ggml_type type_src;
 781      const ggml_type type_dst;
 782      const std::array<int64_t, 4> ne;
 783  
 784      std::string vars() override {
 785          return VARS_TO_STR3(type_src, type_dst, ne);
 786      }
 787  
 788      size_t op_size(ggml_tensor * t) override {
 789          return ggml_nbytes(t) + ggml_nbytes(t->src[0]);
 790      }
 791  
 792      test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
 793              std::array<int64_t, 4> ne = {10, 10, 10, 1})
 794          : type_src(type_src), type_dst(type_dst), ne(ne) {}
 795  
 796      ggml_tensor * build_graph(ggml_context * ctx) override {
 797          ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
 798          ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne.data());
 799          ggml_tensor * out = ggml_cpy(ctx, src, dst);
 800          return out;
 801      }
 802  };
 803  
 804  // GGML_OP_CONT
 805  struct test_cont : public test_case {
 806      const ggml_type type;
 807      const std::array<int64_t, 4> ne;
 808  
 809      std::string vars() override {
 810          return VARS_TO_STR2(type, ne);
 811      }
 812  
 813      test_cont(ggml_type type = GGML_TYPE_F32,
 814              std::array<int64_t, 4> ne = {10, 10, 10, 1})
 815          : type(type), ne(ne) {}
 816  
 817      ggml_tensor * build_graph(ggml_context * ctx) override {
 818          ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
 819          src = ggml_transpose(ctx, src);
 820          ggml_tensor * out = ggml_cont(ctx, src);
 821  
 822          return out;
 823      }
 824  };
 825  
 826  // GGML_OP_ADD
 827  // GGML_OP_MUL
 828  // GGML_OP_DIV
 829  struct test_bin_bcast : public test_case {
 830      using op_t = ggml_tensor * (*) (ggml_context *, ggml_tensor *, ggml_tensor *);
 831      op_t op;
 832      const ggml_type type;
 833      const std::array<int64_t, 4> ne;
 834      const std::array<int, 4> nr;
 835  
 836      std::string vars() override {
 837          return VARS_TO_STR3(type, ne, nr);
 838      }
 839  
 840      size_t op_size(ggml_tensor * t) override {
 841          return ggml_nbytes(t) * 3;
 842      }
 843  
 844      test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32,
 845              std::array<int64_t, 4> ne = {10, 10, 1, 1},
 846              std::array<int, 4> nr = {1, 2, 1, 1})
 847          : op(op), type(type), ne(ne), nr(nr) {}
 848  
 849      ggml_tensor * build_graph(ggml_context * ctx) override {
 850          ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
 851          ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
 852          ggml_tensor * out = op(ctx, a, b);
 853          return out;
 854      }
 855  
 856      void initialize_tensors(ggml_context * ctx) override {
 857          for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
 858              if (op == ggml_div) {
 859                  // avoid division by zero
 860                  init_tensor_uniform(t, 1.0f, 2.0f);
 861              } else {
 862                  init_tensor_uniform(t);
 863              }
 864          }
 865      }
 866  };
 867  
 868  // GGML_OP_SCALE
 869  struct test_scale : public test_case {
 870      const ggml_type type;
 871      const std::array<int64_t, 4> ne;
 872      float scale;
 873  
 874      std::string vars() override {
 875          return VARS_TO_STR3(type, ne, scale);
 876      }
 877  
 878      test_scale(ggml_type type = GGML_TYPE_F32,
 879              std::array<int64_t, 4> ne = {10, 10, 10, 10},
 880              float scale = 2.0f)
 881          : type(type), ne(ne), scale(scale) {}
 882  
 883      ggml_tensor * build_graph(ggml_context * ctx) override {
 884          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
 885          ggml_tensor * out = ggml_scale(ctx, a, scale);
 886          return out;
 887      }
 888  };
 889  
 890  // GGML_OP_NORM
 891  struct test_norm : public test_case {
 892      const ggml_type type;
 893      const std::array<int64_t, 4> ne;
 894      float eps;
 895  
 896      std::string vars() override {
 897          return VARS_TO_STR3(type, ne, eps);
 898      }
 899  
 900      test_norm(ggml_type type = GGML_TYPE_F32,
 901              std::array<int64_t, 4> ne = {64, 10, 10, 10},
 902              float eps = 1e-6f)
 903          : type(type), ne(ne), eps(eps) {}
 904  
 905      ggml_tensor * build_graph(ggml_context * ctx) override {
 906          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
 907          ggml_tensor * out = ggml_norm(ctx, a, eps);
 908          return out;
 909      }
 910  };
 911  
 912  // GGML_OP_RMS_NORM
 913  struct test_rms_norm : public test_case {
 914      const ggml_type type;
 915      const std::array<int64_t, 4> ne;
 916      float eps;
 917  
 918      std::string vars() override {
 919          return VARS_TO_STR3(type, ne, eps);
 920      }
 921  
 922      test_rms_norm(ggml_type type = GGML_TYPE_F32,
 923              std::array<int64_t, 4> ne = {64, 10, 10, 10},
 924              float eps = 1e-6f)
 925          : type(type), ne(ne), eps(eps) {}
 926  
 927      ggml_tensor * build_graph(ggml_context * ctx) override {
 928          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
 929          ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
 930          return out;
 931      }
 932  };
 933  
 934  // GGML_OP_MUL_MAT
 935  struct test_mul_mat : public test_case {
 936      const ggml_type type_a;
 937      const ggml_type type_b;
 938      const int64_t m;
 939      const int64_t n;
 940      const int64_t k;
 941      const std::array<int64_t, 2> bs; // dims 3 and 4
 942      const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
 943  
 944      std::string vars() override {
 945          return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr);
 946      }
 947  
 948      double max_nmse_err() override {
 949          return 5e-4;
 950      }
 951  
 952      size_t op_size(ggml_tensor * t) override {
 953          size_t a = ggml_nbytes(t->src[0]) * n * nr[0] * nr[1];
 954          size_t b = ggml_nbytes(t->src[1]) * m;
 955          size_t c  = ggml_nbytes(t);
 956          return a + b + c;
 957  
 958          GGML_UNUSED(t);
 959      }
 960  
 961      test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
 962              int64_t m = 32, int64_t n = 32, int64_t k = 32,
 963              std::array<int64_t, 2> bs = {10, 10},
 964              std::array<int64_t, 2> nr = {2, 2})
 965          : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), nr(nr) {}
 966  
 967      ggml_tensor * build_graph(ggml_context * ctx) override {
 968          // C^T = A * B^T: (k, m) * (k, n) => (m, n)
 969          ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]      , bs[1]);
 970          ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
 971          ggml_tensor * out = ggml_mul_mat(ctx, a, b);
 972          return out;
 973      }
 974  };
 975  
 976  // GGML_OP_MUL_MAT_ID
 977  struct test_mul_mat_id : public test_case {
 978      const ggml_type type_a;
 979      const ggml_type type_b;
 980      const int n_mats;
 981      const int n_used;
 982      const bool b; // brodcast b matrix
 983      const int64_t m;
 984      const int64_t n;
 985      const int64_t k;
 986  
 987      std::string vars() override {
 988          return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
 989      }
 990  
 991      double max_nmse_err() override {
 992          return 5e-4;
 993      }
 994  
 995      size_t op_size(ggml_tensor * t) override {
 996          size_t a = ggml_nbytes(t->src[2]) * n;
 997          size_t b = ggml_nbytes(t->src[1]) * m;
 998          size_t c  = ggml_nbytes(t);
 999          return a + b + c;
1000  
1001          GGML_UNUSED(t);
1002      }
1003  
1004      test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
1005              int n_mats = 8, int n_used = 2, bool b = false,
1006              int64_t m = 32, int64_t n = 32, int64_t k = 32)
1007          : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
1008              m(m), n(n), k(k) {
1009              GGML_ASSERT(n_used <= n_mats);
1010          }
1011  
1012      ggml_tensor * build_graph(ggml_context * ctx) override {
1013          // C^T = A * B^T: (k, m) * (k, n) => (m, n)
1014          ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
1015          ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
1016          if (n_used != n_mats) {
1017              ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
1018          }
1019          ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
1020          ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
1021          return out;
1022      }
1023  
1024      void initialize_tensors(ggml_context * ctx) override {
1025          std::random_device rd;
1026          std::default_random_engine rng(rd());
1027          for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1028              if (t->type == GGML_TYPE_I32) {
1029                  if (ggml_is_view_op(t->op)) { continue; }
1030                  // ids
1031                  for (int64_t r = 0; r < ggml_nrows(t); r++) {
1032                      std::vector<int32_t> data(t->ne[0]);
1033                      for (int i = 0; i < t->ne[0]; i++) {
1034                          data[i] = i % n_mats;
1035                      }
1036                      std::shuffle(data.begin(), data.end(), rng);
1037                      ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
1038                  }
1039              } else {
1040                  init_tensor_uniform(t);
1041              }
1042          }
1043      }
1044  };
1045  
1046  // GGML_OP_SQR
1047  struct test_sqr : public test_case {
1048      const ggml_type type;
1049      const std::array<int64_t, 4> ne;
1050  
1051      std::string vars() override {
1052          return VARS_TO_STR2(type, ne);
1053      }
1054  
1055      test_sqr(ggml_type type = GGML_TYPE_F32,
1056              std::array<int64_t, 4> ne = {10, 10, 10, 10})
1057          : type(type), ne(ne) {}
1058  
1059      ggml_tensor * build_graph(ggml_context * ctx) override {
1060          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1061          ggml_tensor * out = ggml_sqr(ctx, a);
1062          return out;
1063      }
1064  };
1065  
1066  // GGML_OP_SQRT
1067  struct test_sqrt : public test_case {
1068      const ggml_type type;
1069      const std::array<int64_t, 4> ne;
1070  
1071      std::string vars() override {
1072          return VARS_TO_STR2(type, ne);
1073      }
1074  
1075      test_sqrt(ggml_type type = GGML_TYPE_F32,
1076              std::array<int64_t, 4> ne = {10, 10, 10, 10})
1077          : type(type), ne(ne) {}
1078  
1079      ggml_tensor * build_graph(ggml_context * ctx) override {
1080          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1081          ggml_tensor * out = ggml_sqrt(ctx, a);
1082          return out;
1083      }
1084  
1085      void initialize_tensors(ggml_context * ctx) override {
1086          // fill with positive values
1087          for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1088              init_tensor_uniform(t, 0.0f, 100.0f);
1089          }
1090      }
1091  };
1092  
1093  // GGML_OP_CLAMP
1094  struct test_clamp : public test_case {
1095      const ggml_type type;
1096      const std::array<int64_t, 4> ne;
1097      float min;
1098      float max;
1099  
1100      std::string vars() override {
1101          return VARS_TO_STR4(type, ne, min, max);
1102      }
1103  
1104      test_clamp(ggml_type type = GGML_TYPE_F32,
1105              std::array<int64_t, 4> ne = {10, 10, 10, 10},
1106              float min = -0.5f, float max = 0.5f)
1107          : type(type), ne(ne), min(min), max(max) {}
1108  
1109      ggml_tensor * build_graph(ggml_context * ctx) override {
1110          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1111          ggml_tensor * out = ggml_clamp(ctx, a, min, max);
1112          return out;
1113      }
1114  };
1115  
1116  // GGML_OP_DIAG_MASK_INF
1117  struct test_diag_mask_inf : public test_case {
1118      const ggml_type type;
1119      const std::array<int64_t, 4> ne;
1120      const int n_past;
1121  
1122      std::string vars() override {
1123          return VARS_TO_STR3(type, ne, n_past);
1124      }
1125  
1126      test_diag_mask_inf(ggml_type type = GGML_TYPE_F32,
1127              std::array<int64_t, 4> ne = {10, 10, 10, 10},
1128              int n_past = 5)
1129          : type(type), ne(ne), n_past(n_past) {}
1130  
1131      ggml_tensor * build_graph(ggml_context * ctx) override {
1132          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1133          ggml_tensor * out = ggml_diag_mask_inf(ctx, a, n_past);
1134          return out;
1135      }
1136  };
1137  
1138  // GGML_OP_SOFT_MAX
1139  struct test_soft_max : public test_case {
1140      const ggml_type type;
1141      const std::array<int64_t, 4> ne;
1142      const bool mask;
1143      const float scale;
1144      const float max_bias;
1145  
1146      std::string vars() override {
1147          return VARS_TO_STR5(type, ne, mask, scale, max_bias);
1148      }
1149  
1150      // the 1024 test with bias occasionally fails:
1151      // SOFT_MAX(type=f32,ne=[1024,16,1,1],mask=1,scale=1.000000,max_bias=8.000000): [SOFT_MAX] NMSE = 0.000000103 > 0.000000100 FAIL
1152      virtual double max_nmse_err() override {
1153          return 1e-6;
1154      }
1155  
1156      test_soft_max(ggml_type type = GGML_TYPE_F32,
1157              std::array<int64_t, 4> ne = {10, 10, 10, 10},
1158              bool mask = false,
1159              float scale = 1.0f,
1160              float max_bias = 0.0f)
1161          : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
1162  
1163      ggml_tensor * build_graph(ggml_context * ctx) override {
1164          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1165          ggml_tensor * mask = nullptr;
1166          if (this->mask) {
1167              mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
1168          }
1169          ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
1170          return out;
1171      }
1172  };
1173  
1174  // GGML_OP_ROPE
1175  struct test_rope : public test_case {
1176      const ggml_type type;
1177      const std::array<int64_t, 4> ne_a;
1178      int n_dims;
1179      int mode;
1180      int n_ctx; // used to generate positions
1181      float fs; // freq_scale
1182      float ef; // ext_factor
1183      float af; // attn_factor
1184      bool ff;
1185      int v; // view (1 : non-contiguous a)
1186  
1187      std::string vars() override {
1188          return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v);
1189      }
1190  
1191      test_rope(ggml_type type = GGML_TYPE_F32,
1192              std::array<int64_t, 4> ne_a = {10, 10, 10, 1},
1193              int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0)
1194          : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v) {}
1195  
1196      ggml_tensor * build_graph(ggml_context * ctx) override {
1197          ggml_tensor * a;
1198          if (v & 1) {
1199              auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
1200              a = ggml_new_tensor(ctx, type, 4, ne.data());
1201              a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
1202          } else {
1203              a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1204          }
1205          ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]);
1206          ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
1207          ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
1208          return out;
1209      }
1210  
1211      void initialize_tensors(ggml_context * ctx) override {
1212          for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1213              if (t->type == GGML_TYPE_I32) {
1214                  // pos
1215                  std::vector<int> data(ne_a[2]);
1216                  for (int i = 0; i < ne_a[2]; i++) {
1217                      data[i] = rand() % n_ctx;
1218                  }
1219                  ggml_backend_tensor_set(t, data.data(), 0, ne_a[2] * sizeof(int));
1220              } else {
1221                  if (t->ne[0] == n_dims/2) {
1222                      // frequency factors in the range [0.9f, 1.1f]
1223                      init_tensor_uniform(t, 0.9f, 1.1f);
1224                  } else {
1225                      init_tensor_uniform(t);
1226                  }
1227              }
1228          }
1229      }
1230  };
1231  
1232  // GGML_OP_POOL2D
1233  struct test_pool2d : public test_case {
1234      enum ggml_op_pool pool_type;
1235      const ggml_type type_input;
1236      const std::array<int64_t, 4> ne_input;
1237      // kernel size
1238      const int k0;
1239      const int k1;
1240      // stride
1241      const int s0;
1242      const int s1;
1243      // padding
1244      const int p0;
1245      const int p1;
1246  
1247      std::string vars() override {
1248          return VARS_TO_STR9(pool_type, type_input, ne_input, k0, k1, s0, s1, p0, p1);
1249      }
1250  
1251      test_pool2d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,
1252              ggml_type type_input = GGML_TYPE_F32,
1253              std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
1254              int k0 = 3, int k1 = 3,
1255              int s0 = 1, int s1 = 1,
1256              int p0 = 1, int p1 = 1)
1257          : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), k1(k1), s0(s0), s1(s1), p0(p0), p1(p1) {}
1258  
1259      ggml_tensor * build_graph(ggml_context * ctx) override {
1260          ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
1261          ggml_tensor * out = ggml_pool_2d(ctx, input, pool_type, k0, k1, s0, s1, p0, p1);
1262          return out;
1263      }
1264  };
1265  
1266  // GGML_OP_IM2COL
1267  struct test_im2col : public test_case {
1268      const ggml_type type_input;
1269      const ggml_type type_kernel;
1270      const ggml_type dst_type;
1271      const std::array<int64_t, 4> ne_input;
1272      const std::array<int64_t, 4> ne_kernel;
1273      // stride
1274      const int s0;
1275      const int s1;
1276      // padding
1277      const int p0;
1278      const int p1;
1279      // dilatation
1280      const int d0;
1281      const int d1;
1282      // mode
1283      const bool is_2D;
1284  
1285      std::string vars() override {
1286          return VARS_TO_STR12(type_input, type_kernel, dst_type, ne_input, ne_kernel, s0, s1, p0, p1, d0, d1, is_2D);
1287      }
1288  
1289      test_im2col(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,
1290              std::array<int64_t, 4> ne_input = {10, 10, 3, 1}, // [input_width, input_height, input_channels, 1]
1291              std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [kernel_width, kernel_height, input_channels, 1]
1292              int s0 = 1, int s1 = 1,
1293              int p0 = 1, int p1 = 1,
1294              int d0 = 1, int d1 = 1,
1295              bool is_2D = true)
1296          : type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), is_2D(is_2D) {}
1297  
1298      ggml_tensor * build_graph(ggml_context * ctx) override {
1299          ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
1300          ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel.data());
1301          ggml_tensor * out = ggml_im2col(ctx, kernel, input, s0, s1, p0, p1, d0, d1, is_2D, dst_type);
1302          return out;
1303      }
1304  };
1305  
1306  // GGML_OP_CONCAT
1307  struct test_concat : public test_case {
1308      const ggml_type type;
1309      const std::array<int64_t, 4> ne_a;
1310      const int64_t ne_b_d;
1311      const int dim;
1312      const int v; // view (1 << 0: non-cont a, 1 << 1: non-cont b)
1313  
1314      std::string vars() override {
1315          return VARS_TO_STR5(type, ne_a, ne_b_d, dim, v);
1316      }
1317  
1318      test_concat(ggml_type type = GGML_TYPE_F32,
1319              std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
1320              int64_t ne_b_d = 10,
1321              int dim = 2, int v = 0)
1322          : type(type), ne_a(ne_a), ne_b_d(ne_b_d), dim(dim), v(v) {}
1323  
1324      ggml_tensor * build_graph(ggml_context * ctx) override {
1325          auto ne_b = ne_a;
1326          ne_b[dim] = ne_b_d;
1327          ggml_tensor * a;
1328          if (v & 1) {
1329              auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3;
1330              a = ggml_new_tensor(ctx, type, 4, ne.data());
1331              a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
1332          } else {
1333              a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1334          }
1335          ggml_tensor * b;
1336          if (v & 2) {
1337              auto ne = ne_b; ne[0] *= 3; ne[1] *= 2; ne[2] *= 4;
1338              b = ggml_new_tensor(ctx, type, 4, ne.data());
1339              b = ggml_view_4d(ctx, b, ne_b[0], ne_b[1], ne_b[2], ne_b[3], b->nb[1], b->nb[2], b->nb[3], 0);
1340          } else {
1341              b = ggml_new_tensor(ctx, type, 4, ne_b.data());
1342          }
1343          ggml_tensor * out = ggml_concat(ctx, a, b, dim);
1344          return out;
1345      }
1346  };
1347  
1348  // GGML_OP_ARGSORT
1349  struct test_argsort : public test_case {
1350      const ggml_type type;
1351      const std::array<int64_t, 4> ne;
1352      ggml_sort_order order;
1353  
1354      std::string vars() override {
1355          return VARS_TO_STR3(type, ne, order);
1356      }
1357  
1358      test_argsort(ggml_type type = GGML_TYPE_F32,
1359              std::array<int64_t, 4> ne = {16, 10, 10, 10},
1360              ggml_sort_order order = GGML_SORT_ORDER_ASC)
1361          : type(type), ne(ne), order(order) {}
1362  
1363      ggml_tensor * build_graph(ggml_context * ctx) override {
1364          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1365          ggml_tensor * out = ggml_argsort(ctx, a, order);
1366          return out;
1367      }
1368  
1369      void initialize_tensors(ggml_context * ctx) override {
1370          std::random_device rd;
1371          std::default_random_engine rng(rd());
1372          for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1373              if (t->type == GGML_TYPE_I32) {
1374                  // indices
1375                  std::vector<int> data(ggml_nelements(t));
1376                  for (int i = 0; i < ggml_nelements(t); i++) {
1377                      data[i] = rand();
1378                  }
1379                  std::shuffle(data.begin(), data.end(), rng);
1380                  ggml_backend_tensor_set(t, data.data(), 0, ne[0]*ne[1]*ne[2]*ne[3] * sizeof(int));
1381              } else if (t->type == GGML_TYPE_F32) {
1382                  // initialize with unique values to avoid ties
1383                  for (int64_t r = 0; r < ggml_nrows(t); r++) {
1384                      std::vector<float> data(t->ne[0]);
1385                      for (int i = 0; i < t->ne[0]; i++) {
1386                          data[i] = i;
1387                      }
1388                      std::shuffle(data.begin(), data.end(), rng);
1389                      ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(float));
1390                  }
1391              } else {
1392                  GGML_ASSERT(false);
1393              }
1394          }
1395      }
1396  };
1397  
1398  // GGML_OP_SUM_ROWS
1399  struct test_sum_rows : public test_case {
1400      const ggml_type type;
1401      const std::array<int64_t, 4> ne;
1402  
1403      std::string vars() override {
1404          return VARS_TO_STR2(type, ne);
1405      }
1406  
1407      test_sum_rows(ggml_type type = GGML_TYPE_F32,
1408              std::array<int64_t, 4> ne = {10, 10, 10, 10})
1409          : type(type), ne(ne) {}
1410  
1411      ggml_tensor * build_graph(ggml_context * ctx) override {
1412          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1413          ggml_tensor * out = ggml_sum_rows(ctx, a);
1414          return out;
1415      }
1416  };
1417  
1418  // GGML_OP_UPSCALE
1419  struct test_upscale : public test_case {
1420      const ggml_type type;
1421      const std::array<int64_t, 4> ne;
1422      const int32_t scale_factor;
1423      const bool transpose;
1424  
1425      std::string vars() override {
1426          return VARS_TO_STR4(type, ne, scale_factor, transpose);
1427      }
1428  
1429      test_upscale(ggml_type type = GGML_TYPE_F32,
1430              std::array<int64_t, 4> ne = {512, 512, 3, 1},
1431              int32_t scale_factor = 2, bool transpose = false)
1432          : type(type), ne(ne), scale_factor(scale_factor), transpose(transpose) {}
1433  
1434      ggml_tensor * build_graph(ggml_context * ctx) override {
1435          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1436          if (transpose) a = ggml_transpose(ctx, a);
1437          ggml_tensor * out = ggml_upscale(ctx, a, scale_factor);
1438          return out;
1439      }
1440  };
1441  
1442  // GGML_OP_UPSCALE (ext)
1443  struct test_upscale_ext : public test_case {
1444      const ggml_type type;
1445      const std::array<int64_t, 4> ne;
1446      const std::array<int64_t, 4> ne_tgt;
1447  
1448      std::string vars() override {
1449          return VARS_TO_STR3(type, ne, ne_tgt);
1450      }
1451  
1452      test_upscale_ext(ggml_type type = GGML_TYPE_F32,
1453              std::array<int64_t, 4> ne     = {2, 5,  7, 11},
1454              std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13})
1455          : type(type), ne(ne), ne_tgt(ne_tgt) {}
1456  
1457      ggml_tensor * build_graph(ggml_context * ctx) override {
1458          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1459          ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3]);
1460          return out;
1461      }
1462  };
1463  
1464  // GGML_OP_GROUP_NORM
1465  struct test_group_norm : public test_case {
1466      const ggml_type type;
1467      const std::array<int64_t, 4> ne;
1468      const int32_t num_groups;
1469  
1470      std::string vars() override {
1471          return VARS_TO_STR3(type, ne, num_groups);
1472      }
1473  
1474      test_group_norm(ggml_type type = GGML_TYPE_F32,
1475              std::array<int64_t, 4> ne = {64, 64, 320, 1},
1476              int32_t num_groups = 32)
1477          : type(type), ne(ne), num_groups(num_groups) {}
1478  
1479      ggml_tensor * build_graph(ggml_context * ctx) override {
1480          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1481          ggml_tensor * out = ggml_group_norm(ctx, a, num_groups);
1482          return out;
1483      }
1484  };
1485  
1486  // GGML_OP_ACC
1487  struct test_acc : public test_case {
1488      const ggml_type type;
1489      const std::array<int64_t, 4> ne_a;
1490      const std::array<int64_t, 4> ne_b;
1491  
1492      std::string vars() override {
1493          return VARS_TO_STR3(type, ne_a, ne_b);
1494      }
1495  
1496      test_acc(ggml_type type = GGML_TYPE_F32,
1497              std::array<int64_t, 4> ne_a = {1024, 577, 1, 1},
1498              std::array<int64_t, 4> ne_b = {1024, 576, 1, 1})
1499          : type(type), ne_a(ne_a), ne_b(ne_b) {}
1500  
1501      ggml_tensor * build_graph(ggml_context * ctx) override {
1502          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1503          ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
1504          ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]);
1505          return out;
1506      }
1507  };
1508  
1509  // GGML_OP_PAD
1510  struct test_pad : public test_case {
1511      const ggml_type type;
1512      const std::array<int64_t, 4> ne_a;
1513      const int pad_0;
1514      const int pad_1;
1515  
1516      std::string vars() override {
1517          return VARS_TO_STR4(type, ne_a, pad_0, pad_1);
1518      }
1519  
1520      test_pad(ggml_type type = GGML_TYPE_F32,
1521              std::array<int64_t, 4> ne_a = {512, 512, 1, 1},
1522              int pad_0 = 1, int pad_1 = 1)
1523          : type(type), ne_a(ne_a), pad_0(pad_0), pad_1(pad_1)  {}
1524  
1525      ggml_tensor * build_graph(ggml_context * ctx) override {
1526          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1527          ggml_tensor * out = ggml_pad(ctx, a, pad_0, pad_1, 0, 0);
1528          return out;
1529      }
1530  };
1531  
1532  // GGML_OP_ARANGE
1533  struct test_arange : public test_case {
1534      const ggml_type type;
1535      const float start;
1536      const float stop;
1537      const float step;
1538  
1539      std::string vars() override {
1540          return VARS_TO_STR4(type, start, stop, step);
1541      }
1542  
1543      test_arange(ggml_type type = GGML_TYPE_F32,
1544              float start = 0.f, float stop = 10.f, float step = 1.f)
1545          : type(type), start(start), stop(stop), step(step)  {}
1546  
1547      ggml_tensor * build_graph(ggml_context * ctx) override {
1548          ggml_tensor * out = ggml_arange(ctx, start, stop, step);
1549          return out;
1550      }
1551  };
1552  
1553  // GGML_OP_TIMESTEP_EMBEDDING
1554  struct test_timestep_embedding : public test_case {
1555      const ggml_type type;
1556      const std::array<int64_t, 4> ne_a;
1557      const int dim;
1558      const int max_period;
1559  
1560      std::string vars() override {
1561          return VARS_TO_STR4(type, ne_a, dim, max_period);
1562      }
1563  
1564      test_timestep_embedding(ggml_type type = GGML_TYPE_F32,
1565              std::array<int64_t, 4> ne_a = {2, 1, 1, 1},
1566              int dim = 320, int max_period=10000)
1567          : type(type), ne_a(ne_a), dim(dim), max_period(max_period)  {}
1568  
1569      ggml_tensor * build_graph(ggml_context * ctx) override {
1570          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1571          ggml_tensor * out = ggml_timestep_embedding(ctx, a, dim, max_period);
1572          return out;
1573      }
1574  };
1575  
1576  // GGML_OP_LEAKY_RELU
1577  struct test_leaky_relu : public test_case {
1578      const ggml_type type;
1579      const std::array<int64_t, 4> ne_a;
1580      const float negative_slope;
1581  
1582      std::string vars() override {
1583          return VARS_TO_STR3(type, ne_a, negative_slope);
1584      }
1585  
1586      test_leaky_relu(ggml_type type = GGML_TYPE_F32,
1587              std::array<int64_t, 4> ne_a = {10, 10, 10, 10},
1588              float negative_slope = 0.1f)
1589          : type(type), ne_a(ne_a), negative_slope(negative_slope)  {}
1590  
1591      ggml_tensor * build_graph(ggml_context * ctx) override {
1592          ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1593          ggml_tensor * out = ggml_leaky_relu(ctx, a, negative_slope, true);
1594          return out;
1595      }
1596  };
1597  
1598  // GGML_OP_FLASH_ATTN_EXT
1599  struct test_flash_attn_ext : public test_case {
1600      const int64_t hs; // head size
1601      const int64_t nh; // num heads
1602      const int64_t kv; // kv size
1603      const int64_t nb; // batch size
1604  
1605      const bool mask; // use mask
1606  
1607      const float max_bias; // ALiBi
1608  
1609      const ggml_type type_KV;
1610  
1611      std::string vars() override {
1612          return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV);
1613      }
1614  
1615      double max_nmse_err() override {
1616          return 5e-4;
1617      }
1618  
1619      test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
1620          : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
1621  
1622      ggml_tensor * build_graph(ggml_context * ctx) override {
1623          const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
1624  
1625          ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs_padded, nb, nh, 1);
1626          ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV,       hs_padded, kv, nh, 1);
1627          ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV,       hs_padded, kv, nh, 1);
1628          ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
1629          ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
1630          return out;
1631      }
1632  };
1633  
1634  enum llm_norm_type {
1635      LLM_NORM,
1636      LLM_NORM_RMS,
1637  };
1638  
1639  struct llama_hparams {
1640      uint32_t n_vocab;
1641      uint32_t n_embd;
1642      uint32_t n_head;
1643      uint32_t n_head_kv;
1644      static constexpr uint32_t n_layer = 1;
1645      uint32_t n_rot;
1646      uint32_t n_embd_head; // dimension of values (d_v)
1647      uint32_t n_ff;
1648  
1649      float f_norm_eps;
1650      float f_norm_rms_eps;
1651  
1652      // cparams
1653      static constexpr uint32_t n_ctx = 512; // user-specified context size
1654      static constexpr uint32_t n_ctx_orig = n_ctx;
1655  
1656      // batch
1657      int32_t n_tokens;
1658  
1659      // llm_build_context
1660      static constexpr int32_t n_kv    = 32; // size of KV cache to consider (n_kv <= n_ctx
1661      static constexpr int32_t kv_head = 1;  // index of where we store new KV data in the cache
1662  
1663      uint32_t n_embd_gqa() const { // dimension of key embeddings across all k-v heads
1664          return n_embd_head * n_head_kv;
1665      }
1666  };
1667  
1668  // LLM base class
1669  struct test_llm : public test_case {
1670      llama_hparams hp;
1671  
1672  protected:
1673      test_llm(llama_hparams hp)
1674          : hp(std::move(hp)) {
1675      }
1676  
1677  public:
1678      struct ggml_tensor * llm_build_norm(
1679              struct ggml_context * ctx,
1680               struct ggml_tensor * cur,
1681               struct ggml_tensor * mw,
1682               struct ggml_tensor * mb,
1683                    llm_norm_type   type) {
1684          switch (type) {
1685              case LLM_NORM:     cur = ggml_norm    (ctx, cur, hp.f_norm_eps); break;
1686              case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hp.f_norm_rms_eps); break;
1687          }
1688          cur = ggml_mul(ctx, cur, mw);
1689          if (mb) {
1690              cur = ggml_add(ctx, cur, mb);
1691          }
1692          return cur;
1693      }
1694  
1695      void llm_build_kv_store(
1696              struct ggml_context * ctx,
1697               struct ggml_tensor * k_l,
1698               struct ggml_tensor * v_l,
1699               struct ggml_tensor * k_cur,
1700               struct ggml_tensor * v_cur) {
1701          // compute the transposed [n_tokens, n_embd] V matrix
1702          struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, hp.n_embd_gqa(), hp.n_tokens));
1703  
1704          struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, k_l, hp.n_tokens*hp.n_embd_gqa(),
1705                  (ggml_row_size(k_l->type, hp.n_embd_gqa()))*hp.kv_head);
1706  
1707          struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, v_l, hp.n_tokens, hp.n_embd_gqa(),
1708                  (  hp.n_ctx)*ggml_element_size(v_l),
1709                  (hp.kv_head)*ggml_element_size(v_l));
1710  
1711          // important: storing RoPE-ed version of K in the KV cache!
1712          ggml_cpy(ctx, k_cur,   k_cache_view);
1713          ggml_cpy(ctx, v_cur_t, v_cache_view);
1714      }
1715  
1716      struct ggml_tensor * llm_build_kqv(
1717              struct ggml_context * ctx,
1718               struct ggml_tensor * k_l,
1719               struct ggml_tensor * v_l,
1720               struct ggml_tensor * q_cur,
1721               struct ggml_tensor * kq_mask,
1722                          float     kq_scale) {
1723          struct ggml_tensor * q = ggml_permute(ctx, q_cur, 0, 2, 1, 3);
1724  
1725          struct ggml_tensor * k =
1726              ggml_view_3d(ctx, k_l,
1727                      hp.n_embd_head, hp.n_kv, hp.n_head_kv,
1728                      ggml_row_size(k_l->type, hp.n_embd_gqa()),
1729                      ggml_row_size(k_l->type, hp.n_embd_head),
1730                      0);
1731  
1732          struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
1733  
1734          kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);
1735  
1736          // split cached v into n_head heads
1737          struct ggml_tensor * v =
1738              ggml_view_3d(ctx, v_l,
1739                      hp.n_kv, hp.n_embd_head, hp.n_head_kv,
1740                      ggml_element_size(v_l)*hp.n_ctx,
1741                      ggml_element_size(v_l)*hp.n_ctx*hp.n_embd_head,
1742                      0);
1743  
1744          struct ggml_tensor * kqv = ggml_mul_mat(ctx, v, kq);
1745  
1746          struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
1747  
1748          struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, hp.n_embd_head*hp.n_head, hp.n_tokens);
1749  
1750          struct ggml_tensor * wo = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
1751          cur = ggml_mul_mat(ctx, wo, cur);
1752  
1753          return cur;
1754      }
1755  
1756      void initialize_tensors(ggml_context * ctx) override {
1757          for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1758              if (t->type == GGML_TYPE_I32) {
1759                  // pos
1760                  std::vector<int> data(hp.n_tokens);
1761                  for (int i = 0; i < hp.n_tokens; i++) {
1762                      data[i] = rand() % hp.n_ctx;
1763                  }
1764                  ggml_backend_tensor_set(t, data.data(), 0, hp.n_tokens * sizeof(int));
1765              } else {
1766                  init_tensor_uniform(t);
1767              }
1768          }
1769      }
1770  };
1771  
1772  // Llama
1773  struct test_llama : public test_llm {
1774      static constexpr float freq_base = 10000.0f;
1775      static constexpr float freq_scale = 1.0f;
1776      static constexpr float ext_factor = 0.0f;
1777      static constexpr float attn_factor = 1.0f;
1778      static constexpr float beta_fast = 32.0f;
1779      static constexpr float beta_slow = 1.0f;
1780  
1781      std::string op_desc(ggml_tensor * t) override {
1782          GGML_UNUSED(t);
1783          return "LLAMA";
1784      }
1785  
1786      std::string vars() override {
1787          auto n_tokens = hp.n_tokens;
1788          return VARS_TO_STR1(n_tokens);
1789      }
1790  
1791      double max_nmse_err() override {
1792          return 2e-3;
1793      }
1794  
1795      test_llama(int n_tokens = 1)
1796          : test_llm({
1797              /*n_vocab        =*/ 32000,
1798              /*n_embd         =*/ 3200,
1799              /*n_head         =*/ 32,
1800              /*n_head_kv      =*/ 32,
1801              /*n_rot          =*/ 100,
1802              /*n_embd_head    =*/ 100,
1803              /*n_ff           =*/ 8640,
1804              /*f_norm_eps     =*/ 0.f,
1805              /*f_norm_rms_eps =*/ 1e-5f,
1806              /*n_tokens       =*/ n_tokens,
1807          }) {
1808      }
1809  
1810      ggml_tensor * build_graph(ggml_context * ctx) override {
1811          struct ggml_tensor * cur;
1812          struct ggml_tensor * inpL;
1813  
1814          inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
1815  
1816          // inp_pos - contains the positions
1817          struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
1818  
1819          // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
1820          struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
1821  
1822          ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
1823          ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
1824  
1825          for (uint32_t il = 0; il < hp.n_layer; ++il) {
1826              struct ggml_tensor * inpSA = inpL;
1827  
1828              // norm
1829              ggml_tensor * attn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
1830              cur = llm_build_norm(ctx, inpL, attn_norm, nullptr, LLM_NORM_RMS);
1831  
1832              // self-attention
1833              {
1834                  ggml_tensor * wq = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd);
1835                  ggml_tensor * wk = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
1836                  ggml_tensor * wv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd_gqa());
1837  
1838                  // compute Q and K and RoPE them
1839                  struct ggml_tensor * Qcur = ggml_mul_mat(ctx, wq, cur);
1840                  struct ggml_tensor * Kcur = ggml_mul_mat(ctx, wk, cur);
1841                  struct ggml_tensor * Vcur = ggml_mul_mat(ctx, wv, cur);
1842  
1843                  Qcur = ggml_rope_ext(
1844                      ctx, ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens), inp_pos, nullptr,
1845                      hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
1846                      ext_factor, attn_factor, beta_fast, beta_slow
1847                  );
1848  
1849                  Kcur = ggml_rope_ext(
1850                      ctx, ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens), inp_pos, nullptr,
1851                      hp.n_rot, 0, hp.n_ctx_orig, freq_base, freq_scale,
1852                      ext_factor, attn_factor, beta_fast, beta_slow
1853                  );
1854  
1855                  llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
1856  
1857                  cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
1858              }
1859  
1860              struct ggml_tensor * ffn_inp = ggml_add(ctx, cur, inpSA);
1861  
1862              // feed-forward network
1863              ggml_tensor * ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
1864              cur = llm_build_norm(ctx, ffn_inp, ffn_norm, nullptr, LLM_NORM_RMS);
1865  
1866              ggml_tensor * ffn_gate = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
1867              ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff,   hp.n_embd);
1868              ggml_tensor * ffn_up   = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
1869              struct ggml_tensor * tmp = ggml_mul_mat(ctx, ffn_up, cur);
1870              cur = ggml_mul_mat(ctx, ffn_gate, cur);
1871              cur = ggml_silu(ctx, cur);
1872              cur = ggml_mul(ctx, cur, tmp);
1873              cur = ggml_mul_mat(ctx, ffn_down, cur);
1874  
1875              cur = ggml_add(ctx, cur, ffn_inp);
1876  
1877              // input for next layer
1878              inpL = cur;
1879          }
1880  
1881          cur = inpL;
1882  
1883          ggml_tensor * output_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
1884          cur = llm_build_norm(ctx, cur, output_norm, nullptr, LLM_NORM_RMS);
1885  
1886          // lm_head
1887          ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_vocab);
1888          cur = ggml_mul_mat(ctx, output, cur);
1889  
1890          return cur;
1891      }
1892  };
1893  
1894  // Falcon
1895  struct test_falcon : public test_llm {
1896      static constexpr float freq_base = 10000.0f;
1897      static constexpr float freq_scale = 1.0f;
1898      static constexpr float ext_factor = 0.0f;
1899      static constexpr float attn_factor = 1.0f;
1900      static constexpr float beta_fast = 32.0f;
1901      static constexpr float beta_slow = 1.0f;
1902  
1903      std::string op_desc(ggml_tensor * t) override {
1904          GGML_UNUSED(t);
1905          return "FALCON";
1906      }
1907  
1908      std::string vars() override {
1909          auto n_tokens = hp.n_tokens;
1910          return VARS_TO_STR1(n_tokens);
1911      }
1912  
1913      double max_nmse_err() override {
1914          return 2e-3;
1915      }
1916  
1917      test_falcon(int n_tokens = 1)
1918          : test_llm({
1919              /*n_vocab        =*/ 32000,
1920              /*n_embd         =*/ 3200,
1921              /*n_head         =*/ 50,
1922              /*n_head_kv      =*/ 1,
1923              /*n_rot          =*/ 64,
1924              /*n_embd_head    =*/ 64,
1925              /*n_ff           =*/ 8640,
1926              /*f_norm_eps     =*/ 1e-5f,
1927              /*f_norm_rms_eps =*/ 0.f,
1928              /*n_tokens       =*/ n_tokens,
1929          }) {
1930      }
1931  
1932      ggml_tensor * build_graph(ggml_context * ctx) override {
1933          struct ggml_tensor * cur;
1934          struct ggml_tensor * inpL;
1935  
1936          inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens);
1937  
1938          // inp_pos - contains the positions
1939          struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
1940  
1941          // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
1942          struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
1943  
1944          ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
1945          ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
1946  
1947          for (uint32_t il = 0; il < hp.n_layer; ++il) {
1948              // norm
1949              ggml_tensor * attn_norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
1950              ggml_tensor * attn_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
1951              ggml_tensor * attn_norm = llm_build_norm(ctx, inpL, attn_norm_w, attn_norm_b, LLM_NORM);
1952  
1953              // self-attention
1954              {
1955                  cur = attn_norm;
1956  
1957                  ggml_tensor * wqkv = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_embd + 2*hp.n_embd_gqa());
1958  
1959                  cur = ggml_mul_mat(ctx, wqkv, cur);
1960  
1961                  struct ggml_tensor * Qcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd,     hp.n_tokens, cur->nb[1], 0*sizeof(float)*(hp.n_embd)));
1962                  struct ggml_tensor * Kcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd)));
1963                  struct ggml_tensor * Vcur = ggml_cont(ctx, ggml_view_2d(ctx, cur, hp.n_embd_gqa(), hp.n_tokens, cur->nb[1], 1*sizeof(float)*(hp.n_embd + hp.n_embd_gqa())));
1964  
1965                  Qcur = ggml_reshape_3d(ctx, Qcur, hp.n_embd_head, hp.n_head,    hp.n_tokens);
1966                  Kcur = ggml_reshape_3d(ctx, Kcur, hp.n_embd_head, hp.n_head_kv, hp.n_tokens);
1967  
1968                  // using mode = 2 for neox mode
1969                  Qcur = ggml_rope_ext(
1970                      ctx, Qcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
1971                      freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
1972                  );
1973  
1974                  Kcur = ggml_rope_ext(
1975                      ctx, Kcur, inp_pos, nullptr, hp.n_rot, 2, hp.n_ctx_orig,
1976                      freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
1977                  );
1978  
1979                  llm_build_kv_store(ctx, k_l, v_l, Kcur, Vcur);
1980  
1981                  cur = llm_build_kqv(ctx, k_l, v_l, Qcur, KQ_mask, 1.0f/sqrtf(float(hp.n_embd_head)));
1982              }
1983  
1984              struct ggml_tensor * ffn_inp = cur;
1985  
1986              // feed forward
1987              {
1988                  ggml_tensor * ffn_up   = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_embd, hp.n_ff);
1989                  ggml_tensor * ffn_down = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, hp.n_ff, hp.n_embd);
1990                  cur = attn_norm;
1991                  cur = ggml_mul_mat(ctx, ffn_up, cur);
1992                  cur = ggml_gelu(ctx, cur);
1993                  cur = ggml_mul_mat(ctx, ffn_down, cur);
1994              }
1995  
1996              cur = ggml_add(ctx, cur, ffn_inp);
1997  
1998              cur = ggml_add(ctx, cur, inpL);
1999  
2000              // input for next layer
2001              inpL = cur;
2002          }
2003  
2004          cur = inpL;
2005  
2006          ggml_tensor * output_norm   = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
2007          ggml_tensor * output_norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_embd);
2008          cur = llm_build_norm(ctx, cur, output_norm, output_norm_b, LLM_NORM);
2009  
2010          // lm_head
2011          ggml_tensor * output = ggml_new_tensor_2d(ctx, GGML_TYPE_Q8_0, hp.n_embd, hp.n_vocab);
2012          cur = ggml_mul_mat(ctx, output, cur);
2013  
2014          return cur;
2015      }
2016  };
2017  
2018  static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
2019      std::vector<std::unique_ptr<test_case>> test_cases;
2020      std::default_random_engine rng(0);
2021  
2022      const ggml_type all_types[] = {
2023          GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16,
2024          GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
2025          GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
2026          GGML_TYPE_Q8_0,
2027          GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
2028          GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
2029          GGML_TYPE_Q6_K,
2030          GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
2031          GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
2032          GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
2033      };
2034  
2035      const ggml_type base_types[] = {
2036          GGML_TYPE_F32, GGML_TYPE_F16,
2037          GGML_TYPE_Q4_0,
2038          GGML_TYPE_Q4_K,
2039          GGML_TYPE_IQ2_XXS
2040      };
2041  
2042      const ggml_type other_types[] = {
2043          GGML_TYPE_Q4_1,
2044          GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
2045          GGML_TYPE_Q8_0,
2046          GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
2047          GGML_TYPE_Q5_K,
2048          GGML_TYPE_Q6_K,
2049          GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
2050          GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
2051          GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
2052      };
2053  
2054      // unary ops
2055      for (int v : {0, 1}) {
2056          for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
2057              test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 10, 10, 10 }, v));
2058              test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }, v));
2059          }
2060      }
2061  
2062      test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
2063      for (ggml_type type : all_types) {
2064          for (int b : {1, 7}) {
2065              for (bool v : {false, true}) {
2066                  test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v));
2067              }
2068          }
2069      }
2070      for (int b : {1, 7}) {
2071          for (bool v : {false, true}) {
2072              test_cases.emplace_back(new test_get_rows(GGML_TYPE_I32, 256, 5, 4, b, v));
2073          }
2074      }
2075  
2076      for (ggml_type type_input : {GGML_TYPE_F32}) {
2077          for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
2078              for (int k0 : {1, 3}) {
2079                  for (int k1 : {1, 3}) {
2080                      for (int s0 : {1, 2}) {
2081                          for (int s1 : {1, 2}) {
2082                              for (int p0 : {0, 1}) {
2083                                  for (int p1 : {0, 1}) {
2084                                      test_cases.emplace_back(new test_pool2d(pool_type, type_input, {10, 10, 3, 1}, k0, k1, s0, s1, p0, p1));
2085                                  }
2086                              }
2087                          }
2088                      }
2089                  }
2090              }
2091          }
2092      }
2093  
2094      test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
2095      test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
2096  
2097      test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
2098      test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
2099      test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));
2100      test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 2, 1}));
2101      test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2}));
2102      test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 10, 10, 10}, {2, 1, 1, 1}));
2103      test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 10, 10, 10}, {1, 1, 1, 2}));
2104  
2105      test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
2106      test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
2107      test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
2108      test_cases.emplace_back(new test_dup(GGML_TYPE_I16));
2109      test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {0, 2, 1, 3}));
2110      test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3}));
2111  
2112      for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
2113          for (ggml_type type_dst : all_types) {
2114             test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
2115          }
2116      }
2117  
2118      test_cases.emplace_back(new test_cont());
2119  
2120      auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
2121          for (auto op : {ggml_add, ggml_mul, ggml_div}) {
2122              test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
2123          }
2124      };
2125  
2126      add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
2127      add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
2128      add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
2129      add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1});
2130      add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1});
2131      add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1});
2132      add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1});
2133      add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1});
2134      add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1});
2135      add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2});
2136      add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2});
2137      add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2});
2138      add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2});
2139  
2140      // stable diffusion
2141      add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1});
2142      add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 16, 16, 1});
2143      add_test_bin_bcast(GGML_TYPE_F32, {1280, 16, 16, 1}, {1, 1, 1, 1});
2144      add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 256, 1, 1});
2145      add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {16, 16, 1, 1});
2146      add_test_bin_bcast(GGML_TYPE_F32, {16, 16, 1280, 1}, {1, 1, 1, 1});
2147      add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {16, 16, 1, 1});
2148      add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 2560, 1}, {16, 16, 1, 1});
2149      add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {32, 32, 1, 1});
2150      add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {32, 32, 1, 1});
2151      add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
2152      add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
2153      add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
2154      //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
2155      //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
2156  
2157      test_cases.emplace_back(new test_scale());
2158  
2159      for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) {
2160          test_cases.emplace_back(new test_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
2161          test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
2162      }
2163  
2164      for (ggml_type type_a : base_types) {
2165          for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2166              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1,  1}, {1, 1}));
2167              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {1, 1}));
2168              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10,  1}, {2, 1}));
2169              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
2170              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
2171              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
2172              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
2173  
2174              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1,  1}, {1, 1}));
2175              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {1, 1}));
2176              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10,  1}, {2, 1}));
2177              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
2178              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
2179              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
2180              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
2181          }
2182      }
2183  
2184      for (ggml_type type_a : other_types) {
2185          for (ggml_type type_b : {GGML_TYPE_F32}) {
2186              test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1,  1}, {1, 1}));
2187          }
2188      }
2189  
2190      test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,  128, { 8,  1}, {1, 1}));
2191      test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,  128, { 8,  1}, {4, 1}));
2192      test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 2,   64, { 8,  1}, {4, 1}));
2193      test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  83, 2,   64, { 8,  1}, {4, 1}));
2194      test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32,  64, 45, 128, { 8,  1}, {4, 1}));
2195      test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45,  64, { 8,  1}, {4, 1}));
2196  
2197      for (ggml_type type_a : base_types) {
2198          for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
2199              for (int n_mats : {4, 8}) {
2200                  for (int n_used : {1, 2, 4}) {
2201                      for (bool b : {false, true}) {
2202                          for (int n : {1, 32}) {
2203                              int m = 512;
2204                              int k = 256;
2205                              test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
2206                          }
2207                      }
2208                  }
2209              }
2210          }
2211      }
2212  
2213      for (ggml_type type_a : other_types) {
2214          for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
2215              for (int n_mats : {4}) {
2216                  for (int n_used : {2}) {
2217                      for (bool b : {false}) {
2218                          for (int n : {1}) {
2219                              int m = 512;
2220                              int k = 256;
2221                              test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
2222                          }
2223                      }
2224                  }
2225              }
2226          }
2227      }
2228  
2229      test_cases.emplace_back(new test_sqr());
2230      test_cases.emplace_back(new test_sqrt());
2231      test_cases.emplace_back(new test_clamp());
2232  
2233      test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10,  1,  1}, 5));
2234      test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10,  1}, 5));
2235      test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
2236  
2237  #if 0
2238      std::uniform_int_distribution<> dist_ne1(1, 50);
2239      int exponent = 1;
2240      while (exponent < (1 << 17)) {
2241          std::uniform_int_distribution<> dist_ne0(exponent, 2*exponent);
2242  
2243          for (int n = 0; n < 10; ++n) {
2244              int64_t ne0 = dist_ne0(rng);
2245              int64_t ne1 = dist_ne1(rng);
2246              test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
2247          }
2248  
2249          exponent <<= 1;
2250      }
2251  #endif
2252      for (bool mask : {false, true}) {
2253          for (float max_bias : {0.0f, 8.0f}) {
2254              if (!mask && max_bias > 0.0f) continue;
2255              for (float scale : {1.0f, 0.1f}) {
2256                  for (int64_t ne0 : {16, 1024}) {
2257                      for (int64_t ne1 : {16, 1024}) {
2258                          test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, scale, max_bias));
2259                          test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
2260                      }
2261                  }
2262              }
2263          }
2264      }
2265  
2266      test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
2267      test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 0.0f));
2268      test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 8.0f));
2269  
2270      {
2271          bool all = true;
2272  
2273          for (float v : { 0, 1 }) {
2274              for (float fs : { 1.0f, 1.4245f }) {
2275                  for (float ef : { 0.0f, 0.7465f }) {
2276                      for (float af : { 1.0f, 1.4245f }) {
2277                          for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2278                              for (bool ff : {false, true}) { // freq_factors
2279                                  test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
2280  
2281                                  if (all) {
2282                                      test_cases.emplace_back(new test_rope(type, {128,  40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
2283                                      test_cases.emplace_back(new test_rope(type, {128,  52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
2284                                      test_cases.emplace_back(new test_rope(type, {128,  64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
2285                                  }
2286  
2287                                  if (all) {
2288                                      test_cases.emplace_back(new test_rope(type, { 64,   1, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
2289                                      test_cases.emplace_back(new test_rope(type, { 64,  71, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
2290                                      test_cases.emplace_back(new test_rope(type, { 64,   8, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
2291                                      test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  20, 2, 512, fs, ef, af, ff, v)); // neox (stablelm)
2292                                      test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  32, 2, 512, fs, ef, af, ff, v)); // neox (phi-2)
2293                                  }
2294  
2295                                  test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1},  64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 40B)
2296                              }
2297                          }
2298  
2299                          all = false;
2300                      }
2301                  }
2302              }
2303          }
2304      }
2305  
2306      for (int v : { 0, 1, 2, 3 }) {
2307          for (int dim : { 0, 1, 2, 3, }) {
2308              test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v));
2309              test_cases.emplace_back(new test_concat(GGML_TYPE_I32, {11, 12, 13, 14}, 7, dim, v));
2310          }
2311      }
2312  
2313      for (ggml_sort_order order : {GGML_SORT_ORDER_ASC, GGML_SORT_ORDER_DESC}) {
2314          test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
2315          test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
2316          test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen
2317      }
2318  
2319      test_cases.emplace_back(new test_sum_rows());
2320      test_cases.emplace_back(new test_upscale());
2321      test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
2322      test_cases.emplace_back(new test_upscale_ext());
2323      test_cases.emplace_back(new test_group_norm());
2324      test_cases.emplace_back(new test_acc());
2325      test_cases.emplace_back(new test_pad());
2326      test_cases.emplace_back(new test_arange());
2327      test_cases.emplace_back(new test_timestep_embedding());
2328      test_cases.emplace_back(new test_leaky_relu());
2329  
2330      for (int hs : { 64, 80, 128, 256, }) {
2331          for (bool mask : { true, false } ) {
2332              for (float max_bias : { 0.0f, 8.0f }) {
2333                  if (!mask && max_bias > 0.0f) continue;
2334                  for (int nh : { 32, }) {
2335                      for (int kv : { 512, 1024, }) {
2336                          for (int nb : { 1, 2, 4, 8, }) {
2337                              for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
2338                                  test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV));
2339                              }
2340                          }
2341                      }
2342                  }
2343              }
2344          }
2345      }
2346  
2347      // these tests are disabled to save execution time, but they can be handy for debugging
2348  #if 0
2349      test_cases.emplace_back(new test_llama(1));
2350      test_cases.emplace_back(new test_llama(2));
2351      test_cases.emplace_back(new test_falcon(1));
2352      test_cases.emplace_back(new test_falcon(2));
2353  #endif
2354  
2355      // run tests
2356      if (mode == MODE_TEST) {
2357          ggml_backend_t backend_cpu = ggml_backend_cpu_init();
2358  
2359          size_t n_ok = 0;
2360          for (auto & test : test_cases) {
2361              if (test->eval(backend, backend_cpu, op_name)) {
2362                  n_ok++;
2363              }
2364          }
2365          printf("  %zu/%zu tests passed\n", n_ok, test_cases.size());
2366  
2367          ggml_backend_free(backend_cpu);
2368  
2369          return n_ok == test_cases.size();
2370      }
2371  
2372      if (mode == MODE_PERF) {
2373          for (auto & test : test_cases) {
2374              test->eval_perf(backend, op_name);
2375          }
2376          return true;
2377      }
2378  
2379      GGML_ASSERT(false);
2380      return false;
2381  }
2382  
2383  static void usage(char ** argv) {
2384      printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]);
2385      printf("  valid modes are: test (compare with CPU backend for correctness) or perf (performance evaluation)\n");
2386      printf("  op names are as given by ggml_op_desc()\n");
2387  }
2388  
2389  int main(int argc, char ** argv) {
2390      test_mode mode = MODE_TEST;
2391      const char * op_name_filter = NULL;
2392      const char * backend_filter = NULL;
2393  
2394      for (int i = 1; i < argc; i++) {
2395          if (strcmp(argv[i], "test") == 0) {
2396              mode = MODE_TEST;
2397          } else if (strcmp(argv[i], "perf") == 0) {
2398              mode = MODE_PERF;
2399          } else if (strcmp(argv[i], "-o") == 0) {
2400              if (i + 1 < argc) {
2401                  op_name_filter = argv[++i];
2402              } else {
2403                  usage(argv);
2404                  return 1;
2405              }
2406          } else if (strcmp(argv[i], "-b") == 0) {
2407              if (i + 1 < argc) {
2408                  backend_filter = argv[++i];
2409              } else {
2410                  usage(argv);
2411                  return 1;
2412              }
2413          } else {
2414              usage(argv);
2415              return 1;
2416          }
2417      }
2418  
2419      // enumerate backends
2420      printf("Testing %zu backends\n\n", ggml_backend_reg_get_count());
2421  
2422      size_t n_ok = 0;
2423  
2424      for (size_t i = 0; i < ggml_backend_reg_get_count(); i++) {
2425          printf("Backend %zu/%zu (%s)\n", i + 1, ggml_backend_reg_get_count(), ggml_backend_reg_get_name(i));
2426  
2427          if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_reg_get_name(i)) != 0) {
2428              printf("  Skipping\n");
2429              n_ok++;
2430              continue;
2431          }
2432  
2433          ggml_backend_t backend = ggml_backend_reg_init_backend(i, NULL);
2434          GGML_ASSERT(backend != NULL);
2435  
2436          if (backend_filter == NULL && ggml_backend_is_cpu(backend)) {
2437              printf("  Skipping CPU backend\n");
2438              ggml_backend_free(backend);
2439              n_ok++;
2440              continue;
2441          }
2442  
2443          printf("  Backend name: %s\n", ggml_backend_name(backend));
2444  
2445          bool ok = test_backend(backend, mode, op_name_filter);
2446  
2447          printf("  Backend %s: ", ggml_backend_name(backend));
2448          if (ok) {
2449              printf("\033[1;32mOK\033[0m\n");
2450              n_ok++;
2451          } else {
2452              printf("\033[1;31mFAIL\033[0m\n");
2453          }
2454  
2455          printf("\n");
2456  
2457          ggml_backend_free(backend);
2458      }
2459  
2460      printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count());
2461  
2462      if (n_ok != ggml_backend_reg_get_count()) {
2463          printf("\033[1;31mFAIL\033[0m\n");
2464          return 1;
2465      }
2466  
2467      ggml_quantize_free();
2468  
2469      printf("\033[1;32mOK\033[0m\n");
2470      return 0;
2471  }