/ sgemm.cpp
sgemm.cpp
1 // Copyright 2024 Mozilla Foundation 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining 4 // a copy of this software and associated documentation files (the 5 // "Software"), to deal in the Software without restriction, including 6 // without limitation the rights to use, copy, modify, merge, publish, 7 // distribute, sublicense, and/or sell copies of the Software, and to 8 // permit persons to whom the Software is furnished to do so, subject to 9 // the following conditions: 10 // 11 // The above copyright notice and this permission notice shall be 12 // included in all copies or substantial portions of the Software. 13 // 14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 18 // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 19 // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 // SOFTWARE. 22 23 // 24 // _ _ ___ _ _ ___ 25 // | |_(_)_ _ _ _| _ ) | /_\ / __| 26 // | _| | ' \ || | _ \ |__ / _ \\__ \. 27 // \__|_|_||_\_, |___/____/_/ \_\___/ 28 // |__/ 29 // 30 // BASIC LINEAR ALGEBRA SUBPROGRAMS 31 // 32 // 33 // This file implements multithreaded CPU matrix multiplication for the 34 // common contiguous use case C = Aᵀ * B. These kernels are designed to 35 // have excellent performance[1] for matrices that fit in the CPU cache 36 // without imposing any overhead such as cache filling or malloc calls. 37 // 38 // This implementation does not guarantee any upper bound with rounding 39 // errors, which grow along with k. Our goal's to maximally exploit the 40 // hardware for performance, and then use whatever resources remain for 41 // improving numerical accuracy. 42 // 43 // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. 44 // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. 45 46 #pragma GCC diagnostic ignored "-Wpedantic" 47 #pragma GCC diagnostic ignored "-Wignored-attributes" 48 49 #include "sgemm.h" 50 #include "ggml-impl.h" 51 #include "ggml-quants.h" 52 53 #ifdef _MSC_VER 54 #define NOINLINE __declspec(noinline) 55 #else 56 #define NOINLINE __attribute__((__noinline__)) 57 #endif 58 59 #if defined(__ARM_NEON) || defined(__AVX512F__) 60 #define VECTOR_REGISTERS 32 61 #else 62 #define VECTOR_REGISTERS 16 63 #endif 64 65 #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) 66 67 namespace { 68 69 inline float unhalf(ggml_fp16_t d) { 70 return GGML_FP16_TO_FP32(d); 71 } 72 73 //////////////////////////////////////////////////////////////////////////////////////////////////// 74 // VECTORIZED ARITHMETIC OPERATIONS 75 76 #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) 77 inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); } 78 inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); } 79 inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } 80 #endif // __SSE__ 81 82 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) 83 inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); } 84 inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); } 85 inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); } 86 #endif // __AVX__ 87 88 #if defined(__AVX512F__) 89 inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); } 90 inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); } 91 inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); } 92 #endif // __AVX512F__ 93 94 #if defined(__ARM_NEON) 95 inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); } 96 inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); } 97 inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); } 98 #endif // __ARM_NEON 99 100 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) 101 inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); } 102 inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } 103 inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } 104 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 105 106 //////////////////////////////////////////////////////////////////////////////////////////////////// 107 // VECTORIZED FUSED MULTIPLY ADD 108 109 /** 110 * Computes a * b + c. 111 */ 112 template <typename T, typename U> 113 inline U madd(T a, T b, U c) { 114 return add(mul(a, b), c); 115 } 116 117 #if defined(__FMA__) 118 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) 119 template <> 120 inline __m256 madd(__m256 a, __m256 b, __m256 c) { 121 return _mm256_fmadd_ps(a, b, c); 122 } 123 #endif 124 #if defined(__AVX512F__) 125 template <> 126 inline __m512 madd(__m512 a, __m512 b, __m512 c) { 127 return _mm512_fmadd_ps(a, b, c); 128 } 129 #endif 130 #endif 131 132 #if defined(__ARM_FEATURE_FMA) 133 template <> 134 inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { 135 return vfmaq_f32(c, b, a); 136 } 137 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) 138 template <> 139 inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { 140 return vfmaq_f16(c, b, a); 141 } 142 #endif 143 #endif 144 145 //////////////////////////////////////////////////////////////////////////////////////////////////// 146 // VECTORIZED HORIZONTAL SUM 147 148 #if defined(__ARM_NEON) 149 inline float hsum(float32x4_t x) { 150 return vaddvq_f32(x); 151 } 152 #endif // __ARM_NEON 153 154 #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) 155 inline float hsum(float16x8_t x) { 156 return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), 157 vcvt_f32_f16(vget_high_f16(x)))); 158 } 159 #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC 160 161 #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) 162 inline float hsum(__m128 x) { 163 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) 164 x = _mm_add_ps(x, _mm_movehl_ps(x, x)); 165 x = _mm_add_ss(x, _mm_movehdup_ps(x)); 166 #else 167 __m128 t; 168 t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1)); 169 x = _mm_add_ps(x, t); 170 t = _mm_movehl_ps(t, x); 171 x = _mm_add_ss(x, t); 172 #endif 173 return _mm_cvtss_f32(x); 174 } 175 #endif 176 177 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) 178 inline float hsum(__m256 x) { 179 return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), 180 _mm256_castps256_ps128(x))); 181 } 182 #endif // __AVX__ 183 184 #if defined(__AVX512F__) 185 inline float hsum(__m512 x) { 186 return _mm512_reduce_add_ps(x); 187 } 188 #endif // __AVX512F__ 189 190 //////////////////////////////////////////////////////////////////////////////////////////////////// 191 // VECTORIZED MEMORY LOADING 192 193 template <typename T, typename U> T load(const U *); 194 195 #if defined(__ARM_NEON) 196 template <> inline float32x4_t load(const float *p) { 197 return vld1q_f32(p); 198 } 199 #if !defined(_MSC_VER) 200 template <> inline float16x8_t load(const ggml_fp16_t *p) { 201 return vld1q_f16((const float16_t *)p); 202 } 203 template <> inline float32x4_t load(const ggml_fp16_t *p) { 204 return vcvt_f32_f16(vld1_f16((const float16_t *)p)); 205 } 206 #endif // _MSC_VER 207 #endif // __ARM_NEON 208 209 #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) 210 template <> inline __m128 load(const float *p) { 211 return _mm_loadu_ps(p); 212 } 213 #endif // __SSE__ 214 215 #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) 216 template <> inline __m256 load(const float *p) { 217 return _mm256_loadu_ps(p); 218 } 219 #endif // __AVX__ 220 221 #if defined(__F16C__) 222 template <> inline __m256 load(const ggml_fp16_t *p) { 223 return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); 224 } 225 #endif // __F16C__ 226 227 #if defined(__AVX512F__) 228 template <> inline __m512 load(const float *p) { 229 return _mm512_loadu_ps(p); 230 } 231 template <> inline __m512 load(const ggml_fp16_t *p) { 232 return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); 233 } 234 #endif // __AVX512F__ 235 236 //////////////////////////////////////////////////////////////////////////////////////////////////// 237 // FLOATING POINT MATRIX MULTIPLICATION 238 239 template <int KN, typename D, typename V, typename TA, typename TB, typename TC> 240 class tinyBLAS { 241 public: 242 tinyBLAS(int64_t k, 243 const TA *A, int64_t lda, 244 const TB *B, int64_t ldb, 245 TC *C, int64_t ldc, 246 int ith, int nth) 247 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { 248 } 249 250 void matmul(int64_t m, int64_t n, int task) { 251 if (task == GGML_TASK_TYPE_COMPUTE) 252 mnpack(0, m, 0, n); 253 } 254 255 private: 256 NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { 257 int64_t mc, nc, mp, np; 258 switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { 259 #if VECTOR_REGISTERS == 32 260 case 0x55: 261 mc = 5; 262 nc = 5; 263 gemm<5, 5>(m0, m, n0, n); 264 break; 265 case 0x45: 266 mc = 4; 267 nc = 5; 268 gemm<4, 5>(m0, m, n0, n); 269 break; 270 case 0x54: 271 mc = 5; 272 nc = 4; 273 gemm<5, 4>(m0, m, n0, n); 274 break; 275 case 0x44: 276 mc = 4; 277 nc = 4; 278 gemm<4, 4>(m0, m, n0, n); 279 break; 280 case 0x53: 281 mc = 5; 282 nc = 3; 283 gemm<5, 3>(m0, m, n0, n); 284 break; 285 case 0x35: 286 mc = 3; 287 nc = 5; 288 gemm<3, 5>(m0, m, n0, n); 289 break; 290 case 0x43: 291 mc = 4; 292 nc = 3; 293 gemm<4, 3>(m0, m, n0, n); 294 break; 295 #else 296 case 0x55: 297 case 0x54: 298 case 0x53: 299 case 0x45: 300 case 0x44: 301 case 0x43: 302 mc = 4; 303 nc = 3; 304 gemm<4, 3>(m0, m, n0, n); 305 break; 306 case 0x35: 307 #endif 308 case 0x34: 309 mc = 3; 310 nc = 4; 311 gemm<3, 4>(m0, m, n0, n); 312 break; 313 case 0x52: 314 mc = 5; 315 nc = 2; 316 gemm<5, 2>(m0, m, n0, n); 317 break; 318 case 0x33: 319 mc = 3; 320 nc = 3; 321 gemm<3, 3>(m0, m, n0, n); 322 break; 323 case 0x25: 324 mc = 2; 325 nc = 5; 326 gemm<2, 5>(m0, m, n0, n); 327 break; 328 case 0x42: 329 mc = 4; 330 nc = 2; 331 gemm<4, 2>(m0, m, n0, n); 332 break; 333 case 0x24: 334 mc = 2; 335 nc = 4; 336 gemm<2, 4>(m0, m, n0, n); 337 break; 338 case 0x32: 339 mc = 3; 340 nc = 2; 341 gemm<3, 2>(m0, m, n0, n); 342 break; 343 case 0x23: 344 mc = 2; 345 nc = 3; 346 gemm<2, 3>(m0, m, n0, n); 347 break; 348 case 0x51: 349 mc = 5; 350 nc = 1; 351 gemm<5, 1>(m0, m, n0, n); 352 break; 353 case 0x41: 354 mc = 4; 355 nc = 1; 356 gemm<4, 1>(m0, m, n0, n); 357 break; 358 case 0x22: 359 mc = 2; 360 nc = 2; 361 gemm<2, 2>(m0, m, n0, n); 362 break; 363 case 0x15: 364 mc = 1; 365 nc = 5; 366 gemm<1, 5>(m0, m, n0, n); 367 break; 368 case 0x14: 369 mc = 1; 370 nc = 4; 371 gemm<1, 4>(m0, m, n0, n); 372 break; 373 case 0x31: 374 mc = 3; 375 nc = 1; 376 gemm<3, 1>(m0, m, n0, n); 377 break; 378 case 0x13: 379 mc = 1; 380 nc = 3; 381 gemm<1, 3>(m0, m, n0, n); 382 break; 383 case 0x21: 384 mc = 2; 385 nc = 1; 386 gemm<2, 1>(m0, m, n0, n); 387 break; 388 case 0x12: 389 mc = 1; 390 nc = 2; 391 gemm<1, 2>(m0, m, n0, n); 392 break; 393 case 0x11: 394 mc = 1; 395 nc = 1; 396 gemm<1, 1>(m0, m, n0, n); 397 break; 398 default: 399 return; 400 } 401 mp = m0 + (m - m0) / mc * mc; 402 np = n0 + (n - n0) / nc * nc; 403 mnpack(mp, m, n0, np); 404 mnpack(m0, m, np, n); 405 } 406 407 template <int RM, int RN> 408 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { 409 int64_t ytiles = (m - m0) / RM; 410 int64_t xtiles = (n - n0) / RN; 411 int64_t tiles = xtiles * ytiles; 412 int64_t duty = (tiles + nth - 1) / nth; 413 int64_t start = duty * ith; 414 int64_t end = start + duty; 415 if (end > tiles) 416 end = tiles; 417 for (int64_t job = start; job < end; ++job) { 418 int64_t ii = m0 + job / xtiles * RM; 419 int64_t jj = n0 + job % xtiles * RN; 420 D Cv[RN][RM] = {}; 421 for (int64_t l = 0; l < k; l += KN) 422 for (int64_t j = 0; j < RN; ++j) 423 for (int64_t i = 0; i < RM; ++i) 424 Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l), 425 load<V>(B + ldb * (jj + j) + l), 426 Cv[j][i]); 427 for (int64_t j = 0; j < RN; ++j) 428 for (int64_t i = 0; i < RM; ++i) 429 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); 430 } 431 } 432 433 const TA *const A; 434 const TB *const B; 435 TC *const C; 436 const int64_t k; 437 const int64_t lda; 438 const int64_t ldb; 439 const int64_t ldc; 440 const int ith; 441 const int nth; 442 }; 443 444 ////////////////////////////////////////////////////////////////////////////////////////// 445 // QUANT ZERO MATRIX MULTIPLICATION 446 447 #if defined(__ARM_FEATURE_DOTPROD) 448 template <typename TA> 449 class tinyBLAS_Q0_ARM { 450 public: 451 tinyBLAS_Q0_ARM(int64_t k, 452 const TA *A, int64_t lda, 453 const block_q8_0 *B, int64_t ldb, 454 float *C, int64_t ldc, 455 int ith, int nth) 456 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { 457 } 458 459 void matmul(int64_t m, int64_t n, int task) { 460 if (task == GGML_TASK_TYPE_COMPUTE) 461 mnpack(0, m, 0, n); 462 } 463 464 private: 465 NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { 466 int64_t mc, nc, mp, np; 467 switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { 468 case 0x33: 469 mc = 3; 470 nc = 3; 471 gemm<3, 3>(m0, m, n0, n); 472 break; 473 case 0x32: 474 mc = 3; 475 nc = 2; 476 gemm<3, 2>(m0, m, n0, n); 477 break; 478 case 0x23: 479 mc = 2; 480 nc = 3; 481 gemm<2, 3>(m0, m, n0, n); 482 break; 483 case 0x22: 484 mc = 2; 485 nc = 2; 486 gemm<2, 2>(m0, m, n0, n); 487 break; 488 case 0x31: 489 mc = 3; 490 nc = 1; 491 gemm<3, 1>(m0, m, n0, n); 492 break; 493 case 0x13: 494 mc = 1; 495 nc = 3; 496 gemm<1, 3>(m0, m, n0, n); 497 break; 498 case 0x21: 499 mc = 2; 500 nc = 1; 501 gemm<2, 1>(m0, m, n0, n); 502 break; 503 case 0x12: 504 mc = 1; 505 nc = 2; 506 gemm<1, 2>(m0, m, n0, n); 507 break; 508 case 0x11: 509 mc = 1; 510 nc = 1; 511 gemm<1, 1>(m0, m, n0, n); 512 break; 513 default: 514 return; 515 } 516 mp = m0 + (m - m0) / mc * mc; 517 np = n0 + (n - n0) / nc * nc; 518 mnpack(mp, m, n0, np); 519 mnpack(m0, m, np, n); 520 } 521 522 template <int RM, int RN> 523 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { 524 int64_t ytiles = (m - m0) / RM; 525 int64_t xtiles = (n - n0) / RN; 526 int64_t tiles = xtiles * ytiles; 527 int64_t duty = (tiles + nth - 1) / nth; 528 int64_t start = duty * ith; 529 int64_t end = start + duty; 530 if (end > tiles) 531 end = tiles; 532 for (int64_t job = start; job < end; ++job) { 533 int64_t ii = m0 + job / xtiles * RM; 534 int64_t jj = n0 + job % xtiles * RN; 535 float32x4_t Cv[RN][RM] = {}; 536 for (int64_t l = 0; l < k; ++l) 537 for (int64_t j = 0; j < RN; ++j) 538 for (int64_t i = 0; i < RM; ++i) 539 Cv[j][i] = vmlaq_n_f32(Cv[j][i], 540 vcvtq_f32_s32(vdotq_s32( 541 vdotq_s32(vdupq_n_s32(0), 542 load_lo(A + lda * (ii + i) + l), 543 load_lo(B + ldb * (jj + j) + l)), 544 load_hi(A + lda * (ii + i) + l), 545 load_hi(B + ldb * (jj + j) + l))), 546 unhalf(A[lda * (ii + i) + l].d) * 547 unhalf(B[ldb * (jj + j) + l].d)); 548 for (int64_t j = 0; j < RN; ++j) 549 for (int64_t i = 0; i < RM; ++i) 550 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); 551 } 552 } 553 554 inline int8x16_t load_lo(const block_q8_0 *b) { 555 return vld1q_s8(b->qs); 556 } 557 558 inline int8x16_t load_hi(const block_q8_0 *b) { 559 return vld1q_s8(b->qs + 16); 560 } 561 562 inline int8x16_t load_lo(const block_q4_0 *b) { 563 return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), 564 vdupq_n_u8(0x0f))), 565 vdupq_n_s8(0x8)); 566 } 567 568 inline int8x16_t load_hi(const block_q4_0 *b) { 569 return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), 570 vdupq_n_s8(0x8)); 571 } 572 573 const TA *const A; 574 const block_q8_0 *const B; 575 float *const C; 576 const int64_t k; 577 const int64_t lda; 578 const int64_t ldb; 579 const int64_t ldc; 580 const int ith; 581 const int nth; 582 }; 583 #endif // __ARM_FEATURE_DOTPROD 584 585 #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) 586 template <typename TA, typename TB, typename TC> 587 class tinyBLAS_Q0_AVX { 588 public: 589 tinyBLAS_Q0_AVX(int64_t k, 590 const TA *A, int64_t lda, 591 const TB *B, int64_t ldb, 592 TC *C, int64_t ldc, 593 int ith, int nth) 594 : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { 595 } 596 597 void matmul(int64_t m, int64_t n, int task) { 598 if (task == GGML_TASK_TYPE_COMPUTE) 599 mnpack(0, m, 0, n); 600 } 601 602 private: 603 void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { 604 int64_t mc, nc, mp, np; 605 switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { 606 #if VECTOR_REGISTERS == 32 607 case 0x44: 608 mc = 4; 609 nc = 4; 610 gemm<4, 4>(m0, m, n0, n); 611 break; 612 case 0x43: 613 mc = 4; 614 nc = 3; 615 gemm<4, 3>(m0, m, n0, n); 616 break; 617 case 0x34: 618 mc = 3; 619 nc = 4; 620 gemm<3, 4>(m0, m, n0, n); 621 break; 622 case 0x33: 623 mc = 3; 624 nc = 3; 625 gemm<3, 3>(m0, m, n0, n); 626 break; 627 case 0x42: 628 mc = 4; 629 nc = 2; 630 gemm<4, 2>(m0, m, n0, n); 631 break; 632 case 0x24: 633 mc = 2; 634 nc = 4; 635 gemm<2, 4>(m0, m, n0, n); 636 break; 637 #else 638 case 0x44: 639 case 0x43: 640 case 0x42: 641 mc = 4; 642 nc = 2; 643 gemm<4, 2>(m0, m, n0, n); 644 break; 645 case 0x34: 646 case 0x24: 647 mc = 2; 648 nc = 4; 649 gemm<2, 4>(m0, m, n0, n); 650 break; 651 case 0x33: 652 #endif 653 case 0x32: 654 mc = 3; 655 nc = 2; 656 gemm<3, 2>(m0, m, n0, n); 657 break; 658 case 0x23: 659 mc = 2; 660 nc = 3; 661 gemm<2, 3>(m0, m, n0, n); 662 break; 663 case 0x41: 664 mc = 4; 665 nc = 1; 666 gemm<4, 1>(m0, m, n0, n); 667 break; 668 case 0x22: 669 mc = 2; 670 nc = 2; 671 gemm<2, 2>(m0, m, n0, n); 672 break; 673 case 0x14: 674 mc = 1; 675 nc = 4; 676 gemm<1, 4>(m0, m, n0, n); 677 break; 678 case 0x31: 679 mc = 3; 680 nc = 1; 681 gemm<3, 1>(m0, m, n0, n); 682 break; 683 case 0x13: 684 mc = 1; 685 nc = 3; 686 gemm<1, 3>(m0, m, n0, n); 687 break; 688 case 0x21: 689 mc = 2; 690 nc = 1; 691 gemm<2, 1>(m0, m, n0, n); 692 break; 693 case 0x12: 694 mc = 1; 695 nc = 2; 696 gemm<1, 2>(m0, m, n0, n); 697 break; 698 case 0x11: 699 mc = 1; 700 nc = 1; 701 gemm<1, 1>(m0, m, n0, n); 702 break; 703 default: 704 return; 705 } 706 mp = m0 + (m - m0) / mc * mc; 707 np = n0 + (n - n0) / nc * nc; 708 mnpack(mp, m, n0, np); 709 mnpack(m0, m, np, n); 710 } 711 712 template <int RM, int RN> 713 NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { 714 int64_t ytiles = (m - m0) / RM; 715 int64_t xtiles = (n - n0) / RN; 716 int64_t tiles = xtiles * ytiles; 717 int64_t duty = (tiles + nth - 1) / nth; 718 int64_t start = duty * ith; 719 int64_t end = start + duty; 720 if (end > tiles) 721 end = tiles; 722 for (int64_t job = start; job < end; ++job) { 723 int64_t ii = m0 + job / xtiles * RM; 724 int64_t jj = n0 + job % xtiles * RN; 725 __m256 Cv[RN][RM] = {}; 726 for (int64_t l = 0; l < k; ++l) 727 for (int64_t j = 0; j < RN; ++j) 728 for (int64_t i = 0; i < RM; ++i) { 729 #if defined(__AVX2__) 730 __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), 731 load(A + lda * (ii + i) + l)), 732 _mm256_sign_epi8(load(B + ldb * (jj + j) + l), 733 load(A + lda * (ii + i) + l))); 734 #else 735 __m128i ali0 = load0(A + lda * (ii + i) + l); 736 __m128i ali1 = load1(A + lda * (ii + i) + l); 737 __m128i blj0 = load0(B + ldb * (jj + j) + l); 738 __m128i blj1 = load1(B + ldb * (jj + j) + l); 739 740 __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); 741 __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); 742 __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); 743 __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); 744 745 // updot 746 const __m128i oneFill = _mm_set1_epi16(1); 747 __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); 748 __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); 749 __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); 750 #endif 751 Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * 752 unhalf(B[ldb * (jj + j) + l].d)), 753 udTmp, 754 Cv[j][i]); 755 } 756 for (int64_t j = 0; j < RN; ++j) 757 for (int64_t i = 0; i < RM; ++i) 758 C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); 759 } 760 } 761 762 inline __m256i load(const block_q8_0 *b) { 763 return _mm256_loadu_si256((const __m256i *)b->qs); 764 } 765 766 inline __m128i load0(const block_q8_0 *b) { 767 return _mm_loadu_si128((const __m128i *)b->qs); 768 } 769 770 inline __m128i load1(const block_q8_0 *b) { 771 return _mm_loadu_si128(((const __m128i *)b->qs) + 1); 772 } 773 774 inline __m256i load(const block_q4_0 *b) { 775 return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); 776 } 777 778 inline __m128i load0(const block_q4_0 *b) { 779 const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); 780 return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); 781 } 782 783 inline __m128i load1(const block_q4_0 *b) { 784 const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); 785 return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); 786 } 787 788 inline __m256 updot(__m256i u, __m256i s) { 789 __m256i res; 790 #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) 791 res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); 792 #else 793 res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); 794 #endif 795 return _mm256_cvtepi32_ps(res); 796 } 797 798 static inline __m256i denibble(const uint8_t *p) { 799 __m128i x = _mm_loadu_si128((const __m128i *)p); 800 return _mm256_and_si256(_mm256_set1_epi8(15), 801 _mm256_insertf128_si256(_mm256_castsi128_si256(x), 802 _mm_srli_epi16(x, 4), 1)); 803 } 804 805 const TA *const A; 806 const TB *const B; 807 TC *const C; 808 const int64_t k; 809 const int64_t lda; 810 const int64_t ldb; 811 const int64_t ldc; 812 const int ith; 813 const int nth; 814 }; 815 #endif // __AVX__ 816 817 } // namespace 818 819 /** 820 * Performs optimized matrix multiplication on CPU. 821 * 822 * This subroutine may compute C = Aᵀ * B with column major ordering. 823 * Despite its name, this isn't a generalized implementation. Work is 824 * only performed when a handwritten kernel is written and available. 825 * Otherwise the caller should fall back to a general matmul routine. 826 * 827 * For example, for single-threaded single-precision GEMM you can say 828 * 829 * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, 830 * 0, 1, GGML_TASK_TYPE_COMPUTE, 831 * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32); 832 * 833 * @param m is rows in `A` and `C` 834 * @param n is cols in `B` and `C` 835 * @param k is cols in `A` and rows in `B` 836 * @param A is first input matrix (always transposed) 837 * @param lda is row stride of `A` 838 * @param B is second input matrix (never transposed) 839 * @param ldb is row stride of `B` 840 * @param C is input/output array of output matrices 841 * @param ldc is row stride of `C` 842 * @param ith is thread id (must be less than `nth`) 843 * @param nth is number of threads (must be greater than zero) 844 * @param task is GGML task type 845 * @param Atype is GGML data type of `A` 846 * @param Btype is GGML data type of `B` 847 * @param Ctype is GGML data type of `C` 848 * @return true if this function was able to service the matmul request 849 */ 850 bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, 851 int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { 852 853 assert(m >= 0); 854 assert(n >= 0); 855 assert(k >= 0); 856 assert(lda >= k); 857 assert(ldb >= k); 858 assert(ldc >= m); 859 assert(nth > 0); 860 assert(ith < nth); 861 862 if (Ctype != GGML_TYPE_F32) 863 return false; 864 865 switch (Atype) { 866 867 case GGML_TYPE_F32: { 868 if (Btype != GGML_TYPE_F32) 869 return false; 870 #if defined(__AVX512F__) 871 if (k % 16) 872 return false; 873 tinyBLAS<16, __m512, __m512, float, float, float> tb{ 874 k, (const float *)A, lda, 875 (const float *)B, ldb, 876 (float *)C, ldc, 877 ith, nth}; 878 tb.matmul(m, n, task); 879 return true; 880 #elif defined(__AVX__) || defined(__AVX2__) 881 if (k % 8) 882 return false; 883 tinyBLAS<8, __m256, __m256, float, float, float> tb{ 884 k, (const float *)A, lda, 885 (const float *)B, ldb, 886 (float *)C, ldc, 887 ith, nth}; 888 tb.matmul(m, n, task); 889 return true; 890 #elif defined(__ARM_NEON) 891 if (n < 4) 892 return false; 893 if (k % 4) 894 return false; 895 tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ 896 k, (const float *)A, lda, 897 (const float *)B, ldb, 898 (float *)C, ldc, 899 ith, nth}; 900 tb.matmul(m, n, task); 901 return true; 902 #else 903 return false; 904 #endif 905 } 906 907 case GGML_TYPE_F16: { 908 #if defined(__AVX512F__) 909 if (k % 16) 910 return false; 911 if (Btype != GGML_TYPE_F32) 912 return false; 913 tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ 914 k, (const ggml_fp16_t *)A, lda, 915 (const float *)B, ldb, 916 (float *)C, ldc, 917 ith, nth}; 918 tb.matmul(m, n, task); 919 return true; 920 #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) 921 if (k % 8) 922 return false; 923 if (Btype != GGML_TYPE_F32) 924 return false; 925 tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ 926 k, (const ggml_fp16_t *)A, lda, 927 (const float *)B, ldb, 928 (float *)C, ldc, 929 ith, nth}; 930 tb.matmul(m, n, task); 931 return true; 932 #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) 933 if (n < 8) 934 return false; 935 if (k % 8) 936 return false; 937 if (Btype != GGML_TYPE_F16) 938 return false; 939 tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ 940 k, (const ggml_fp16_t *)A, lda, 941 (const ggml_fp16_t *)B, ldb, 942 (float *)C, ldc, 943 ith, nth}; 944 tb.matmul(m, n, task); 945 return true; 946 #elif defined(__ARM_NEON) && !defined(_MSC_VER) 947 if (k % 4) 948 return false; 949 if (Btype != GGML_TYPE_F32) 950 return false; 951 tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ 952 k, (const ggml_fp16_t *)A, lda, 953 (const float *)B, ldb, 954 (float *)C, ldc, 955 ith, nth}; 956 tb.matmul(m, n, task); 957 return true; 958 #else 959 return false; 960 #endif 961 } 962 963 case GGML_TYPE_Q8_0: { 964 if (Btype != GGML_TYPE_Q8_0) 965 return false; 966 #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) 967 tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{ 968 k, (const block_q8_0 *)A, lda, 969 (const block_q8_0 *)B, ldb, 970 (float *)C, ldc, 971 ith, nth}; 972 tb.matmul(m, n, task); 973 return true; 974 #elif defined(__ARM_FEATURE_DOTPROD) 975 tinyBLAS_Q0_ARM<block_q8_0> tb{ 976 k, (const block_q8_0 *)A, lda, 977 (const block_q8_0 *)B, ldb, 978 (float *)C, ldc, 979 ith, nth}; 980 tb.matmul(m, n, task); 981 return true; 982 #else 983 return false; 984 #endif 985 } 986 987 case GGML_TYPE_Q4_0: { 988 if (Btype != GGML_TYPE_Q8_0) 989 return false; 990 #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) 991 tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{ 992 k, (const block_q4_0 *)A, lda, 993 (const block_q8_0 *)B, ldb, 994 (float *)C, ldc, 995 ith, nth}; 996 tb.matmul(m, n, task); 997 return true; 998 #elif defined(__ARM_FEATURE_DOTPROD) 999 tinyBLAS_Q0_ARM<block_q4_0> tb{ 1000 k, (const block_q4_0 *)A, lda, 1001 (const block_q8_0 *)B, ldb, 1002 (float *)C, ldc, 1003 ith, nth}; 1004 tb.matmul(m, n, task); 1005 return true; 1006 #else 1007 return false; 1008 #endif 1009 } 1010 1011 default: 1012 return false; 1013 } 1014 1015 (void)m; 1016 (void)n; 1017 (void)k; 1018 (void)A; 1019 (void)lda; 1020 (void)B; 1021 (void)ldb; 1022 (void)C; 1023 (void)ldc; 1024 (void)ith; 1025 (void)nth; 1026 (void)task; 1027 (void)Atype; 1028 (void)Btype; 1029 (void)Ctype; 1030 }