From 4f666108535c1fc31af8ae502c393c9a76822523 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Sat, 15 Feb 2025 06:41:49 +0800 Subject: [PATCH] [webgpu] Use workgroup_idx instead of workgroup_id.x (#23696) We should always use workgroup_idx instead of workgroup_id.x in cause the dispatched workgroups are normalized. When the input is large enough, the 1d workgroups will be normalized to 2d/3d and results incorrect result. --- .../webgpu/quantization/matmul_nbits.cc | 19 ++++++++++++++----- .../webgpu/quantization/matmul_nbits.h | 1 + 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index 9f898dfc9ab04..3b566d37fa979 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -535,13 +535,21 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddOutput("output", ShaderUsage::UseUniform); shader.AddOutput("scales", ShaderUsage::UseUniform); - + shader.AdditionalImplementation() << R"ADDNL_FN( + fn readInput(offset: u32) -> input_a_value_t + { + if (offset > uniforms.input_size) { + return input_a_value_t(0); + } + return input_a[offset]; + } +)ADDNL_FN"; shader.MainFunctionBody() << R"MAIN_FN( var local_a : array, 32>; var max_value:vec4 = vec4(0); for (var idx:u32=0;idx<32;idx+=1) { - local_a[idx] = input_a[workgroup_id.x*32 + idx]; + local_a[idx] = readInput(workgroup_idx*32 + idx); max_value = max(max_value, abs(local_a[idx])); } var scale = max(max_value.x, max_value.y); @@ -549,10 +557,10 @@ Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const scale = max(scale, max_value.w); for (var idx:u32=0;idx<32;idx+=1) { - output[workgroup_id.x*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); + output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); } // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. - scales[workgroup_id.x] = scale/127; + scales[workgroup_idx] = scale/127; )MAIN_FN"; return Status::OK(); } @@ -828,7 +836,8 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context Tensor a_scale = context.CreateGPUTensor(a->DataType(), a_scales_dims); quantize_program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow(kVec4Components)}}) .AddOutputs({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape(), gsl::narrow(1)}, - {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}); + {&a_scale, ProgramTensorMetadataDependency::Rank, a_scale.Shape(), gsl::narrow(1)}}) + .AddUniformVariable({static_cast(M * K / kVec4Components)}); ORT_RETURN_IF_ERROR(context.RunProgram(quantize_program)); constexpr uint32_t kTileSize = 64; diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h index a2470d9268907..3d72629bf6b25 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.h @@ -39,6 +39,7 @@ class DP4AMatMulQuantizeProgram final : public Program {