From ff5816ad29eba3762e1c5c576c1adf586c35dd91 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi <1802579+farzonl@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:40:11 -0400 Subject: [PATCH] [DirectX] Add `all` lowering (#105787) - DXILIntrinsicExpansion.cpp: Modify `any` codegen expansion to work for `all` - DirectX\all.ll: Add test case completes #88946 --- .../Target/DirectX/DXILIntrinsicExpansion.cpp | 51 ++++++------ llvm/test/CodeGen/DirectX/all.ll | 83 +++++++++++++++++++ 2 files changed, 110 insertions(+), 24 deletions(-) create mode 100644 llvm/test/CodeGen/DirectX/all.ll diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index e49169cff8aa86a..2daa4f825c3b25b 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -38,6 +38,7 @@ static bool isIntrinsicExpansion(Function &F) { case Intrinsic::log: case Intrinsic::log10: case Intrinsic::pow: + case Intrinsic::dx_all: case Intrinsic::dx_any: case Intrinsic::dx_clamp: case Intrinsic::dx_uclamp: @@ -54,8 +55,7 @@ static bool isIntrinsicExpansion(Function &F) { static Value *expandAbs(CallInst *Orig) { Value *X = Orig->getOperand(0); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); Type *Ty = X->getType(); Type *EltTy = Ty->getScalarType(); Constant *Zero = Ty->isVectorTy() @@ -148,8 +148,7 @@ static Value *expandIntegerDotIntrinsic(CallInst *Orig, static Value *expandExpIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); Type *Ty = X->getType(); Type *EltTy = Ty->getScalarType(); Constant *Log2eConst = @@ -166,13 +165,21 @@ static Value *expandExpIntrinsic(CallInst *Orig) { return Exp2Call; } -static Value *expandAnyIntrinsic(CallInst *Orig) { +static Value *expandAnyOrAllIntrinsic(CallInst *Orig, + Intrinsic::ID intrinsicId) { Value *X = Orig->getOperand(0); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); Type *Ty = X->getType(); Type *EltTy = Ty->getScalarType(); + auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result, + Value *Elt) { + if (IntrinsicId == Intrinsic::dx_any) + return Builder.CreateOr(Result, Elt); + assert(IntrinsicId == Intrinsic::dx_all); + return Builder.CreateAnd(Result, Elt); + }; + Value *Result = nullptr; if (!Ty->isVectorTy()) { Result = EltTy->isFloatingPointTy() @@ -193,7 +200,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) { Result = Builder.CreateExtractElement(Cond, (uint64_t)0); for (unsigned I = 1; I < XVec->getNumElements(); I++) { Value *Elt = Builder.CreateExtractElement(Cond, I); - Result = Builder.CreateOr(Result, Elt); + Result = ApplyOp(intrinsicId, Result, Elt); } } return Result; @@ -201,8 +208,7 @@ static Value *expandAnyIntrinsic(CallInst *Orig) { static Value *expandLengthIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); Type *Ty = X->getType(); Type *EltTy = Ty->getScalarType(); @@ -230,8 +236,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); Value *Y = Orig->getOperand(1); Value *S = Orig->getOperand(2); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); auto *V = Builder.CreateFSub(Y, X); V = Builder.CreateFMul(S, V); return Builder.CreateFAdd(X, V, "dx.lerp"); @@ -240,8 +245,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) { static Value *expandLogIntrinsic(CallInst *Orig, float LogConstVal = numbers::ln2f) { Value *X = Orig->getOperand(0); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); Type *Ty = X->getType(); Type *EltTy = Ty->getScalarType(); Constant *Ln2Const = @@ -266,8 +270,7 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); Type *Ty = Orig->getType(); Type *EltTy = Ty->getScalarType(); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); auto *XVec = dyn_cast(Ty); if (!XVec) { @@ -305,8 +308,7 @@ static Value *expandPowIntrinsic(CallInst *Orig) { Value *X = Orig->getOperand(0); Value *Y = Orig->getOperand(1); Type *Ty = X->getType(); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); auto *Log2Call = Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2"); @@ -350,8 +352,7 @@ static Value *expandClampIntrinsic(CallInst *Orig, Value *Min = Orig->getOperand(1); Value *Max = Orig->getOperand(2); Type *Ty = X->getType(); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); auto *MaxCall = Builder.CreateIntrinsic( Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max"); return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic), @@ -360,7 +361,8 @@ static Value *expandClampIntrinsic(CallInst *Orig, static bool expandIntrinsic(Function &F, CallInst *Orig) { Value *Result = nullptr; - switch (F.getIntrinsicID()) { + Intrinsic::ID IntrinsicId = F.getIntrinsicID(); + switch (IntrinsicId) { case Intrinsic::abs: Result = expandAbs(Orig); break; @@ -376,12 +378,13 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) { case Intrinsic::pow: Result = expandPowIntrinsic(Orig); break; + case Intrinsic::dx_all: case Intrinsic::dx_any: - Result = expandAnyIntrinsic(Orig); + Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId); break; case Intrinsic::dx_uclamp: case Intrinsic::dx_clamp: - Result = expandClampIntrinsic(Orig, F.getIntrinsicID()); + Result = expandClampIntrinsic(Orig, IntrinsicId); break; case Intrinsic::dx_lerp: Result = expandLerpIntrinsic(Orig); @@ -397,7 +400,7 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) { break; case Intrinsic::dx_sdot: case Intrinsic::dx_udot: - Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID()); + Result = expandIntegerDotIntrinsic(Orig, IntrinsicId); break; } diff --git a/llvm/test/CodeGen/DirectX/all.ll b/llvm/test/CodeGen/DirectX/all.ll new file mode 100644 index 000000000000000..1c0b6486dc93588 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/all.ll @@ -0,0 +1,83 @@ +; RUN: opt -S -passes=dxil-intrinsic-expansion,dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s + +; Make sure dxil operation function calls for all are generated for float and half. + +; CHECK-LABEL: all_bool +; CHECK: icmp ne i1 %{{.*}}, false +define noundef i1 @all_bool(i1 noundef %p0) { +entry: + %dx.all = call i1 @llvm.dx.all.i1(i1 %p0) + ret i1 %dx.all +} + +; CHECK-LABEL: all_int64_t +; CHECK: icmp ne i64 %{{.*}}, 0 +define noundef i1 @all_int64_t(i64 noundef %p0) { +entry: + %dx.all = call i1 @llvm.dx.all.i64(i64 %p0) + ret i1 %dx.all +} + +; CHECK-LABEL: all_int +; CHECK: icmp ne i32 %{{.*}}, 0 +define noundef i1 @all_int(i32 noundef %p0) { +entry: + %dx.all = call i1 @llvm.dx.all.i32(i32 %p0) + ret i1 %dx.all +} + +; CHECK-LABEL: all_int16_t +; CHECK: icmp ne i16 %{{.*}}, 0 +define noundef i1 @all_int16_t(i16 noundef %p0) { +entry: + %dx.all = call i1 @llvm.dx.all.i16(i16 %p0) + ret i1 %dx.all +} + +; CHECK-LABEL: all_double +; CHECK: fcmp une double %{{.*}}, 0.000000e+00 +define noundef i1 @all_double(double noundef %p0) { +entry: + %dx.all = call i1 @llvm.dx.all.f64(double %p0) + ret i1 %dx.all +} + +; CHECK-LABEL: all_float +; CHECK: fcmp une float %{{.*}}, 0.000000e+00 +define noundef i1 @all_float(float noundef %p0) { +entry: + %dx.all = call i1 @llvm.dx.all.f32(float %p0) + ret i1 %dx.all +} + +; CHECK-LABEL: all_half +; CHECK: fcmp une half %{{.*}}, 0xH0000 +define noundef i1 @all_half(half noundef %p0) { +entry: + %dx.all = call i1 @llvm.dx.all.f16(half %p0) + ret i1 %dx.all +} + +; CHECK-LABEL: all_bool4 +; CHECK: icmp ne <4 x i1> %{{.*}}, zeroinitialize +; CHECK: extractelement <4 x i1> %{{.*}}, i64 0 +; CHECK: extractelement <4 x i1> %{{.*}}, i64 1 +; CHECK: and i1 %{{.*}}, %{{.*}} +; CHECK: extractelement <4 x i1> %{{.*}}, i64 2 +; CHECK: and i1 %{{.*}}, %{{.*}} +; CHECK: extractelement <4 x i1> %{{.*}}, i64 3 +; CHECK: and i1 %{{.*}}, %{{.*}} +define noundef i1 @all_bool4(<4 x i1> noundef %p0) { +entry: + %dx.all = call i1 @llvm.dx.all.v4i1(<4 x i1> %p0) + ret i1 %dx.all +} + +declare i1 @llvm.dx.all.v4i1(<4 x i1>) +declare i1 @llvm.dx.all.i1(i1) +declare i1 @llvm.dx.all.i16(i16) +declare i1 @llvm.dx.all.i32(i32) +declare i1 @llvm.dx.all.i64(i64) +declare i1 @llvm.dx.all.f16(half) +declare i1 @llvm.dx.all.f32(float) +declare i1 @llvm.dx.all.f64(double)