/ vulkan-shaders / mul_mat_vec_q4_k.comp
mul_mat_vec_q4_k.comp
  1  #version 450
  2  
  3  #include "mul_mat_vec_base.comp"
  4  
  5  layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
  6  
  7  shared FLOAT_TYPE tmp[32];
  8  
  9  void main() {
 10      const uint row = gl_WorkGroupID.x;
 11  
 12      uint a_offset, b_offset, d_offset;
 13      get_offsets(a_offset, b_offset, d_offset);
 14  
 15      const uint num_blocks_per_row = p.ncols / QUANT_K;
 16      const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
 17  
 18      const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
 19      const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
 20  
 21      const uint step = 8/K_QUANTS_PER_ITERATION;             // 8 or 4
 22  
 23      const uint il = tid/step;                               // 0...3
 24      const uint ir = tid - step*il;                          // 0...7 or 0...3
 25      const uint n =  2 * K_QUANTS_PER_ITERATION;             // 2 or 4
 26  
 27      const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
 28      const uint v_in = il % 2;
 29  
 30      const uint l0 = n * (2 * ir + v_in);            // 0...15
 31      const uint q_offset = 32*v_im + l0;
 32      const uint y_offset = 64*v_im + l0;
 33  
 34      tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
 35  
 36      [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 37          const uint y1_idx = i * QUANT_K + y_offset;
 38          const uint y2_idx = y1_idx + 128;
 39  
 40          const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
 41          const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
 42  
 43          const uint8_t sc0 = uint8_t(  data_a[ib0 + i].scales[v_im * 2    ]       & 0x3f);
 44          const uint8_t sc1 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 1]       & 0x3f);
 45          const uint8_t sc2 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 4]       & 0x3f);
 46          const uint8_t sc3 = uint8_t(  data_a[ib0 + i].scales[v_im * 2 + 5]       & 0x3f);
 47          const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2    ] & 0xc0) >> 2));
 48          const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9]       & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
 49          const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
 50          const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
 51  
 52  #if K_QUANTS_PER_ITERATION == 2
 53          const uint8_t q4_0  = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
 54          const uint8_t q4_1  = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
 55          const uint8_t q4_2  = uint8_t(data_a[ib0 + i].qs[q_offset +  2] & 0xf);
 56          const uint8_t q4_3  = uint8_t(data_a[ib0 + i].qs[q_offset +  3] & 0xf);
 57          const uint8_t q4_4  = uint8_t(data_a[ib0 + i].qs[q_offset     ]  >> 4);
 58          const uint8_t q4_5  = uint8_t(data_a[ib0 + i].qs[q_offset +  1]  >> 4);
 59          const uint8_t q4_6  = uint8_t(data_a[ib0 + i].qs[q_offset +  2]  >> 4);
 60          const uint8_t q4_7  = uint8_t(data_a[ib0 + i].qs[q_offset +  3]  >> 4);
 61          const uint8_t q4_8  = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
 62          const uint8_t q4_9  = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
 63          const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf);
 64          const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf);
 65          const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
 66          const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
 67          const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66]  >> 4);
 68          const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67]  >> 4);
 69  
 70          const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1 + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3);
 71          const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_5 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7);
 72          const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx]) * q4_8 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_9 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * q4_10 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11);
 73          const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_12 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_13 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * q4_14 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15);
 74          const FLOAT_TYPE smin = FLOAT_TYPE(
 75              FLOAT_TYPE(data_b[b_offset + y1_idx    ]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx    ]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
 76            + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
 77            + FLOAT_TYPE(data_b[b_offset + y1_idx + 2]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 34]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 2]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 34]) * sc7
 78            + FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7
 79          );
 80          tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
 81  #else
 82          const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset     ] & 0xf);
 83          const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset +  1] & 0xf);
 84          const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset     ]  >> 4);
 85          const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset +  1]  >> 4);
 86          const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
 87          const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
 88          const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64]  >> 4);
 89          const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65]  >> 4);
 90  
 91          const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx     ]) * q4_0  + FLOAT_TYPE(data_b[b_offset + y1_idx +  1]) * q4_1);
 92          const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2  + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
 93          const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx     ]) * q4_4  + FLOAT_TYPE(data_b[b_offset + y2_idx +  1]) * q4_5);
 94          const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
 95          const FLOAT_TYPE smin = FLOAT_TYPE(
 96              FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
 97            + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
 98          );
 99  
100          tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
101  #endif
102      }
103  
104      // sum up partial sums and write back result
105      barrier();
106      [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
107          if (tid < s) {
108              tmp[tid] += tmp[tid + s];
109          }
110          barrier();
111      }
112      if (tid == 0) {
113          data_d[d_offset + row] = D_TYPE(tmp[0]);
114      }
115  }