From 41fddb53fe2d2ecd1e37852dc0d24aeece515c89 Mon Sep 17 00:00:00 2001 From: Ben Kimock Date: Tue, 6 Feb 2024 14:32:00 -0500 Subject: [PATCH] Add "algebraic" versions of the fast-math intrinsics --- .../src/intrinsics/mod.rs | 21 +++++--- compiler/rustc_codegen_gcc/src/builder.rs | 25 ++++++++++ compiler/rustc_codegen_llvm/src/builder.rs | 48 +++++++++++++++++-- compiler/rustc_codegen_llvm/src/intrinsic.rs | 4 +- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 1 + .../rustc_codegen_ssa/src/mir/intrinsic.rs | 32 +++++++++++++ .../rustc_codegen_ssa/src/traits/builder.rs | 5 ++ .../rustc_hir_analysis/src/check/intrinsic.rs | 12 ++++- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 25 +++++++++- compiler/rustc_span/src/symbol.rs | 5 ++ library/core/src/intrinsics.rs | 40 ++++++++++++++++ tests/codegen/simd/issue-120720-reduce-nan.rs | 21 ++++++++ 12 files changed, 225 insertions(+), 14 deletions(-) create mode 100644 tests/codegen/simd/issue-120720-reduce-nan.rs diff --git a/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs b/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs index 476752c7230a7..199d5df29e7d0 100644 --- a/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs +++ b/compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs @@ -1152,17 +1152,26 @@ fn codegen_regular_intrinsic_call<'tcx>( ret.write_cvalue(fx, ret_val); } - sym::fadd_fast | sym::fsub_fast | sym::fmul_fast | sym::fdiv_fast | sym::frem_fast => { + sym::fadd_fast + | sym::fsub_fast + | sym::fmul_fast + | sym::fdiv_fast + | sym::frem_fast + | sym::fadd_algebraic + | sym::fsub_algebraic + | sym::fmul_algebraic + | sym::fdiv_algebraic + | sym::frem_algebraic => { intrinsic_args!(fx, args => (x, y); intrinsic); let res = crate::num::codegen_float_binop( fx, match intrinsic { - sym::fadd_fast => BinOp::Add, - sym::fsub_fast => BinOp::Sub, - sym::fmul_fast => BinOp::Mul, - sym::fdiv_fast => BinOp::Div, - sym::frem_fast => BinOp::Rem, + sym::fadd_fast | sym::fadd_algebraic => BinOp::Add, + sym::fsub_fast | sym::fsub_algebraic => BinOp::Sub, + sym::fmul_fast | sym::fmul_algebraic => BinOp::Mul, + sym::fdiv_fast | sym::fdiv_algebraic => BinOp::Div, + sym::frem_fast | sym::frem_algebraic => BinOp::Rem, _ => unreachable!(), }, x, diff --git a/compiler/rustc_codegen_gcc/src/builder.rs b/compiler/rustc_codegen_gcc/src/builder.rs index 42e61b3ccb5ad..5f1e45383765f 100644 --- a/compiler/rustc_codegen_gcc/src/builder.rs +++ b/compiler/rustc_codegen_gcc/src/builder.rs @@ -705,6 +705,31 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> { self.frem(lhs, rhs) } + fn fadd_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> { + // NOTE: it seems like we cannot enable fast-mode for a single operation in GCC. + lhs + rhs + } + + fn fsub_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> { + // NOTE: it seems like we cannot enable fast-mode for a single operation in GCC. + lhs - rhs + } + + fn fmul_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> { + // NOTE: it seems like we cannot enable fast-mode for a single operation in GCC. + lhs * rhs + } + + fn fdiv_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> { + // NOTE: it seems like we cannot enable fast-mode for a single operation in GCC. + lhs / rhs + } + + fn frem_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> { + // NOTE: it seems like we cannot enable fast-mode for a single operation in GCC. + self.frem(lhs, rhs) + } + fn checked_binop(&mut self, oop: OverflowOp, typ: Ty<'_>, lhs: Self::Value, rhs: Self::Value) -> (Self::Value, Self::Value) { self.gcc_checked_binop(oop, typ, lhs, rhs) } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 7ed27b33dceaa..cfa266720d2a8 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -340,6 +340,46 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { } } + fn fadd_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + unsafe { + let instr = llvm::LLVMBuildFAdd(self.llbuilder, lhs, rhs, UNNAMED); + llvm::LLVMRustSetAlgebraicMath(instr); + instr + } + } + + fn fsub_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + unsafe { + let instr = llvm::LLVMBuildFSub(self.llbuilder, lhs, rhs, UNNAMED); + llvm::LLVMRustSetAlgebraicMath(instr); + instr + } + } + + fn fmul_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + unsafe { + let instr = llvm::LLVMBuildFMul(self.llbuilder, lhs, rhs, UNNAMED); + llvm::LLVMRustSetAlgebraicMath(instr); + instr + } + } + + fn fdiv_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + unsafe { + let instr = llvm::LLVMBuildFDiv(self.llbuilder, lhs, rhs, UNNAMED); + llvm::LLVMRustSetAlgebraicMath(instr); + instr + } + } + + fn frem_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value { + unsafe { + let instr = llvm::LLVMBuildFRem(self.llbuilder, lhs, rhs, UNNAMED); + llvm::LLVMRustSetAlgebraicMath(instr); + instr + } + } + fn checked_binop( &mut self, oop: OverflowOp, @@ -1327,17 +1367,17 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> { pub fn vector_reduce_fmul(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value { unsafe { llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src) } } - pub fn vector_reduce_fadd_fast(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value { + pub fn vector_reduce_fadd_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value { unsafe { let instr = llvm::LLVMRustBuildVectorReduceFAdd(self.llbuilder, acc, src); - llvm::LLVMRustSetFastMath(instr); + llvm::LLVMRustSetAlgebraicMath(instr); instr } } - pub fn vector_reduce_fmul_fast(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value { + pub fn vector_reduce_fmul_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value { unsafe { let instr = llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src); - llvm::LLVMRustSetFastMath(instr); + llvm::LLVMRustSetAlgebraicMath(instr); instr } } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 4415c51acf684..3b091fca28bcb 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1880,14 +1880,14 @@ fn generic_simd_intrinsic<'ll, 'tcx>( arith_red!(simd_reduce_mul_ordered: vector_reduce_mul, vector_reduce_fmul, true, mul, 1.0); arith_red!( simd_reduce_add_unordered: vector_reduce_add, - vector_reduce_fadd_fast, + vector_reduce_fadd_algebraic, false, add, 0.0 ); arith_red!( simd_reduce_mul_unordered: vector_reduce_mul, - vector_reduce_fmul_fast, + vector_reduce_fmul_algebraic, false, mul, 1.0 diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index d0044086c616e..f9eb1da5dc7a4 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1618,6 +1618,7 @@ extern "C" { ) -> &'a Value; pub fn LLVMRustSetFastMath(Instr: &Value); + pub fn LLVMRustSetAlgebraicMath(Instr: &Value); // Miscellaneous instructions pub fn LLVMRustGetInstrProfIncrementIntrinsic(M: &Module) -> &Value; diff --git a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs index e4633acd81740..82488829b6e16 100644 --- a/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs +++ b/compiler/rustc_codegen_ssa/src/mir/intrinsic.rs @@ -250,6 +250,38 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { } } } + sym::fadd_algebraic + | sym::fsub_algebraic + | sym::fmul_algebraic + | sym::fdiv_algebraic + | sym::frem_algebraic => match float_type_width(arg_tys[0]) { + Some(_width) => match name { + sym::fadd_algebraic => { + bx.fadd_algebraic(args[0].immediate(), args[1].immediate()) + } + sym::fsub_algebraic => { + bx.fsub_algebraic(args[0].immediate(), args[1].immediate()) + } + sym::fmul_algebraic => { + bx.fmul_algebraic(args[0].immediate(), args[1].immediate()) + } + sym::fdiv_algebraic => { + bx.fdiv_algebraic(args[0].immediate(), args[1].immediate()) + } + sym::frem_algebraic => { + bx.frem_algebraic(args[0].immediate(), args[1].immediate()) + } + _ => bug!(), + }, + None => { + bx.tcx().dcx().emit_err(InvalidMonomorphization::BasicFloatType { + span, + name, + ty: arg_tys[0], + }); + return Ok(()); + } + }, sym::float_to_int_unchecked => { if float_type_width(arg_tys[0]).is_none() { diff --git a/compiler/rustc_codegen_ssa/src/traits/builder.rs b/compiler/rustc_codegen_ssa/src/traits/builder.rs index 1c5c78e6ca200..86d3d1260c307 100644 --- a/compiler/rustc_codegen_ssa/src/traits/builder.rs +++ b/compiler/rustc_codegen_ssa/src/traits/builder.rs @@ -86,22 +86,27 @@ pub trait BuilderMethods<'a, 'tcx>: fn add(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn fadd(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn fadd_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; + fn fadd_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn sub(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn fsub(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn fsub_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; + fn fsub_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn mul(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn fmul(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn fmul_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; + fn fmul_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn udiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn exactudiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn sdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn exactsdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn fdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn fdiv_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; + fn fdiv_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn urem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn srem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn frem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn frem_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; + fn frem_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn shl(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn lshr(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; fn ashr(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value; diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index f0f6bfff64aaa..f4b994df2ce8e 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -123,7 +123,12 @@ pub fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) - | sym::variant_count | sym::is_val_statically_known | sym::ptr_mask - | sym::debug_assertions => hir::Unsafety::Normal, + | sym::debug_assertions + | sym::fadd_algebraic + | sym::fsub_algebraic + | sym::fmul_algebraic + | sym::fdiv_algebraic + | sym::frem_algebraic => hir::Unsafety::Normal, _ => hir::Unsafety::Unsafe, }; @@ -405,6 +410,11 @@ pub fn check_intrinsic_type( sym::fadd_fast | sym::fsub_fast | sym::fmul_fast | sym::fdiv_fast | sym::frem_fast => { (1, 0, vec![param(0), param(0)], param(0)) } + sym::fadd_algebraic + | sym::fsub_algebraic + | sym::fmul_algebraic + | sym::fdiv_algebraic + | sym::frem_algebraic => (1, 0, vec![param(0), param(0)], param(0)), sym::float_to_int_unchecked => (2, 0, vec![param(0)], param(1)), sym::assume => (0, 0, vec![tcx.types.bool], Ty::new_unit(tcx)), diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index b45706fd1e5b2..7326f2e8e2a20 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -418,7 +418,11 @@ extern "C" LLVMAttributeRef LLVMRustCreateMemoryEffectsAttr(LLVMContextRef C, } } -// Enable a fast-math flag +// Enable all fast-math flags, including those which will cause floating-point operations +// to return poison for some well-defined inputs. This function can only be used to build +// unsafe Rust intrinsics. That unsafety does permit additional optimizations, but at the +// time of writing, their value is not well-understood relative to those enabled by +// LLVMRustSetAlgebraicMath. // // https://llvm.org/docs/LangRef.html#fast-math-flags extern "C" void LLVMRustSetFastMath(LLVMValueRef V) { @@ -427,6 +431,25 @@ extern "C" void LLVMRustSetFastMath(LLVMValueRef V) { } } +// Enable fast-math flags which permit algebraic transformations that are not allowed by +// IEEE floating point. For example: +// a + (b + c) = (a + b) + c +// and +// a / b = a * (1 / b) +// Note that this does NOT enable any flags which can cause a floating-point operation on +// well-defined inputs to return poison, and therefore this function can be used to build +// safe Rust intrinsics (such as fadd_algebraic). +// +// https://llvm.org/docs/LangRef.html#fast-math-flags +extern "C" void LLVMRustSetAlgebraicMath(LLVMValueRef V) { + if (auto I = dyn_cast(unwrap(V))) { + I->setHasAllowReassoc(true); + I->setHasAllowContract(true); + I->setHasAllowReciprocal(true); + I->setHasNoSignedZeros(true); + } +} + extern "C" LLVMValueRef LLVMRustBuildAtomicLoad(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Source, const char *Name, LLVMAtomicOrdering Order) { diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 29c88783357b7..181ab0d4d56cf 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -764,8 +764,10 @@ symbols! { f64_nan, fabsf32, fabsf64, + fadd_algebraic, fadd_fast, fake_variadic, + fdiv_algebraic, fdiv_fast, feature, fence, @@ -785,6 +787,7 @@ symbols! { fmaf32, fmaf64, fmt, + fmul_algebraic, fmul_fast, fn_align, fn_delegation, @@ -810,6 +813,7 @@ symbols! { format_unsafe_arg, freeze, freg, + frem_algebraic, frem_fast, from, from_desugaring, @@ -823,6 +827,7 @@ symbols! { from_usize, from_yeet, fs_create_dir, + fsub_algebraic, fsub_fast, fundamental, future, diff --git a/library/core/src/intrinsics.rs b/library/core/src/intrinsics.rs index fc6c1eab803d7..f19d169ae686d 100644 --- a/library/core/src/intrinsics.rs +++ b/library/core/src/intrinsics.rs @@ -1882,6 +1882,46 @@ extern "rust-intrinsic" { #[rustc_nounwind] pub fn frem_fast(a: T, b: T) -> T; + /// Float addition that allows optimizations based on algebraic rules. + /// + /// This intrinsic does not have a stable counterpart. + #[rustc_nounwind] + #[rustc_safe_intrinsic] + #[cfg(not(bootstrap))] + pub fn fadd_algebraic(a: T, b: T) -> T; + + /// Float subtraction that allows optimizations based on algebraic rules. + /// + /// This intrinsic does not have a stable counterpart. + #[rustc_nounwind] + #[rustc_safe_intrinsic] + #[cfg(not(bootstrap))] + pub fn fsub_algebraic(a: T, b: T) -> T; + + /// Float multiplication that allows optimizations based on algebraic rules. + /// + /// This intrinsic does not have a stable counterpart. + #[rustc_nounwind] + #[rustc_safe_intrinsic] + #[cfg(not(bootstrap))] + pub fn fmul_algebraic(a: T, b: T) -> T; + + /// Float division that allows optimizations based on algebraic rules. + /// + /// This intrinsic does not have a stable counterpart. + #[rustc_nounwind] + #[rustc_safe_intrinsic] + #[cfg(not(bootstrap))] + pub fn fdiv_algebraic(a: T, b: T) -> T; + + /// Float remainder that allows optimizations based on algebraic rules. + /// + /// This intrinsic does not have a stable counterpart. + #[rustc_nounwind] + #[rustc_safe_intrinsic] + #[cfg(not(bootstrap))] + pub fn frem_algebraic(a: T, b: T) -> T; + /// Convert with LLVM’s fptoui/fptosi, which may return undef for values out of range /// () /// diff --git a/tests/codegen/simd/issue-120720-reduce-nan.rs b/tests/codegen/simd/issue-120720-reduce-nan.rs new file mode 100644 index 0000000000000..1e06a58b41f21 --- /dev/null +++ b/tests/codegen/simd/issue-120720-reduce-nan.rs @@ -0,0 +1,21 @@ +//@ compile-flags: -C opt-level=3 -C target-cpu=cannonlake + +// In a previous implementation, _mm512_reduce_add_pd did the reduction with all fast-math flags +// enabled, making it UB to reduce a vector containing a NaN. + +#![crate_type = "lib"] +#![feature(stdarch_x86_avx512, avx512_target_feature)] +use std::arch::x86_64::*; + +// CHECK-label: @demo( +#[no_mangle] +#[target_feature(enable = "avx512f")] // Function-level target feature mismatches inhibit inlining +pub unsafe fn demo() -> bool { + // CHECK: %0 = tail call reassoc nsz arcp contract double @llvm.vector.reduce.fadd.v8f64( + // CHECK: %_0.i = fcmp uno double %0, 0.000000e+00 + // CHECK: ret i1 %_0.i + let res = unsafe { + _mm512_reduce_add_pd(_mm512_set_pd(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, f64::NAN)) + }; + res.is_nan() +}