Skip to content

Commit

Permalink
vulkan : implement YaRN RoPE scaling (ggerganov#2268)
Browse files Browse the repository at this point in the history
The NeoX cur_rot part is different because I'm pretty sure my original
implementation was wrong.
  • Loading branch information
cebtenzzre committed Nov 23, 2023
1 parent 1829f1d commit 208cd52
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 69 deletions.
36 changes: 23 additions & 13 deletions ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1195,8 +1195,8 @@ void ggml_vk_rope(
const std::shared_ptr<kp::Tensor>& inB,
const std::shared_ptr<kp::Tensor>& out,
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
ggml_type src0t, int32_t n_dims, int32_t mode,
float freq_base, float freq_scale,
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_orig_ctx,
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
int32_t ne01, int32_t ne02, int32_t ne03,
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
int32_t ne0,
Expand Down Expand Up @@ -1224,15 +1224,15 @@ void ggml_vk_rope(

struct PushConstants {
uint32_t inAOff, inBOff, outOff;
int32_t n_dims, mode;
float freq_base, freq_scale;
int32_t n_dims, mode, n_orig_ctx;
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
uint32_t nb00, nb01, nb02, nb03;
int32_t ne0;
uint32_t nb0, nb1, nb2, nb3;
} pushConsts {
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
n_dims, mode,
freq_base, freq_scale,
n_dims, mode, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
nb00, nb01, nb02, nb03,
ne0,
nb0, nb1, nb2, nb3
Expand Down Expand Up @@ -1545,13 +1545,23 @@ void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph
GGML_ASSERT(ne10 == ne02);
GGML_ASSERT(src0t == dstt);
// const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
float freq_base;
float freq_scale;
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
ggml_vk_rope(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, freq_base, freq_scale, ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3);
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];

float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
ggml_vk_rope(
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
);
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
Expand Down
1 change: 1 addition & 0 deletions kompute/common.comp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#define GELU_COEF_A 0.044715
#define SQRT_2_OVER_PI 0.79788456080286535587989211986876
#define TWOPI_F 6.283185307179586f

#define QK_K 256

Expand Down
40 changes: 12 additions & 28 deletions kompute/op_rope_f16.comp
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,32 @@

#version 450

#include "common.comp"

// TODO: use a local size of 32 or more (Metal uses 1024)
layout(local_size_x = 1) in;
#include "rope_common.comp"

layout(binding = 0) buffer restrict readonly tensorInA { float16_t inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };

layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int n_dims;
int mode;
float freq_base;
float freq_scale;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;

void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;

const bool is_neox = (pcs.mode & 2) != 0;

float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);

const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);

const int p = inB[pcs.inBOff + i2];

float theta = pcs.freq_scale * float(p);
float theta = float(p);

if (!is_neox) {
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);

theta *= theta_scale;

Expand All @@ -68,8 +50,10 @@ void main() {
const float inv_ndims = -1.f/pcs.n_dims;
for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) {
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
const uint cur_rot = ib * pcs.n_dims + ic;

float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);

theta *= theta_scale;

Expand Down
40 changes: 12 additions & 28 deletions kompute/op_rope_f32.comp
Original file line number Diff line number Diff line change
Expand Up @@ -8,50 +8,32 @@

#version 450

#include "common.comp"

// TODO: use a local size of 32 or more (Metal uses 1024)
layout(local_size_x = 1) in;
#include "rope_common.comp"

layout(binding = 0) buffer restrict readonly tensorInA { float inA[]; };
layout(binding = 1) buffer restrict readonly tensorInB { int inB[]; };
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };

layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int n_dims;
int mode;
float freq_base;
float freq_scale;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;

void main() {
const uint i3 = gl_WorkGroupID.z;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_WorkGroupID.x;

const bool is_neox = (pcs.mode & 2) != 0;

float corr_dims[2];
rope_yarn_corr_dims(pcs.n_dims, pcs.n_orig_ctx, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);

const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);

const int p = inB[pcs.inBOff + i2];

float theta = pcs.freq_scale * float(p);
float theta = float(p);

if (!is_neox) {
for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);

theta *= theta_scale;

Expand All @@ -68,8 +50,10 @@ void main() {
const float inv_ndims = -1.f/pcs.n_dims;
for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) {
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);
const uint cur_rot = ib * pcs.n_dims + ic;

float cos_theta, sin_theta;
rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);

theta *= theta_scale;

Expand Down
75 changes: 75 additions & 0 deletions kompute/rope_common.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/**
* Copyright (c) 2023 Nomic, Inc. All rights reserved.
*
* This software is licensed under the terms of the Software for Open Models License (SOM),
* version 1.0, as detailed in the LICENSE_SOM.txt file. A copy of this license should accompany
* this software. Except as expressly granted in the SOM license, all rights are reserved by Nomic, Inc.
*/

#include "common.comp"

// TODO: use a local size of 32 or more (Metal uses 1024)
layout(local_size_x = 1) in;

layout (push_constant) uniform parameter {
uint inAOff;
uint inBOff;
uint outOff;
int n_dims;
int mode;
int n_orig_ctx;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;
uint nb00;
uint nb01;
uint nb02;
uint nb03;
int ne0;
uint nb0;
uint nb1;
uint nb2;
uint nb3;
} pcs;

float rope_yarn_ramp(const float low, const float high, const float i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}

// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
void rope_yarn(
float theta_extrap, float freq_scale, float corr_dims[2], float i0, float ext_factor, float mscale,
out float cos_theta, out float sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;

// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
}
cos_theta = cos(theta) * mscale;
sin_theta = sin(theta) * mscale;
}

// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) {
return n_dims * log(n_orig_ctx / (n_rot * TWOPI_F)) / (2 * log(base));
}

void rope_yarn_corr_dims(
int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, out float dims[2]
) {
// start and end correction dims
dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base)));
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
}

0 comments on commit 208cd52

Please sign in to comment.