Skip to content

Commit

Permalink
initial CUDA implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
cebtenzzre committed Jul 19, 2023
1 parent b43bfe8 commit ce59171
Showing 1 changed file with 73 additions and 7 deletions.
80 changes: 73 additions & 7 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1875,8 +1875,53 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
cpy_1(cx + x_offset, cdst + dst_offset);
}

static __device__ void ntkv2_ramp(const float low, const float high, const int i0, float *out) {
const float y = (i0 / 2 - low) / min(0.001f, high - low);
*out = 1.0f - min(1.0f, max(0.0f, y));
}

// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static __device__ void compute_ntkv2(
float theta_base,
float theta_ntk,
float dims_over_base,
float freq_scale,
int64_t i0,
float ntk_factor,
float extrapolation_factor,
int n_dims,
float *theta) {
// Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
// Do not change unless there is a good reason for doing so!
// These are precomputed because CUDA doesn't allow dynamic init of device constants
static const float low_1p = 2.6135630f;
static const float high_1p = 2.7817991f;
static const float low_2p = 1.5070765f;
static const float high_2p = 2.5467973f;

// start and end correction factors
const float low_1 = max(0.0f, floorf(low_1p * dims_over_base));
const float high_1 = min(n_dims - 1.0f, ceilf(high_1p * dims_over_base));
const float low_2 = max(0.0f, floorf(low_2p * dims_over_base));
const float high_2 = min(n_dims - 1.0f, ceilf(high_2p * dims_over_base));

float ramp_mix;

const float theta_linear = freq_scale * theta_base;
ntkv2_ramp(low_1, high_1, i0, &ramp_mix);
ramp_mix *= ntk_factor;
const float theta_mix = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
ntkv2_ramp(low_2, high_2, i0, &ramp_mix);
ramp_mix *= extrapolation_factor;
*theta = theta_mix * (1 - ramp_mix) + theta_base * ramp_mix;
}

// rope == RoPE == rotary positional embedding
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) {
static __global__ void rope_f32(const float * x, float * dst, const int ncols, const int n_dims, const float freq_base,
const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
const float theta_ntk_scale, const float dims_over_base, const float p) {

const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);

if (col >= ncols) {
Expand All @@ -1886,7 +1931,11 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;

const float theta = p*powf(theta_scale, col/2);
const float theta_base = p*powf(theta_scale, col/2);
const float theta_ntk = p*powf(theta_ntk_scale, col/2);
float theta;
compute_ntkv2(theta_base, theta_ntk, dims_over_base,
freq_scale, col, ntk_factor, extrapolation_factor, n_dims, &theta);
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);

Expand Down Expand Up @@ -2365,12 +2414,17 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
}

static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) {
static void rope_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int n_dims, const float freq_base,
const float freq_scale, const float ntk_factor, const float extrapolation_factor, const float theta_scale,
const float theta_ntk_scale, const float dims_over_base, const float p, cudaStream_t stream) {

GGML_ASSERT(nrows % 2 == 0);
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(num_blocks_x, nrows, 1);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, n_dims, freq_base, freq_scale, ntk_factor,
extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p);
}

static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
Expand Down Expand Up @@ -2947,12 +3001,23 @@ inline void ggml_cuda_op_rope(
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;

float freq_base;
float freq_scale;
float ntk_factor;
float extrapolation_factor;

const int n_past = ((int32_t *) src1->data)[0];
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];

const float theta_scale = powf(10000.0, -2.0f/n_dims);
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
memcpy(&ntk_factor, (int32_t *) src1->data + 6, sizeof(float));
memcpy(&extrapolation_factor, (int32_t *) src1->data + 7, sizeof(float));

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
const float dims_over_base = n_dims / logf(freq_base);
const float p = ((mode & 1) == 0 ? n_past + i02 : i02);

bool is_glm = mode & 4;
Expand All @@ -2963,7 +3028,8 @@ inline void ggml_cuda_op_rope(
const float block_p = max(p - (n_ctx - 2.f), 0.f);
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
} else {
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, n_dims, freq_base, freq_scale, ntk_factor,
extrapolation_factor, theta_scale, theta_ntk_scale, dims_over_base, p, cudaStream_main);
}

(void) dst;
Expand Down

0 comments on commit ce59171

Please sign in to comment.