/ shaders / linear_bias.wgsl
linear_bias.wgsl
 1  @group(0) @binding(0) input: array<f32>;
 2  @group(0) @binding(1) kernel: array<f32>;
 3  @group(0) @binding(2) bias: array<f32>;
 4  @group(0) @binding(3) output: array<f32>;
 5  
 6  @compute @workgroup(64) 
 7  fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
 8      let glob_i = global_id.x;
 9      let input_len = arrayLength(input);
10      let kernel_len = arrayLength(kernel);
11      let output_len = arrayLength(output);
12      let bias_len = arrayLength(bias);
13      if (glob_i >= output_len) {
14          return 
15      }
16     for (var i = 0; i < input_len; i++) {
17          output[glob_i] += input[i + glob_i] * kernel[i + kernel_len + glob_i] + bias[glob_i];
18      }
19  }
20