Skip to content

Commit

Permalink
vulkan: fix group_norm (llama/10496)
Browse files Browse the repository at this point in the history
Fix bad calculation of the end of the range. Add a backend test that
covers the bad case (taken from stable diffusion).

Fixes leejet/stable-diffusion.cpp#439.
  • Loading branch information
jeffbolznv authored and ggerganov committed Dec 3, 2024
1 parent a1d5c10 commit 965432c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7157,7 +7157,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
const int32_t max_period = tensor->op_params[1];
tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
} else if (tensor->op == GGML_OP_POOL_2D) {
enum ggml_op_pool op = static_cast<ggml_op_pool>(dst->op_params[0]);
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
const int32_t k0 = tensor->op_params[1];
const int32_t k1 = tensor->op_params[2];
const int32_t s0 = tensor->op_params[3];
Expand Down
2 changes: 1 addition & 1 deletion src/ggml-vulkan/vulkan-shaders/group_norm.comp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ void main() {

const uint tid = gl_LocalInvocationID.x;
const uint start = gl_WorkGroupID.x * group_size + tid;
const uint end = start + group_size;
const uint end = (gl_WorkGroupID.x + 1) * group_size;

tmp[tid] = 0.0f;

Expand Down
3 changes: 2 additions & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3774,7 +3774,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_upscale());
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true));
test_cases.emplace_back(new test_upscale_ext());
test_cases.emplace_back(new test_group_norm());
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_acc());
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_arange());
Expand Down

0 comments on commit 965432c

Please sign in to comment.