/ kompute-shaders / rope_common.comp
rope_common.comp
 1  #include "common.comp"
 2  
 3  // TODO: use a local size of 32 or more (Metal uses 1024)
 4  layout(local_size_x = 1) in;
 5  
 6  layout (push_constant) uniform parameter {
 7      uint inAOff;
 8      uint inBOff;
 9      uint outOff;
10      int n_dims;
11      int mode;
12      int n_ctx_orig;
13      float freq_base;
14      float freq_scale;
15      float ext_factor;
16      float attn_factor;
17      float beta_fast;
18      float beta_slow;
19      uint nb00;
20      uint nb01;
21      uint nb02;
22      uint nb03;
23      int ne0;
24      uint nb0;
25      uint nb1;
26      uint nb2;
27      uint nb3;
28  } pcs;
29  
30  float rope_yarn_ramp(const float low, const float high, const float i0) {
31      const float y = (i0 / 2 - low) / max(0.001f, high - low);
32      return 1.0f - min(1.0f, max(0.0f, y));
33  }
34  
35  // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
36  // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
37  void rope_yarn(
38      float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
39      out float cos_theta, out float sin_theta
40  ) {
41      // Get n-d rotational scaling corrected for extrapolation
42      float theta_interp = freq_scale * theta_extrap;
43      float theta = theta_interp;
44      if (ext_factor != 0.0f) {
45          float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
46          theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
47  
48          // Get n-d magnitude scaling corrected for interpolation
49          mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
50      }
51      cos_theta = cos(theta) * mscale;
52      sin_theta = sin(theta) * mscale;
53  }
54  
55  // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
56  // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
57  float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
58      return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base));
59  }
60  
61  void rope_yarn_corr_dims(
62      int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2]
63  ) {
64      // start and end correction dims
65      dims[0] = max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
66      dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
67  }