Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[naga] Implement quantizeToF16 #6519

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Bottom level categories:
- Parse `diagnostic(…)` directives, but don't implement any triggering rules yet. By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456).
- Fix an issue where `naga` CLI would incorrectly skip the first positional argument when `--stdin-file-path` was specified. By @ErichDonGubler in [#6480](https://github.com/gfx-rs/wgpu/pull/6480).
- Fix textureNumLevels in the GLSL backend. By @magcius in [#6483](https://github.com/gfx-rs/wgpu/pull/6483).
- Implement `quantizeToF16()` for WGSL frontend, and WGSL, SPIR-V, HLSL, MSL, and GLSL backends. By @jamienicol in [#6519](https://github.com/gfx-rs/wgpu/pull/6519).

#### General

Expand Down
45 changes: 44 additions & 1 deletion naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,8 @@ impl<'a, W: Write> Writer<'a, W> {
crate::MathFunction::Pack4xI8
| crate::MathFunction::Pack4xU8
| crate::MathFunction::Unpack4xI8
| crate::MathFunction::Unpack4xU8 => {
| crate::MathFunction::Unpack4xU8
| crate::MathFunction::QuantizeToF16 => {
self.need_bake_expressions.insert(arg);
}
crate::MathFunction::ExtractBits => {
Expand Down Expand Up @@ -3495,6 +3496,48 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Inverse => "inverse",
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
Mf::QuantizeToF16 => match *ctx.resolve_type(arg, &self.module.types) {
crate::TypeInner::Scalar { .. } => {
write!(self.out, "unpackHalf2x16(packHalf2x16(vec2(")?;
self.write_expr(arg, ctx)?;
write!(self.out, "))).x")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Bi,
..
} => {
write!(self.out, "unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, "))")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Tri,
..
} => {
write!(self.out, "vec3(unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".xy)), unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".zz)).x)")?;
return Ok(());
}
crate::TypeInner::Vector {
size: crate::VectorSize::Quad,
..
} => {
write!(self.out, "vec4(unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".xy)), unpackHalf2x16(packHalf2x16(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ".zw)))")?;
return Ok(());
}
_ => unreachable!(
"Correct TypeInner for QuantizeToF16 should be already validated"
),
},
// bits
Mf::CountTrailingZeros => {
match *ctx.resolve_type(arg, &self.module.types) {
Expand Down
7 changes: 7 additions & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3036,6 +3036,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Unpack4x8unorm,
Unpack4xI8,
Unpack4xU8,
QuantizeToF16,
Regular(&'static str),
MissingIntOverload(&'static str),
MissingIntReturnType(&'static str),
Expand Down Expand Up @@ -3102,6 +3103,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
//Mf::Inverse =>,
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
Mf::QuantizeToF16 => Function::QuantizeToF16,
// bits
Mf::CountTrailingZeros => Function::CountTrailingZeros,
Mf::CountLeadingZeros => Function::CountLeadingZeros,
Expand Down Expand Up @@ -3303,6 +3305,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " >> 24) << 24 >> 24")?;
}
Function::QuantizeToF16 => {
write!(self.out, "f16tof32(f32tof16(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
}
Function::Regular(fun_name) => {
write!(self.out, "{fun_name}(")?;
self.write_expr(module, arg, func_ctx)?;
Expand Down
17 changes: 17 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1936,6 +1936,7 @@ impl<W: Write> Writer<W> {
Mf::Inverse => return Err(Error::UnsupportedCall(format!("{fun:?}"))),
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
Mf::QuantizeToF16 => "",
// bits
Mf::CountTrailingZeros => "ctz",
Mf::CountLeadingZeros => "clz",
Expand Down Expand Up @@ -2144,6 +2145,22 @@ impl<W: Write> Writer<W> {
self.put_expression(arg, context, true)?;
write!(self.out, " >> 24) << 24 >> 24")?;
}
Mf::QuantizeToF16 => {
match *context.resolve_type(arg) {
crate::TypeInner::Scalar { .. } => write!(self.out, "float(half(")?,
crate::TypeInner::Vector { size, .. } => write!(
self.out,
"{NAMESPACE}::float{size}({NAMESPACE}::half{size}(",
size = back::vector_size_str(size),
)?,
_ => unreachable!(
"Correct TypeInner for QuantizeToF16 should be already validated"
),
};

self.put_expression(arg, context, true)?;
write!(self.out, "))")?;
}
_ => {
write!(self.out, "{NAMESPACE}::{fun_name}")?;
self.put_call_parameters(
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1032,6 +1032,12 @@ impl<'w> BlockContext<'w> {
arg0_id,
)),
Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
Mf::QuantizeToF16 => MathOp::Custom(Instruction::unary(
spirv::Op::QuantizeToF16,
result_type_id,
id,
arg0_id,
)),
Mf::ReverseBits => MathOp::Custom(Instruction::unary(
spirv::Op::BitReverse,
result_type_id,
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1723,6 +1723,7 @@ impl<W: Write> Writer<W> {
Mf::InverseSqrt => Function::Regular("inverseSqrt"),
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
Mf::QuantizeToF16 => Function::Regular("quantizeToF16"),
// bits
Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"),
Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),
Expand Down
1 change: 1 addition & 0 deletions naga/src/front/wgsl/parse/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
"inverseSqrt" => Mf::InverseSqrt,
"transpose" => Mf::Transpose,
"determinant" => Mf::Determinant,
"quantizeToF16" => Mf::QuantizeToF16,
// bits
"countTrailingZeros" => Mf::CountTrailingZeros,
"countLeadingZeros" => Mf::CountLeadingZeros,
Expand Down
1 change: 1 addition & 0 deletions naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1199,6 +1199,7 @@ pub enum MathFunction {
Inverse,
Transpose,
Determinant,
QuantizeToF16,
// bits
CountTrailingZeros,
CountLeadingZeros,
Expand Down
1 change: 1 addition & 0 deletions naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ impl super::MathFunction {
Self::Inverse => 1,
Self::Transpose => 1,
Self::Determinant => 1,
Self::QuantizeToF16 => 1,
// bits
Self::CountTrailingZeros => 1,
Self::CountLeadingZeros => 1,
Expand Down
3 changes: 2 additions & 1 deletion naga/src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,8 @@ impl<'a> ResolveContext<'a> {
| Mf::Exp2
| Mf::Log
| Mf::Log2
| Mf::Pow => res_arg.clone(),
| Mf::Pow
| Mf::QuantizeToF16 => res_arg.clone(),
Mf::Modf | Mf::Frexp => {
let (size, width) = match res_arg.inner_with(types) {
&Ti::Scalar(crate::Scalar {
Expand Down
20 changes: 20 additions & 0 deletions naga/src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,26 @@ impl super::Validator {
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::QuantizeToF16 => {
if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() {
return Err(ExpressionError::WrongArgumentCount(fun));
}
match *arg_ty {
Ti::Scalar(Sc {
kind: Sk::Float,
width: 4,
})
| Ti::Vector {
scalar:
Sc {
kind: Sk::Float,
width: 4,
},
..
} => {}
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
// Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276
Mf::CountLeadingZeros
| Mf::CountTrailingZeros
Expand Down
4 changes: 4 additions & 0 deletions naga/tests/in/math-functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,8 @@ fn main() {
let frexp_b = frexp(1.5).fract;
let frexp_c: i32 = frexp(1.5).exp;
let frexp_d: i32 = frexp(vec4(1.5, 1.5, 1.5, 1.5)).exp.x;
let quantizeToF16_a: f32 = quantizeToF16(1.0);
let quantizeToF16_b: vec2<f32> = quantizeToF16(vec2(1.0, 1.0));
let quantizeToF16_c: vec3<f32> = quantizeToF16(vec3(1.0, 1.0, 1.0));
let quantizeToF16_d: vec4<f32> = quantizeToF16(vec4(1.0, 1.0, 1.0, 1.0));
}
7 changes: 7 additions & 0 deletions naga/tests/out/glsl/math-functions.main.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,12 @@ void main() {
float frexp_b = naga_frexp(1.5).fract_;
int frexp_c = naga_frexp(1.5).exp_;
int frexp_d = naga_frexp(vec4(1.5, 1.5, 1.5, 1.5)).exp_.x;
float quantizeToF16_a = unpackHalf2x16(packHalf2x16(vec2(1.0))).x;
vec2 _e120 = vec2(1.0, 1.0);
vec2 quantizeToF16_b = unpackHalf2x16(packHalf2x16(_e120));
vec3 _e125 = vec3(1.0, 1.0, 1.0);
vec3 quantizeToF16_c = vec3(unpackHalf2x16(packHalf2x16(_e125.xy)), unpackHalf2x16(packHalf2x16(_e125.zz)).x);
vec4 _e131 = vec4(1.0, 1.0, 1.0, 1.0);
vec4 quantizeToF16_d = vec4(unpackHalf2x16(packHalf2x16(_e131.xy)), unpackHalf2x16(packHalf2x16(_e131.zw)));
}

4 changes: 4 additions & 0 deletions naga/tests/out/hlsl/math-functions.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,8 @@ void main()
float frexp_b = naga_frexp(1.5).fract;
int frexp_c = naga_frexp(1.5).exp_;
int frexp_d = naga_frexp(float4(1.5, 1.5, 1.5, 1.5)).exp_.x;
float quantizeToF16_a = f16tof32(f32tof16(1.0));
float2 quantizeToF16_b = f16tof32(f32tof16(float2(1.0, 1.0)));
float3 quantizeToF16_c = f16tof32(f32tof16(float3(1.0, 1.0, 1.0)));
float4 quantizeToF16_d = f16tof32(f32tof16(float4(1.0, 1.0, 1.0, 1.0)));
}
4 changes: 4 additions & 0 deletions naga/tests/out/msl/math-functions.msl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,8 @@ fragment void main_(
float frexp_b = naga_frexp(1.5).fract;
int frexp_c = naga_frexp(1.5).exp;
int frexp_d = naga_frexp(metal::float4(1.5, 1.5, 1.5, 1.5)).exp.x;
float quantizeToF16_a = float(half(1.0));
metal::float2 quantizeToF16_b = metal::float2(metal::half2(metal::float2(1.0, 1.0)));
metal::float3 quantizeToF16_c = metal::float3(metal::half3(metal::float3(1.0, 1.0, 1.0)));
metal::float4 quantizeToF16_d = metal::float4(metal::half4(metal::float4(1.0, 1.0, 1.0, 1.0)));
}
Loading