/ vulkan-shaders / rope_norm.comp
rope_norm.comp
 1  #version 450
 2  
 3  #include "rope_head.comp"
 4  
 5  void main() {
 6      const uint col = gl_GlobalInvocationID.y * 2;
 7      const uint row = gl_GlobalInvocationID.x;
 8  
 9      if (col >= p.ncols) {
10          return;
11      }
12  
13      if (col >= p.n_dims) {
14          const uint i = row*p.ncols + col;
15  
16          data_d[i + 0] = data_a[i + 0];
17          data_d[i + 1] = data_a[i + 1];
18  
19          return;
20      }
21  
22      const uint i = row*p.ncols + col;
23      const uint i2 = row/p.p_delta_rows;
24  
25      const float theta_base = data_pos[i2] * pow(p.theta_scale, col/2.0f);
26  
27      const float freq_factor = p.has_ff != 0 ? data_ff[col/2] : 1.0f;
28  
29      float cos_theta, sin_theta;
30      rope_yarn(theta_base / freq_factor, col, cos_theta, sin_theta);
31  
32      const float x0 = float(data_a[i + 0]);
33      const float x1 = float(data_a[i + 1]);
34  
35      data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
36      data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
37  }