Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Qwen2VL #10361

Merged
merged 35 commits into from
Dec 14, 2024
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c17546f
Barebone Qwen2VL LLM convertor
HimariO Sep 21, 2024
7c6f793
Add Qwen2VL cli entrypoint
HimariO Sep 22, 2024
b24bd89
[WIP] add qwen2vl arch
HimariO Sep 25, 2024
3541196
Verify m-rope output
HimariO Sep 29, 2024
9d389a0
Add vl-rope/2d-rope support for qwen2vl ViT
HimariO Sep 30, 2024
f661483
update qwen2vl cli tool
HimariO Oct 1, 2024
3c3691e
update 5D tensor op workaround
HimariO Oct 2, 2024
c13edfe
[WIP] qwen2vl vision model
HimariO Oct 10, 2024
7e9fc72
make batch and clip utils compatible with qwen2vl
HimariO Oct 18, 2024
bcd49f5
[WIP] create inference workflow, gguf convert script but fix
HimariO Oct 18, 2024
023f007
correcting vision-rope behavior, add the missing last layer back to ViT
HimariO Oct 20, 2024
3d19dd4
add arg parser to qwen2vl_surgery
HimariO Oct 20, 2024
53480d2
replace variable size array with vector
HimariO Oct 21, 2024
0882f57
cuda-gdb cmake preset
HimariO Oct 27, 2024
3237bb4
add fp32 mrope, vision rope kernel
HimariO Oct 28, 2024
201f704
add fp16 support for qwen2vl and m-rope
HimariO Oct 30, 2024
f1fa60f
add `GGML_ROPE_TYPE_MROPE`, `GGML_ROPE_TYPE_VISION`
HimariO Oct 30, 2024
241bb45
fix rope op mode switching, out dated func args
HimariO Nov 4, 2024
07553cf
update `llama_hparams`
HimariO Nov 10, 2024
fac0345
update to keep up stream changes
HimariO Nov 11, 2024
cbd08b4
resolve linter, test errors
HimariO Nov 29, 2024
6c39aa3
add makefile entry, update speical image padding token
HimariO Dec 7, 2024
ac2089c
add mrope unit test, fix few compiler warnings
HimariO Dec 7, 2024
12f17f7
rename `mrope` related function, params
HimariO Dec 7, 2024
3ba7664
minor updates on debug util, bug fixs
HimariO Dec 9, 2024
b24ab86
add `m-rope` testcase to `test-backend-ops`
HimariO Dec 9, 2024
d7edc55
Merge branch 'master' into qwen2-vl
HimariO Dec 11, 2024
9abb252
Apply suggestions from code review
HimariO Dec 13, 2024
c292bf1
Merge branch 'ggerganov:master' into qwen2-vl
HimariO Dec 13, 2024
e9748e4
fix traililng whitespce
HimariO Dec 13, 2024
ef7f74b
store `llama_hparams.rope_sections` with fixed size array
HimariO Dec 13, 2024
e2e9a6c
update position id tensor size check in GGML_OP_ROPE
HimariO Dec 13, 2024
a02a190
minor updates
HimariO Dec 13, 2024
19aba1d
update `ggml_backend_*_supports_op` of unsupported backends
HimariO Dec 13, 2024
f96909e
remote old `rope_section` compare operator
HimariO Dec 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add fp32 mrope, vision rope kernel
  • Loading branch information
HimariO committed Nov 29, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 3237bb46144646a19278d11918eee499f6046656
223 changes: 220 additions & 3 deletions ggml/src/ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
@@ -4,6 +4,11 @@ struct rope_corr_dims {
float v[2];
};


struct mrope_sections {
int v[4];
};

static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
@@ -108,6 +113,114 @@ static __global__ void rope_neox(
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}

template<typename T, bool has_ff>
static __global__ void rope_mrope(
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (i0 >= ne0) {
return;
}

const int row = blockDim.x*blockIdx.x + threadIdx.x;

if (i0 >= n_dims) {
const int i = row*ne0 + i0;

dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];

return;
}

const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows;

int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
int sec_w = sections.v[1] + sections.v[0];
int sector = (i0 / 2) % sect_dims;

