Skip to content

Commit

Permalink
CUDA: fix illegal memory access
Browse files Browse the repository at this point in the history
  • Loading branch information
cebtenzzre committed Aug 3, 2023
1 parent da5f409 commit 2b61001
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3154,23 +3154,27 @@ static __device__ float rope_ntkv2_ramp(const float low, const float high, const
return 1.0f - min(1.0f, max(0.0f, y));
}

struct rope_corr_factors {
float v[4];
};

// NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static __device__ float rope_ntkv2(
const float theta_base,
const float theta_linear,
const float theta_ntk,
const float corr_factors[4],
const rope_corr_factors corr_factors,
const int64_t i0,
const float ntk_factor,
const float ext_factor) {
float ramp_mix;
float theta;

ramp_mix = rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor;
ramp_mix = rope_ntkv2_ramp(corr_factors.v[0], corr_factors.v[1], i0) * ntk_factor;
theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;

ramp_mix = rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * ext_factor;
ramp_mix = rope_ntkv2_ramp(corr_factors.v[2], corr_factors.v[3], i0) * ext_factor;
theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
return theta;
}
Expand All @@ -3187,7 +3191,7 @@ static __global__ void rope_f32(
const float theta_ntk_scale,
const float p0,
const int p_delta_rows,
const float corr_factors[4]) {
const rope_corr_factors corr_factors) {
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);

if (col >= ncols) {
Expand Down Expand Up @@ -3817,7 +3821,7 @@ static void rope_f32_cuda(
const float theta_ntk_scale,
const float p0,
const int p_delta_rows,
const float corr_factors[4],
const rope_corr_factors corr_factors,
cudaStream_t stream) {
GGML_ASSERT(nrows % 2 == 0);
const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1);
Expand Down Expand Up @@ -4546,8 +4550,8 @@ inline void ggml_cuda_op_rope(
} else {
const float p0 = (mode & 1) == 0 ? n_past : 0;
const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims);
float corr_factors[4];
ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors);
rope_corr_factors corr_factors;
ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors.v);

rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor, ext_factor, theta_scale,
theta_ntk_scale, p0, ne01, corr_factors, cudaStream_main);
Expand Down

0 comments on commit 2b61001

Please sign in to comment.