/ vulkan-shaders / mul_mat_vec_q6_k.comp
mul_mat_vec_q6_k.comp
 1  #version 450
 2  
 3  #include "mul_mat_vec_base.comp"
 4  
 5  layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
 6  
 7  shared FLOAT_TYPE tmp[32];
 8  
 9  void main() {
10      const uint row = gl_WorkGroupID.x;
11  
12      uint a_offset, b_offset, d_offset;
13      get_offsets(a_offset, b_offset, d_offset);
14  
15      const uint num_blocks_per_row = p.ncols / QUANT_K;
16      const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
17  
18      const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION;  // 0...31 or 0...16
19      const uint ix  = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION;  // 0 or 0, 1
20  
21      const uint step = 16/K_QUANTS_PER_ITERATION;            // 16 or 8
22  
23      const uint v_im = tid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128...
24      const uint v_in = tid - step*v_im;                      // 0...15 or 0...7
25  
26  #if K_QUANTS_PER_ITERATION == 1
27      const uint l0 = v_in;                                   // 0...15
28      const uint is = 0;
29  #else
30      const uint l0 = 4 * v_in;                               // 0, 4, 8, ..., 28
31      const uint is = v_in / 4;
32  #endif
33  
34      const uint ql_offset = 64*v_im + l0;
35      const uint qh_offset = 32*v_im + l0;
36      const uint s_offset  =  8*v_im + is;
37      const uint y_offset = 128*v_im + l0;
38  
39      tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
40  
41      [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
42          const uint y_idx   = i * QUANT_K + y_offset;
43  
44          const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
45  
46  #if K_QUANTS_PER_ITERATION == 1
47          FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx +  0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x03) << 4)) - 32)
48                         + FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
49                         + FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x0c) << 2)) - 32)
50                         + FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
51                         + FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset +  0]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0x30) >> 0)) - 32)
52                         + FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
53                         + FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32]  >> 4) | ((data_a[ib0 + i].qh[qh_offset +  0] & 0xc0) >> 2)) - 32)
54                         + FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48]  >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
55          tmp[16 * ix + tid] += sum;
56  #else
57          FLOAT_TYPE sum = FLOAT_TYPE(0.0);
58          [[unroll]] for (int l = 0; l < 4; ++l) {
59              sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
60                   + FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
61                   + FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
62                   + FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32]  >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
63          }
64          tmp[16 * ix + tid] += sum;
65  #endif
66      }
67  
68      // sum up partial sums and write back result
69      barrier();
70      [[unroll]] for (uint s = 16; s > 0; s >>= 1) {
71          if (tid < s) {
72              tmp[tid] += tmp[tid + s];
73         }
74          barrier();
75      }
76      if (tid == 0) {
77          data_d[d_offset + row] = D_TYPE(tmp[0]);
78      }
79  }