/ vulkan-shaders / gelu.comp
gelu.comp
 1  #version 450
 2  
 3  #include "generic_head.comp"
 4  #include "types.comp"
 5  
 6  #extension GL_EXT_control_flow_attributes : enable
 7  
 8  layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
 9  
10  layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11  layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12  
13  void main() {
14      const float GELU_COEF_A    = 0.044715f;
15      const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
16      const uint i = gl_GlobalInvocationID.x;
17  
18      if (i >= p.KX) {
19          return;
20      }
21  
22      const float xi = float(data_a[i]);
23      const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
24      data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
25  }