diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 320d61707e642a..3775e61cd05638 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3154,23 +3154,27 @@ static __device__ float rope_ntkv2_ramp(const float low, const float high, const return 1.0f - min(1.0f, max(0.0f, y)); } +struct rope_corr_factors { + float v[4]; +}; + // NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static __device__ float rope_ntkv2( const float theta_base, const float theta_linear, const float theta_ntk, - const float corr_factors[4], + const rope_corr_factors corr_factors, const int64_t i0, const float ntk_factor, const float ext_factor) { float ramp_mix; float theta; - ramp_mix = rope_ntkv2_ramp(corr_factors[0], corr_factors[1], i0) * ntk_factor; + ramp_mix = rope_ntkv2_ramp(corr_factors.v[0], corr_factors.v[1], i0) * ntk_factor; theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix; - ramp_mix = rope_ntkv2_ramp(corr_factors[2], corr_factors[3], i0) * ext_factor; + ramp_mix = rope_ntkv2_ramp(corr_factors.v[2], corr_factors.v[3], i0) * ext_factor; theta = theta * (1 - ramp_mix) + theta_base * ramp_mix; return theta; } @@ -3187,7 +3191,7 @@ static __global__ void rope_f32( const float theta_ntk_scale, const float p0, const int p_delta_rows, - const float corr_factors[4]) { + const rope_corr_factors corr_factors) { const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); if (col >= ncols) { @@ -3817,7 +3821,7 @@ static void rope_f32_cuda( const float theta_ntk_scale, const float p0, const int p_delta_rows, - const float corr_factors[4], + const rope_corr_factors corr_factors, cudaStream_t stream) { GGML_ASSERT(nrows % 2 == 0); const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); @@ -4546,8 +4550,8 @@ inline void ggml_cuda_op_rope( } else { const float p0 = (mode & 1) == 0 ? n_past : 0; const float theta_ntk_scale = powf(freq_base * powf(freq_scale, (n_dims / (n_dims - 2.0f))), -2.0f/n_dims); - float corr_factors[4]; - ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors); + rope_corr_factors corr_factors; + ggml_rope_ntkv2_corr_factors(n_dims, freq_base, corr_factors.v); rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, freq_scale, ntk_factor, ext_factor, theta_scale, theta_ntk_scale, p0, ne01, corr_factors, cudaStream_main);