Skip to content

Commit

Permalink
[webgpu] Use workgroup_idx instead of workgroup_id.x (#23696)
Browse files Browse the repository at this point in the history
We should always use workgroup_idx instead of workgroup_id.x in cause
the dispatched workgroups are normalized.

When the input is large enough, the 1d workgroups will be normalized to
2d/3d and results incorrect result.
  • Loading branch information
qjia7 authored Feb 14, 2025
1 parent 1cd7981 commit 4f66610
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
19 changes: 14 additions & 5 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,24 +535,32 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddOutput("output", ShaderUsage::UseUniform);
shader.AddOutput("scales", ShaderUsage::UseUniform);

shader.AdditionalImplementation() << R"ADDNL_FN(
fn readInput(offset: u32) -> input_a_value_t
{
if (offset > uniforms.input_size) {
return input_a_value_t(0);
}
return input_a[offset];
}
)ADDNL_FN";
shader.MainFunctionBody() << R"MAIN_FN(
var local_a : array<vec4<input_a_element_t>, 32>;
var max_value:vec4<input_a_element_t> = vec4<input_a_element_t>(0);
for (var idx:u32=0;idx<32;idx+=1)
{
local_a[idx] = input_a[workgroup_id.x*32 + idx];
local_a[idx] = readInput(workgroup_idx*32 + idx);
max_value = max(max_value, abs(local_a[idx]));
}
var scale = max(max_value.x, max_value.y);
scale = max(scale, max_value.z);
scale = max(scale, max_value.w);
for (var idx:u32=0;idx<32;idx+=1)
{
output[workgroup_id.x*32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
output[workgroup_idx*32+idx] = pack4x8snorm(vec4<f32>(local_a[idx]/scale));
}
// 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f.
scales[workgroup_id.x] = scale/127;
scales[workgroup_idx] = scale/127;
)MAIN_FN";
return Status::OK();
}
Expand Down Expand Up @@ -828,7 +836,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims);
quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)}})
.AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow<int>(1)},
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow<int>(1)}});
{&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow<int>(1)}})
.AddUniformVariable({static_cast<uint32_t>(M * K / kVec4Components)});
ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program));

constexpr uint32_t kTileSize = 64;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class DP4AMatMulQuantizeProgram final : public Program<DP4AMatMulQuantizeProgram
public:
DP4AMatMulQuantizeProgram() : Program{"DP4AMatMulQuantize"} {}
Status GenerateShaderCode(ShaderHelper& sh) const override;
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32});
};

class DP4AMatMulNBitsProgram final : public Program<DP4AMatMulNBitsProgram> {
Expand Down

0 comments on commit 4f66610

Please sign in to comment.