/ kompute-shaders / rope_common.comp
rope_common.comp
1 #include "common.comp" 2 3 // TODO: use a local size of 32 or more (Metal uses 1024) 4 layout(local_size_x = 1) in; 5 6 layout (push_constant) uniform parameter { 7 uint inAOff; 8 uint inBOff; 9 uint outOff; 10 int n_dims; 11 int mode; 12 int n_ctx_orig; 13 float freq_base; 14 float freq_scale; 15 float ext_factor; 16 float attn_factor; 17 float beta_fast; 18 float beta_slow; 19 uint nb00; 20 uint nb01; 21 uint nb02; 22 uint nb03; 23 int ne0; 24 uint nb0; 25 uint nb1; 26 uint nb2; 27 uint nb3; 28 } pcs; 29 30 float rope_yarn_ramp(const float low, const float high, const float i0) { 31 const float y = (i0 / 2 - low) / max(0.001f, high - low); 32 return 1.0f - min(1.0f, max(0.0f, y)); 33 } 34 35 // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn 36 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. 37 void rope_yarn( 38 float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale, 39 out float cos_theta, out float sin_theta 40 ) { 41 // Get n-d rotational scaling corrected for extrapolation 42 float theta_interp = freq_scale * theta_extrap; 43 float theta = theta_interp; 44 if (ext_factor != 0.0f) { 45 float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; 46 theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; 47 48 // Get n-d magnitude scaling corrected for interpolation 49 mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); 50 } 51 cos_theta = cos(theta) * mscale; 52 sin_theta = sin(theta) * mscale; 53 } 54 55 // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get 56 // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` 57 float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { 58 return n_dims * log(n_ctx_orig / (n_rot * TWOPI_F)) / (2 * log(base)); 59 } 60 61 void rope_yarn_corr_dims( 62 int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, out float dims[2] 63 ) { 64 // start and end correction dims 65 dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); 66 dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); 67 }