float theta_base = 0.0;
if (sector < sections.v[0]) {
theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2 + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[i2 + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[i2 + ne2 * 3]*powf(theta_scale, i0/2.0f);
}

const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

float cos_theta;
float sin_theta;

rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];

dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}

template<typename T, bool has_ff>
static __global__ void rope_vision(
const T * x, T * dst, int ne0, int ne2, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors, mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);

if (i0 >= ne0) {
return;
}

const int row = blockDim.x*blockIdx.x + threadIdx.x;

// if (i0 >= n_dims) {
// const int i = row*ne0 + i0;

// dst[i + 0] = x[i + 0];
// dst[i + 1] = x[i + 1];

// return;
// }

const int i = row*ne0 + i0/2;
const int i2 = row/p_delta_rows; // i2-th tokens

int sect_dims = sections.v[0] + sections.v[1];
int sec_w = sections.v[1] + sections.v[0];
int sector = (i0 / 2) % sect_dims;

float theta_base = 0.0;
if (sector < sections.v[0]) {
const int p = sector;
theta_base = pos[i2]*powf(theta_scale, p);
}
else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0];
theta_base = pos[i2 + ne2]*powf(theta_scale, p);
}

const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;

float cos_theta;
float sin_theta;

rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + n_dims];

dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims] = x0*sin_theta + x1*cos_theta;
}

template<typename T>
static void rope_norm_cuda(
const T * x, T * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
@@ -156,6 +269,56 @@ static void rope_neox_cuda(
}
}

template<typename T>
static void rope_mrope_cuda(
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);

const float theta_scale = powf(freq_base, -2.0f/n_dims);

if (freq_factors == nullptr) {
rope_mrope<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
} else {
rope_mrope<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
}
}

template<typename T>
static void rope_vision_cuda(
const T * x, T * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);

const float theta_scale = powf(freq_base, -2.0f/n_dims);

if (freq_factors == nullptr) {
rope_vision<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
} else {
rope_vision<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne2, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, freq_factors, sections
);
}
}

static void rope_norm_cuda_f16(
const half * x, half * dst, int ne0, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) {
@@ -185,6 +348,22 @@ static void rope_neox_cuda_f32(
rope_neox_cuda<float>(x, dst, ne0, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream);
}

static void rope_mrope_cuda_f32(
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {

rope_mrope_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}

static void rope_vision_cuda_f32(
const float * x, float * dst, int ne0, int ne2, int n_dims, int nr, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, mrope_sections sections, cudaStream_t stream
) {

rope_vision_cuda<float>(x, dst, ne0, ne2, n_dims, nr, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
}

void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -201,15 +380,18 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);

const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne00 = src0->ne[0]; // head dims
const int64_t ne01 = src0->ne[1]; // num heads
const int64_t ne02 = src0->ne[2]; // num heads
const int64_t nr = ggml_nrows(src0);

//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
//const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
// int sections[4];
mrope_sections sections;

// RoPE alteration for extended context
float freq_base;
@@ -225,8 +407,15 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);

const bool is_mrope = sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0;
const bool is_vision = is_mrope && sections.v[3] > 0;
const bool is_neox = (mode & GGML_ROPE_TYPE_NEOX) & !(is_mrope || is_vision); // TODO: fix this with new rope type

const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
if (is_vision) {
GGML_ASSERT(n_dims == ne00/2);
}

const int32_t * pos = (const int32_t *) src1_d;

@@ -253,6 +442,34 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
} else {
GGML_ABORT("fatal error");
}
} else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_mrope_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream
);
} else if (src0->type == GGML_TYPE_F16 && false) {
// rope_mrope_cuda_f16(
// (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
// attn_factor, corr_dims, freq_factors, stream
// );
} else {
GGML_ABORT("fatal error");
}
} else if (is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_vision_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, ne02, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, freq_factors, sections, stream
);
} else if (src0->type == GGML_TYPE_F16 && false) {
// rope_vision_cuda_f16(
// (const half *)src0_d, (half *)dst_d, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
// attn_factor, corr_dims, freq_factors, stream
// );
} else {
GGML_ABORT("fatal error");
}
} else {
if (src0->type == GGML_TYPE_F32) {
rope_norm_cuda_f32(