/ kompute-shaders / op_getrows_f32.comp
op_getrows_f32.comp
1 #version 450 2 3 #include "common.comp" 4 5 layout(local_size_x = 1) in; 6 7 layout (binding = 0) readonly buffer tensorInA { float inA[]; }; 8 layout (binding = 1) readonly buffer tensorInB { int inB[]; }; 9 layout (binding = 2) writeonly buffer tensorOut { float out_[]; }; 10 11 layout (push_constant) uniform parameter { 12 uint inAOff; 13 uint inBOff; 14 uint outOff; 15 int ne00; 16 int nb01; 17 int nb1; 18 } pcs; 19 20 void dequantize_row_f32(uint x /*Based from inA unaligned*/, uint y /*Based from out_*/, int k) { 21 for (int j = 0; j < k; j++) { 22 out_[y + j] = inA[x + j]; 23 } 24 } 25 26 void main() { 27 const uint i = gl_WorkGroupID.x; 28 const int r = inB[i + pcs.inBOff]; 29 30 dequantize_row_f32(r*pcs.nb01/4 + pcs.inAOff, i*pcs.nb1/4 + pcs.outOff, pcs.ne00); 31 }