/ sgemm.cpp
sgemm.cpp
   1  // Copyright 2024 Mozilla Foundation
   2  //
   3  // Permission is hereby granted, free of charge, to any person obtaining
   4  // a copy of this software and associated documentation files (the
   5  // "Software"), to deal in the Software without restriction, including
   6  // without limitation the rights to use, copy, modify, merge, publish,
   7  // distribute, sublicense, and/or sell copies of the Software, and to
   8  // permit persons to whom the Software is furnished to do so, subject to
   9  // the following conditions:
  10  //
  11  // The above copyright notice and this permission notice shall be
  12  // included in all copies or substantial portions of the Software.
  13  //
  14  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  15  // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  16  // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  17  // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
  18  // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
  19  // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  20  // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21  // SOFTWARE.
  22  
  23  //
  24  //                   _   _          ___ _      _   ___
  25  //                  | |_(_)_ _ _  _| _ ) |    /_\ / __|
  26  //                  |  _| | ' \ || | _ \ |__ / _ \\__ \.
  27  //                   \__|_|_||_\_, |___/____/_/ \_\___/
  28  //                             |__/
  29  //
  30  //                    BASIC LINEAR ALGEBRA SUBPROGRAMS
  31  //
  32  //
  33  // This file implements multithreaded CPU matrix multiplication for the
  34  // common contiguous use case C = Aᵀ * B. These kernels are designed to
  35  // have excellent performance[1] for matrices that fit in the CPU cache
  36  // without imposing any overhead such as cache filling or malloc calls.
  37  //
  38  // This implementation does not guarantee any upper bound with rounding
  39  // errors, which grow along with k. Our goal's to maximally exploit the
  40  // hardware for performance, and then use whatever resources remain for
  41  // improving numerical accuracy.
  42  //
  43  // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
  44  //     Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
  45  
  46  #pragma GCC diagnostic ignored "-Wpedantic"
  47  #pragma GCC diagnostic ignored "-Wignored-attributes"
  48  
  49  #include "sgemm.h"
  50  #include "ggml-impl.h"
  51  #include "ggml-quants.h"
  52  
  53  #ifdef _MSC_VER
  54  #define NOINLINE __declspec(noinline)
  55  #else
  56  #define NOINLINE __attribute__((__noinline__))
  57  #endif
  58  
  59  #if defined(__ARM_NEON) || defined(__AVX512F__)
  60  #define VECTOR_REGISTERS 32
  61  #else
  62  #define VECTOR_REGISTERS 16
  63  #endif
  64  
  65  #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
  66  
  67  namespace {
  68  
  69  inline float unhalf(ggml_fp16_t d) {
  70      return GGML_FP16_TO_FP32(d);
  71  }
  72  
  73  ////////////////////////////////////////////////////////////////////////////////////////////////////
  74  // VECTORIZED ARITHMETIC OPERATIONS
  75  
  76  #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  77  inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
  78  inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
  79  inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
  80  #endif  // __SSE__
  81  
  82  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
  83  inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
  84  inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
  85  inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
  86  #endif // __AVX__
  87  
  88  #if defined(__AVX512F__)
  89  inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
  90  inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
  91  inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
  92  #endif // __AVX512F__
  93  
  94  #if defined(__ARM_NEON)
  95  inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
  96  inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
  97  inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
  98  #endif // __ARM_NEON
  99  
 100  #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
 101  inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
 102  inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
 103  inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
 104  #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 105  
 106  ////////////////////////////////////////////////////////////////////////////////////////////////////
 107  // VECTORIZED FUSED MULTIPLY ADD
 108  
 109  /**
 110   * Computes a * b + c.
 111   */
 112  template <typename T, typename U>
 113  inline U madd(T a, T b, U c) {
 114      return add(mul(a, b), c);
 115  }
 116  
 117  #if defined(__FMA__)
 118  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 119  template <>
 120  inline __m256 madd(__m256 a, __m256 b, __m256 c) {
 121      return _mm256_fmadd_ps(a, b, c);
 122  }
 123  #endif
 124  #if defined(__AVX512F__)
 125  template <>
 126  inline __m512 madd(__m512 a, __m512 b, __m512 c) {
 127      return _mm512_fmadd_ps(a, b, c);
 128  }
 129  #endif
 130  #endif
 131  
 132  #if defined(__ARM_FEATURE_FMA)
 133  template <>
 134  inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
 135      return vfmaq_f32(c, b, a);
 136  }
 137  #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
 138  template <>
 139  inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
 140      return vfmaq_f16(c, b, a);
 141  }
 142  #endif
 143  #endif
 144  
 145  ////////////////////////////////////////////////////////////////////////////////////////////////////
 146  // VECTORIZED HORIZONTAL SUM
 147  
 148  #if defined(__ARM_NEON)
 149  inline float hsum(float32x4_t x) {
 150      return vaddvq_f32(x);
 151  }
 152  #endif // __ARM_NEON
 153  
 154  #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
 155  inline float hsum(float16x8_t x) {
 156      return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
 157                                  vcvt_f32_f16(vget_high_f16(x))));
 158  }
 159  #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
 160  
 161  #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 162  inline float hsum(__m128 x) {
 163  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 164      x = _mm_add_ps(x, _mm_movehl_ps(x, x));
 165      x = _mm_add_ss(x, _mm_movehdup_ps(x));
 166  #else
 167      __m128 t;
 168      t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
 169      x = _mm_add_ps(x, t);
 170      t = _mm_movehl_ps(t, x);
 171      x = _mm_add_ss(x, t);
 172  #endif
 173      return _mm_cvtss_f32(x);
 174  }
 175  #endif
 176  
 177  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 178  inline float hsum(__m256 x) {
 179      return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
 180                             _mm256_castps256_ps128(x)));
 181  }
 182  #endif // __AVX__
 183  
 184  #if defined(__AVX512F__)
 185  inline float hsum(__m512 x) {
 186      return _mm512_reduce_add_ps(x);
 187  }
 188  #endif // __AVX512F__
 189  
 190  ////////////////////////////////////////////////////////////////////////////////////////////////////
 191  // VECTORIZED MEMORY LOADING
 192  
 193  template <typename T, typename U> T load(const U *);
 194  
 195  #if defined(__ARM_NEON)
 196  template <> inline float32x4_t load(const float *p) {
 197      return vld1q_f32(p);
 198  }
 199  #if !defined(_MSC_VER)
 200  template <> inline float16x8_t load(const ggml_fp16_t *p) {
 201      return vld1q_f16((const float16_t *)p);
 202  }
 203  template <> inline float32x4_t load(const ggml_fp16_t *p) {
 204      return vcvt_f32_f16(vld1_f16((const float16_t *)p));
 205  }
 206  #endif // _MSC_VER
 207  #endif // __ARM_NEON
 208  
 209  #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 210  template <> inline __m128 load(const float *p) {
 211      return _mm_loadu_ps(p);
 212  }
 213  #endif  // __SSE__
 214  
 215  #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 216  template <> inline __m256 load(const float *p) {
 217      return _mm256_loadu_ps(p);
 218  }
 219  #endif // __AVX__
 220  
 221  #if defined(__F16C__)
 222  template <> inline __m256 load(const ggml_fp16_t *p) {
 223      return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
 224  }
 225  #endif // __F16C__
 226  
 227  #if defined(__AVX512F__)
 228  template <> inline __m512 load(const float *p) {
 229      return _mm512_loadu_ps(p);
 230  }
 231  template <> inline __m512 load(const ggml_fp16_t *p) {
 232      return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
 233  }
 234  #endif // __AVX512F__
 235  
 236  ////////////////////////////////////////////////////////////////////////////////////////////////////
 237  // FLOATING POINT MATRIX MULTIPLICATION
 238  
 239  template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
 240  class tinyBLAS {
 241    public:
 242      tinyBLAS(int64_t k,
 243               const TA *A, int64_t lda,
 244               const TB *B, int64_t ldb,
 245               TC *C, int64_t ldc,
 246               int ith, int nth)
 247          : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
 248      }
 249  
 250      void matmul(int64_t m, int64_t n, int task) {
 251          if (task == GGML_TASK_TYPE_COMPUTE)
 252              mnpack(0, m, 0, n);
 253      }
 254  
 255    private:
 256      NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
 257          int64_t mc, nc, mp, np;
 258          switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
 259  #if VECTOR_REGISTERS == 32
 260          case 0x55:
 261              mc = 5;
 262              nc = 5;
 263              gemm<5, 5>(m0, m, n0, n);
 264              break;
 265          case 0x45:
 266              mc = 4;
 267              nc = 5;
 268              gemm<4, 5>(m0, m, n0, n);
 269              break;
 270          case 0x54:
 271              mc = 5;
 272              nc = 4;
 273              gemm<5, 4>(m0, m, n0, n);
 274              break;
 275          case 0x44:
 276              mc = 4;
 277              nc = 4;
 278              gemm<4, 4>(m0, m, n0, n);
 279              break;
 280          case 0x53:
 281              mc = 5;
 282              nc = 3;
 283              gemm<5, 3>(m0, m, n0, n);
 284              break;
 285          case 0x35:
 286              mc = 3;
 287              nc = 5;
 288              gemm<3, 5>(m0, m, n0, n);
 289              break;
 290          case 0x43:
 291              mc = 4;
 292              nc = 3;
 293              gemm<4, 3>(m0, m, n0, n);
 294              break;
 295  #else
 296          case 0x55:
 297          case 0x54:
 298          case 0x53:
 299          case 0x45:
 300          case 0x44:
 301          case 0x43:
 302              mc = 4;
 303              nc = 3;
 304              gemm<4, 3>(m0, m, n0, n);
 305              break;
 306          case 0x35:
 307  #endif
 308          case 0x34:
 309              mc = 3;
 310              nc = 4;
 311              gemm<3, 4>(m0, m, n0, n);
 312              break;
 313          case 0x52:
 314              mc = 5;
 315              nc = 2;
 316              gemm<5, 2>(m0, m, n0, n);
 317              break;
 318          case 0x33:
 319              mc = 3;
 320              nc = 3;
 321              gemm<3, 3>(m0, m, n0, n);
 322              break;
 323          case 0x25:
 324              mc = 2;
 325              nc = 5;
 326              gemm<2, 5>(m0, m, n0, n);
 327              break;
 328          case 0x42:
 329              mc = 4;
 330              nc = 2;
 331              gemm<4, 2>(m0, m, n0, n);
 332              break;
 333          case 0x24:
 334              mc = 2;
 335              nc = 4;
 336              gemm<2, 4>(m0, m, n0, n);
 337              break;
 338          case 0x32:
 339              mc = 3;
 340              nc = 2;
 341              gemm<3, 2>(m0, m, n0, n);
 342              break;
 343          case 0x23:
 344              mc = 2;
 345              nc = 3;
 346              gemm<2, 3>(m0, m, n0, n);
 347              break;
 348          case 0x51:
 349              mc = 5;
 350              nc = 1;
 351              gemm<5, 1>(m0, m, n0, n);
 352              break;
 353          case 0x41:
 354              mc = 4;
 355              nc = 1;
 356              gemm<4, 1>(m0, m, n0, n);
 357              break;
 358          case 0x22:
 359              mc = 2;
 360              nc = 2;
 361              gemm<2, 2>(m0, m, n0, n);
 362              break;
 363          case 0x15:
 364              mc = 1;
 365              nc = 5;
 366              gemm<1, 5>(m0, m, n0, n);
 367              break;
 368          case 0x14:
 369              mc = 1;
 370              nc = 4;
 371              gemm<1, 4>(m0, m, n0, n);
 372              break;
 373          case 0x31:
 374              mc = 3;
 375              nc = 1;
 376              gemm<3, 1>(m0, m, n0, n);
 377              break;
 378          case 0x13:
 379              mc = 1;
 380              nc = 3;
 381              gemm<1, 3>(m0, m, n0, n);
 382              break;
 383          case 0x21:
 384              mc = 2;
 385              nc = 1;
 386              gemm<2, 1>(m0, m, n0, n);
 387              break;
 388          case 0x12:
 389              mc = 1;
 390              nc = 2;
 391              gemm<1, 2>(m0, m, n0, n);
 392              break;
 393          case 0x11:
 394              mc = 1;
 395              nc = 1;
 396              gemm<1, 1>(m0, m, n0, n);
 397              break;
 398          default:
 399              return;
 400          }
 401          mp = m0 + (m - m0) / mc * mc;
 402          np = n0 + (n - n0) / nc * nc;
 403          mnpack(mp, m, n0, np);
 404          mnpack(m0, m, np, n);
 405      }
 406  
 407      template <int RM, int RN>
 408      NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
 409          int64_t ytiles = (m - m0) / RM;
 410          int64_t xtiles = (n - n0) / RN;
 411          int64_t tiles = xtiles * ytiles;
 412          int64_t duty = (tiles + nth - 1) / nth;
 413          int64_t start = duty * ith;
 414          int64_t end = start + duty;
 415          if (end > tiles)
 416              end = tiles;
 417          for (int64_t job = start; job < end; ++job) {
 418              int64_t ii = m0 + job / xtiles * RM;
 419              int64_t jj = n0 + job % xtiles * RN;
 420              D Cv[RN][RM] = {};
 421              for (int64_t l = 0; l < k; l += KN)
 422                  for (int64_t j = 0; j < RN; ++j)
 423                      for (int64_t i = 0; i < RM; ++i)
 424                          Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
 425                                          load<V>(B + ldb * (jj + j) + l),
 426                                          Cv[j][i]);
 427              for (int64_t j = 0; j < RN; ++j)
 428                  for (int64_t i = 0; i < RM; ++i)
 429                      C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
 430          }
 431      }
 432  
 433      const TA *const A;
 434      const TB *const B;
 435      TC *const C;
 436      const int64_t k;
 437      const int64_t lda;
 438      const int64_t ldb;
 439      const int64_t ldc;
 440      const int ith;
 441      const int nth;
 442  };
 443  
 444  //////////////////////////////////////////////////////////////////////////////////////////
 445  // QUANT ZERO MATRIX MULTIPLICATION
 446  
 447  #if defined(__ARM_FEATURE_DOTPROD)
 448  template <typename TA>
 449  class tinyBLAS_Q0_ARM {
 450    public:
 451      tinyBLAS_Q0_ARM(int64_t k,
 452                      const TA *A, int64_t lda,
 453                      const block_q8_0 *B, int64_t ldb,
 454                      float *C, int64_t ldc,
 455                      int ith, int nth)
 456          : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
 457      }
 458  
 459      void matmul(int64_t m, int64_t n, int task) {
 460          if (task == GGML_TASK_TYPE_COMPUTE)
 461              mnpack(0, m, 0, n);
 462      }
 463  
 464    private:
 465      NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
 466          int64_t mc, nc, mp, np;
 467          switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
 468          case 0x33:
 469              mc = 3;
 470              nc = 3;
 471              gemm<3, 3>(m0, m, n0, n);
 472              break;
 473          case 0x32:
 474              mc = 3;
 475              nc = 2;
 476              gemm<3, 2>(m0, m, n0, n);
 477              break;
 478          case 0x23:
 479              mc = 2;
 480              nc = 3;
 481              gemm<2, 3>(m0, m, n0, n);
 482              break;
 483          case 0x22:
 484              mc = 2;
 485              nc = 2;
 486              gemm<2, 2>(m0, m, n0, n);
 487              break;
 488          case 0x31:
 489              mc = 3;
 490              nc = 1;
 491              gemm<3, 1>(m0, m, n0, n);
 492              break;
 493          case 0x13:
 494              mc = 1;
 495              nc = 3;
 496              gemm<1, 3>(m0, m, n0, n);
 497              break;
 498          case 0x21:
 499              mc = 2;
 500              nc = 1;
 501              gemm<2, 1>(m0, m, n0, n);
 502              break;
 503          case 0x12:
 504              mc = 1;
 505              nc = 2;
 506              gemm<1, 2>(m0, m, n0, n);
 507              break;
 508          case 0x11:
 509              mc = 1;
 510              nc = 1;
 511              gemm<1, 1>(m0, m, n0, n);
 512              break;
 513          default:
 514              return;
 515          }
 516          mp = m0 + (m - m0) / mc * mc;
 517          np = n0 + (n - n0) / nc * nc;
 518          mnpack(mp, m, n0, np);
 519          mnpack(m0, m, np, n);
 520      }
 521  
 522      template <int RM, int RN>
 523      NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
 524          int64_t ytiles = (m - m0) / RM;
 525          int64_t xtiles = (n - n0) / RN;
 526          int64_t tiles = xtiles * ytiles;
 527          int64_t duty = (tiles + nth - 1) / nth;
 528          int64_t start = duty * ith;
 529          int64_t end = start + duty;
 530          if (end > tiles)
 531              end = tiles;
 532          for (int64_t job = start; job < end; ++job) {
 533              int64_t ii = m0 + job / xtiles * RM;
 534              int64_t jj = n0 + job % xtiles * RN;
 535              float32x4_t Cv[RN][RM] = {};
 536              for (int64_t l = 0; l < k; ++l)
 537                  for (int64_t j = 0; j < RN; ++j)
 538                      for (int64_t i = 0; i < RM; ++i)
 539                          Cv[j][i] = vmlaq_n_f32(Cv[j][i],
 540                                                 vcvtq_f32_s32(vdotq_s32(
 541                                                     vdotq_s32(vdupq_n_s32(0),
 542                                                               load_lo(A + lda * (ii + i) + l),
 543                                                               load_lo(B + ldb * (jj + j) + l)),
 544                                                     load_hi(A + lda * (ii + i) + l),
 545                                                     load_hi(B + ldb * (jj + j) + l))),
 546                                                 unhalf(A[lda * (ii + i) + l].d) *
 547                                                 unhalf(B[ldb * (jj + j) + l].d));
 548              for (int64_t j = 0; j < RN; ++j)
 549                  for (int64_t i = 0; i < RM; ++i)
 550                      C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
 551          }
 552      }
 553  
 554      inline int8x16_t load_lo(const block_q8_0 *b) {
 555          return vld1q_s8(b->qs);
 556      }
 557  
 558      inline int8x16_t load_hi(const block_q8_0 *b) {
 559          return vld1q_s8(b->qs + 16);
 560      }
 561  
 562      inline int8x16_t load_lo(const block_q4_0 *b) {
 563          return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
 564                                                       vdupq_n_u8(0x0f))),
 565                          vdupq_n_s8(0x8));
 566      }
 567  
 568      inline int8x16_t load_hi(const block_q4_0 *b) {
 569          return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
 570                          vdupq_n_s8(0x8));
 571      }
 572  
 573      const TA *const A;
 574      const block_q8_0 *const B;
 575      float *const C;
 576      const int64_t k;
 577      const int64_t lda;
 578      const int64_t ldb;
 579      const int64_t ldc;
 580      const int ith;
 581      const int nth;
 582  };
 583  #endif // __ARM_FEATURE_DOTPROD
 584  
 585  #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
 586  template <typename TA, typename TB, typename TC>
 587  class tinyBLAS_Q0_AVX {
 588    public:
 589      tinyBLAS_Q0_AVX(int64_t k,
 590                      const TA *A, int64_t lda,
 591                      const TB *B, int64_t ldb,
 592                      TC *C, int64_t ldc,
 593                      int ith, int nth)
 594          : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
 595      }
 596  
 597      void matmul(int64_t m, int64_t n, int task) {
 598          if (task == GGML_TASK_TYPE_COMPUTE)
 599              mnpack(0, m, 0, n);
 600      }
 601  
 602    private:
 603      void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
 604          int64_t mc, nc, mp, np;
 605          switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
 606  #if VECTOR_REGISTERS == 32
 607          case 0x44:
 608              mc = 4;
 609              nc = 4;
 610              gemm<4, 4>(m0, m, n0, n);
 611              break;
 612          case 0x43:
 613              mc = 4;
 614              nc = 3;
 615              gemm<4, 3>(m0, m, n0, n);
 616              break;
 617          case 0x34:
 618              mc = 3;
 619              nc = 4;
 620              gemm<3, 4>(m0, m, n0, n);
 621              break;
 622          case 0x33:
 623              mc = 3;
 624              nc = 3;
 625              gemm<3, 3>(m0, m, n0, n);
 626              break;
 627          case 0x42:
 628              mc = 4;
 629              nc = 2;
 630              gemm<4, 2>(m0, m, n0, n);
 631              break;
 632          case 0x24:
 633              mc = 2;
 634              nc = 4;
 635              gemm<2, 4>(m0, m, n0, n);
 636              break;
 637  #else
 638          case 0x44:
 639          case 0x43:
 640          case 0x42:
 641              mc = 4;
 642              nc = 2;
 643              gemm<4, 2>(m0, m, n0, n);
 644              break;
 645          case 0x34:
 646          case 0x24:
 647              mc = 2;
 648              nc = 4;
 649              gemm<2, 4>(m0, m, n0, n);
 650              break;
 651          case 0x33:
 652  #endif
 653          case 0x32:
 654              mc = 3;
 655              nc = 2;
 656              gemm<3, 2>(m0, m, n0, n);
 657              break;
 658          case 0x23:
 659              mc = 2;
 660              nc = 3;
 661              gemm<2, 3>(m0, m, n0, n);
 662              break;
 663          case 0x41:
 664              mc = 4;
 665              nc = 1;
 666              gemm<4, 1>(m0, m, n0, n);
 667              break;
 668          case 0x22:
 669              mc = 2;
 670              nc = 2;
 671              gemm<2, 2>(m0, m, n0, n);
 672              break;
 673          case 0x14:
 674              mc = 1;
 675              nc = 4;
 676              gemm<1, 4>(m0, m, n0, n);
 677              break;
 678          case 0x31:
 679              mc = 3;
 680              nc = 1;
 681              gemm<3, 1>(m0, m, n0, n);
 682              break;
 683          case 0x13:
 684              mc = 1;
 685              nc = 3;
 686              gemm<1, 3>(m0, m, n0, n);
 687              break;
 688          case 0x21:
 689              mc = 2;
 690              nc = 1;
 691              gemm<2, 1>(m0, m, n0, n);
 692              break;
 693          case 0x12:
 694              mc = 1;
 695              nc = 2;
 696              gemm<1, 2>(m0, m, n0, n);
 697              break;
 698          case 0x11:
 699              mc = 1;
 700              nc = 1;
 701              gemm<1, 1>(m0, m, n0, n);
 702              break;
 703          default:
 704              return;
 705          }
 706          mp = m0 + (m - m0) / mc * mc;
 707          np = n0 + (n - n0) / nc * nc;
 708          mnpack(mp, m, n0, np);
 709          mnpack(m0, m, np, n);
 710      }
 711  
 712      template <int RM, int RN>
 713      NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
 714          int64_t ytiles = (m - m0) / RM;
 715          int64_t xtiles = (n - n0) / RN;
 716          int64_t tiles = xtiles * ytiles;
 717          int64_t duty = (tiles + nth - 1) / nth;
 718          int64_t start = duty * ith;
 719          int64_t end = start + duty;
 720          if (end > tiles)
 721              end = tiles;
 722          for (int64_t job = start; job < end; ++job) {
 723              int64_t ii = m0 + job / xtiles * RM;
 724              int64_t jj = n0 + job % xtiles * RN;
 725              __m256 Cv[RN][RM] = {};
 726              for (int64_t l = 0; l < k; ++l)
 727                  for (int64_t j = 0; j < RN; ++j)
 728                      for (int64_t i = 0; i < RM; ++i) {
 729  #if defined(__AVX2__)
 730                          __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
 731                                                                load(A + lda * (ii + i) + l)),
 732                                               _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
 733                                                                load(A + lda * (ii + i) + l)));
 734  #else
 735                          __m128i ali0 = load0(A + lda * (ii + i) + l);
 736                          __m128i ali1 = load1(A + lda * (ii + i) + l);
 737                          __m128i blj0 = load0(B + ldb * (jj + j) + l);
 738                          __m128i blj1 = load1(B + ldb * (jj + j) + l);
 739  
 740                          __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
 741                          __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
 742                          __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
 743                          __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
 744  
 745                          // updot
 746                          const __m128i oneFill = _mm_set1_epi16(1);
 747                          __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
 748                          __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
 749                          __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
 750  #endif
 751                          Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
 752                                                         unhalf(B[ldb * (jj + j) + l].d)),
 753                                                         udTmp,
 754                                                         Cv[j][i]);
 755                      }
 756              for (int64_t j = 0; j < RN; ++j)
 757                  for (int64_t i = 0; i < RM; ++i)
 758                      C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
 759          }
 760      }
 761  
 762      inline __m256i load(const block_q8_0 *b) {
 763          return _mm256_loadu_si256((const __m256i *)b->qs);
 764      }
 765  
 766      inline __m128i load0(const block_q8_0 *b) {
 767          return _mm_loadu_si128((const __m128i *)b->qs);
 768      }
 769  
 770      inline __m128i load1(const block_q8_0 *b) {
 771          return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
 772      }
 773  
 774      inline __m256i load(const block_q4_0 *b) {
 775          return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
 776      }
 777  
 778      inline __m128i load0(const block_q4_0 *b) {
 779          const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
 780          return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
 781      }
 782  
 783      inline __m128i load1(const block_q4_0 *b) {
 784          const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
 785          return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
 786      }
 787  
 788      inline __m256 updot(__m256i u, __m256i s) {
 789          __m256i res;
 790  #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
 791          res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
 792  #else
 793          res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
 794  #endif
 795          return _mm256_cvtepi32_ps(res);
 796      }
 797  
 798      static inline __m256i denibble(const uint8_t *p) {
 799          __m128i x = _mm_loadu_si128((const __m128i *)p);
 800          return _mm256_and_si256(_mm256_set1_epi8(15),
 801                                  _mm256_insertf128_si256(_mm256_castsi128_si256(x),
 802                                                          _mm_srli_epi16(x, 4), 1));
 803      }
 804  
 805      const TA *const A;
 806      const TB *const B;
 807      TC *const C;
 808      const int64_t k;
 809      const int64_t lda;
 810      const int64_t ldb;
 811      const int64_t ldc;
 812      const int ith;
 813      const int nth;
 814  };
 815  #endif // __AVX__
 816  
 817  } // namespace
 818  
 819  /**
 820   * Performs optimized matrix multiplication on CPU.
 821   *
 822   * This subroutine may compute C = Aᵀ * B with column major ordering.
 823   * Despite its name, this isn't a generalized implementation. Work is
 824   * only performed when a handwritten kernel is written and available.
 825   * Otherwise the caller should fall back to a general matmul routine.
 826   *
 827   * For example, for single-threaded single-precision GEMM you can say
 828   *
 829   *     llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
 830   *                     0, 1, GGML_TASK_TYPE_COMPUTE,
 831   *                     GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
 832   *
 833   * @param m is rows in `A` and `C`
 834   * @param n is cols in `B` and `C`
 835   * @param k is cols in `A` and rows in `B`
 836   * @param A is first input matrix (always transposed)
 837   * @param lda is row stride of `A`
 838   * @param B is second input matrix (never transposed)
 839   * @param ldb is row stride of `B`
 840   * @param C is input/output array of output matrices
 841   * @param ldc is row stride of `C`
 842   * @param ith is thread id (must be less than `nth`)
 843   * @param nth is number of threads (must be greater than zero)
 844   * @param task is GGML task type
 845   * @param Atype is GGML data type of `A`
 846   * @param Btype is GGML data type of `B`
 847   * @param Ctype is GGML data type of `C`
 848   * @return true if this function was able to service the matmul request
 849   */
 850  bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
 851                       int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
 852  
 853      assert(m >= 0);
 854      assert(n >= 0);
 855      assert(k >= 0);
 856      assert(lda >= k);
 857      assert(ldb >= k);
 858      assert(ldc >= m);
 859      assert(nth > 0);
 860      assert(ith < nth);
 861  
 862      if (Ctype != GGML_TYPE_F32)
 863          return false;
 864  
 865      switch (Atype) {
 866  
 867      case GGML_TYPE_F32: {
 868          if (Btype != GGML_TYPE_F32)
 869              return false;
 870  #if defined(__AVX512F__)
 871          if (k % 16)
 872              return false;
 873          tinyBLAS<16, __m512, __m512, float, float, float> tb{
 874              k, (const float *)A, lda,
 875              (const float *)B, ldb,
 876              (float *)C, ldc,
 877              ith, nth};
 878          tb.matmul(m, n, task);
 879          return true;
 880  #elif defined(__AVX__) || defined(__AVX2__)
 881          if (k % 8)
 882              return false;
 883          tinyBLAS<8, __m256, __m256, float, float, float> tb{
 884              k, (const float *)A, lda,
 885              (const float *)B, ldb,
 886              (float *)C, ldc,
 887              ith, nth};
 888          tb.matmul(m, n, task);
 889          return true;
 890  #elif defined(__ARM_NEON)
 891          if (n < 4)
 892              return false;
 893          if (k % 4)
 894              return false;
 895          tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
 896              k, (const float *)A, lda,
 897              (const float *)B, ldb,
 898              (float *)C, ldc,
 899              ith, nth};
 900          tb.matmul(m, n, task);
 901          return true;
 902  #else
 903          return false;
 904  #endif
 905      }
 906  
 907      case GGML_TYPE_F16: {
 908  #if defined(__AVX512F__)
 909          if (k % 16)
 910              return false;
 911          if (Btype != GGML_TYPE_F32)
 912              return false;
 913          tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
 914              k, (const ggml_fp16_t *)A, lda,
 915              (const float *)B, ldb,
 916              (float *)C, ldc,
 917              ith, nth};
 918          tb.matmul(m, n, task);
 919          return true;
 920  #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
 921          if (k % 8)
 922              return false;
 923          if (Btype != GGML_TYPE_F32)
 924              return false;
 925          tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
 926              k, (const ggml_fp16_t *)A, lda,
 927              (const float *)B, ldb,
 928              (float *)C, ldc,
 929              ith, nth};
 930          tb.matmul(m, n, task);
 931          return true;
 932  #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
 933          if (n < 8)
 934              return false;
 935          if (k % 8)
 936              return false;
 937          if (Btype != GGML_TYPE_F16)
 938              return false;
 939          tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
 940              k, (const ggml_fp16_t *)A, lda,
 941              (const ggml_fp16_t *)B, ldb,
 942              (float *)C, ldc,
 943              ith, nth};
 944          tb.matmul(m, n, task);
 945          return true;
 946  #elif defined(__ARM_NEON) && !defined(_MSC_VER)
 947          if (k % 4)
 948              return false;
 949          if (Btype != GGML_TYPE_F32)
 950              return false;
 951          tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
 952              k, (const ggml_fp16_t *)A, lda,
 953              (const float *)B, ldb,
 954              (float *)C, ldc,
 955              ith, nth};
 956          tb.matmul(m, n, task);
 957          return true;
 958  #else
 959          return false;
 960  #endif
 961      }
 962  
 963      case GGML_TYPE_Q8_0: {
 964          if (Btype != GGML_TYPE_Q8_0)
 965             return false;
 966  #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
 967          tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
 968              k, (const block_q8_0 *)A, lda,
 969              (const block_q8_0 *)B, ldb,
 970              (float *)C, ldc,
 971              ith, nth};
 972          tb.matmul(m, n, task);
 973          return true;
 974  #elif defined(__ARM_FEATURE_DOTPROD)
 975          tinyBLAS_Q0_ARM<block_q8_0> tb{
 976              k, (const block_q8_0 *)A, lda,
 977              (const block_q8_0 *)B, ldb,
 978              (float *)C, ldc,
 979              ith, nth};
 980          tb.matmul(m, n, task);
 981          return true;
 982  #else
 983          return false;
 984  #endif
 985      }
 986  
 987      case GGML_TYPE_Q4_0: {
 988          if (Btype != GGML_TYPE_Q8_0)
 989              return false;
 990  #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
 991          tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
 992              k, (const block_q4_0 *)A, lda,
 993              (const block_q8_0 *)B, ldb,
 994              (float *)C, ldc,
 995              ith, nth};
 996          tb.matmul(m, n, task);
 997          return true;
 998  #elif defined(__ARM_FEATURE_DOTPROD)
 999          tinyBLAS_Q0_ARM<block_q4_0> tb{
1000              k, (const block_q4_0 *)A, lda,
1001              (const block_q8_0 *)B, ldb,
1002              (float *)C, ldc,
1003              ith, nth};
1004          tb.matmul(m, n, task);
1005          return true;
1006  #else
1007          return false;
1008  #endif
1009      }
1010  
1011      default:
1012          return false;
1013      }
1014  
1015      (void)m;
1016      (void)n;
1017      (void)k;
1018      (void)A;
1019      (void)lda;
1020      (void)B;
1021      (void)ldb;
1022      (void)C;
1023      (void)ldc;
1024      (void)ith;
1025      (void)nth;
1026      (void)task;
1027      (void)Atype;
1028      (void)Btype;
1029      (void)Ctype;
1030  }