Skip to content

Commit

Permalink
[DirectX] Add all lowering (#105787)
Browse files Browse the repository at this point in the history
- DXILIntrinsicExpansion.cpp: Modify `any` codegen expansion to work for
`all`
- DirectX\all.ll: Add test case

completes #88946
  • Loading branch information
farzonl authored Aug 26, 2024
1 parent 4bab038 commit ff5816a
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 24 deletions.
51 changes: 27 additions & 24 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::log:
case Intrinsic::log10:
case Intrinsic::pow:
case Intrinsic::dx_all:
case Intrinsic::dx_any:
case Intrinsic::dx_clamp:
case Intrinsic::dx_uclamp:
Expand All @@ -54,8 +55,7 @@ static bool isIntrinsicExpansion(Function &F) {

static Value *expandAbs(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Zero = Ty->isVectorTy()
Expand Down Expand Up @@ -148,8 +148,7 @@ static Value *expandIntegerDotIntrinsic(CallInst *Orig,

static Value *expandExpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Log2eConst =
Expand All @@ -166,13 +165,21 @@ static Value *expandExpIntrinsic(CallInst *Orig) {
return Exp2Call;
}

static Value *expandAnyIntrinsic(CallInst *Orig) {
static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
Intrinsic::ID intrinsicId) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();

auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
Value *Elt) {
if (IntrinsicId == Intrinsic::dx_any)
return Builder.CreateOr(Result, Elt);
assert(IntrinsicId == Intrinsic::dx_all);
return Builder.CreateAnd(Result, Elt);
};

Value *Result = nullptr;
if (!Ty->isVectorTy()) {
Result = EltTy->isFloatingPointTy()
Expand All @@ -193,16 +200,15 @@ static Value *expandAnyIntrinsic(CallInst *Orig) {
Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
Value *Elt = Builder.CreateExtractElement(Cond, I);
Result = Builder.CreateOr(Result, Elt);
Result = ApplyOp(intrinsicId, Result, Elt);
}
}
return Result;
}

static Value *expandLengthIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();

Expand Down Expand Up @@ -230,8 +236,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Value *S = Orig->getOperand(2);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
IRBuilder<> Builder(Orig);
auto *V = Builder.CreateFSub(Y, X);
V = Builder.CreateFMul(S, V);
return Builder.CreateFAdd(X, V, "dx.lerp");
Expand All @@ -240,8 +245,7 @@ static Value *expandLerpIntrinsic(CallInst *Orig) {
static Value *expandLogIntrinsic(CallInst *Orig,
float LogConstVal = numbers::ln2f) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
IRBuilder<> Builder(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();
Constant *Ln2Const =
Expand All @@ -266,8 +270,7 @@ static Value *expandNormalizeIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = Orig->getType();
Type *EltTy = Ty->getScalarType();
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
IRBuilder<> Builder(Orig);

auto *XVec = dyn_cast<FixedVectorType>(Ty);
if (!XVec) {
Expand Down Expand Up @@ -305,8 +308,7 @@ static Value *expandPowIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Type *Ty = X->getType();
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
IRBuilder<> Builder(Orig);

auto *Log2Call =
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
Expand Down Expand Up @@ -350,8 +352,7 @@ static Value *expandClampIntrinsic(CallInst *Orig,
Value *Min = Orig->getOperand(1);
Value *Max = Orig->getOperand(2);
Type *Ty = X->getType();
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
IRBuilder<> Builder(Orig);
auto *MaxCall = Builder.CreateIntrinsic(
Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
Expand All @@ -360,7 +361,8 @@ static Value *expandClampIntrinsic(CallInst *Orig,

static bool expandIntrinsic(Function &F, CallInst *Orig) {
Value *Result = nullptr;
switch (F.getIntrinsicID()) {
Intrinsic::ID IntrinsicId = F.getIntrinsicID();
switch (IntrinsicId) {
case Intrinsic::abs:
Result = expandAbs(Orig);
break;
Expand All @@ -376,12 +378,13 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::pow:
Result = expandPowIntrinsic(Orig);
break;
case Intrinsic::dx_all:
case Intrinsic::dx_any:
Result = expandAnyIntrinsic(Orig);
Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_uclamp:
case Intrinsic::dx_clamp:
Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
Result = expandClampIntrinsic(Orig, IntrinsicId);
break;
case Intrinsic::dx_lerp:
Result = expandLerpIntrinsic(Orig);
Expand All @@ -397,7 +400,7 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
break;
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
Result = expandIntegerDotIntrinsic(Orig, F.getIntrinsicID());
Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
break;
}

Expand Down
83 changes: 83 additions & 0 deletions llvm/test/CodeGen/DirectX/all.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
; RUN: opt -S -passes=dxil-intrinsic-expansion,dxil-op-lower -mtriple=dxil-pc-shadermodel6.0-library < %s | FileCheck %s

; Make sure dxil operation function calls for all are generated for float and half.

; CHECK-LABEL: all_bool
; CHECK: icmp ne i1 %{{.*}}, false
define noundef i1 @all_bool(i1 noundef %p0) {
entry:
%dx.all = call i1 @llvm.dx.all.i1(i1 %p0)
ret i1 %dx.all
}

; CHECK-LABEL: all_int64_t
; CHECK: icmp ne i64 %{{.*}}, 0
define noundef i1 @all_int64_t(i64 noundef %p0) {
entry:
%dx.all = call i1 @llvm.dx.all.i64(i64 %p0)
ret i1 %dx.all
}

; CHECK-LABEL: all_int
; CHECK: icmp ne i32 %{{.*}}, 0
define noundef i1 @all_int(i32 noundef %p0) {
entry:
%dx.all = call i1 @llvm.dx.all.i32(i32 %p0)
ret i1 %dx.all
}

; CHECK-LABEL: all_int16_t
; CHECK: icmp ne i16 %{{.*}}, 0
define noundef i1 @all_int16_t(i16 noundef %p0) {
entry:
%dx.all = call i1 @llvm.dx.all.i16(i16 %p0)
ret i1 %dx.all
}

; CHECK-LABEL: all_double
; CHECK: fcmp une double %{{.*}}, 0.000000e+00
define noundef i1 @all_double(double noundef %p0) {
entry:
%dx.all = call i1 @llvm.dx.all.f64(double %p0)
ret i1 %dx.all
}

; CHECK-LABEL: all_float
; CHECK: fcmp une float %{{.*}}, 0.000000e+00
define noundef i1 @all_float(float noundef %p0) {
entry:
%dx.all = call i1 @llvm.dx.all.f32(float %p0)
ret i1 %dx.all
}

; CHECK-LABEL: all_half
; CHECK: fcmp une half %{{.*}}, 0xH0000
define noundef i1 @all_half(half noundef %p0) {
entry:
%dx.all = call i1 @llvm.dx.all.f16(half %p0)
ret i1 %dx.all
}

; CHECK-LABEL: all_bool4
; CHECK: icmp ne <4 x i1> %{{.*}}, zeroinitialize
; CHECK: extractelement <4 x i1> %{{.*}}, i64 0
; CHECK: extractelement <4 x i1> %{{.*}}, i64 1
; CHECK: and i1 %{{.*}}, %{{.*}}
; CHECK: extractelement <4 x i1> %{{.*}}, i64 2
; CHECK: and i1 %{{.*}}, %{{.*}}
; CHECK: extractelement <4 x i1> %{{.*}}, i64 3
; CHECK: and i1 %{{.*}}, %{{.*}}
define noundef i1 @all_bool4(<4 x i1> noundef %p0) {
entry:
%dx.all = call i1 @llvm.dx.all.v4i1(<4 x i1> %p0)
ret i1 %dx.all
}

declare i1 @llvm.dx.all.v4i1(<4 x i1>)
declare i1 @llvm.dx.all.i1(i1)
declare i1 @llvm.dx.all.i16(i16)
declare i1 @llvm.dx.all.i32(i32)
declare i1 @llvm.dx.all.i64(i64)
declare i1 @llvm.dx.all.f16(half)
declare i1 @llvm.dx.all.f32(float)
declare i1 @llvm.dx.all.f64(double)

0 comments on commit ff5816a

Please sign in to comment.