/ vulkan-shaders / mul_mat_vec_nc.comp
mul_mat_vec_nc.comp
 1  #version 450
 2  
 3  #extension GL_EXT_control_flow_attributes : enable
 4  #extension GL_EXT_shader_16bit_storage : require
 5  
 6  #define BLOCK_SIZE 32
 7  #define FLOAT_TYPE float
 8  
 9  layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
10  
11  layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
12  layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
13  layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
14  
15  layout (push_constant) uniform parameter
16  {
17      uint ncols_x;
18      uint nrows_x;
19      uint row_stride_x;
20      uint channel_stride_x;
21      uint channel_x_divisor;
22      uint b_offset;
23      uint d_offset;
24  } p;
25  
26  shared FLOAT_TYPE tmp[BLOCK_SIZE];
27  
28  void main() {
29      const uint tid       = gl_LocalInvocationID.x;
30      const uint row_x     = gl_GlobalInvocationID.y;
31      const uint channel   = gl_GlobalInvocationID.z;
32      const uint channel_x = channel / p.channel_x_divisor;
33  
34      const uint nrows_y   = p.ncols_x;
35      const uint nrows_dst = p.nrows_x;
36      const uint row_dst   = row_x;
37  
38      const uint idst = channel*nrows_dst + row_dst;
39  
40      tmp[tid] = 0.0f;
41  
42      for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
43          const uint col_x = col_x0 + tid;
44  
45          if (col_x >= p.ncols_x) {
46              break;
47          }
48  
49          const uint row_y = col_x;
50  
51          const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
52          const uint iy = channel*nrows_y + row_y;
53  
54          const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
55  
56          tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
57      }
58  
59      // sum up partial sums and write back result
60      barrier();
61      [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
62          if (tid < s) {
63              tmp[tid] += tmp[tid + s];
64          }
65          barrier();
66      }
67  
68      if (tid == 0) {
69          dst[idst] = tmp[0];
70      }
71  }