Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NFC][DXIL] move replace/erase in DXIL intrinsic expansion to caller #104626

Merged
merged 3 commits into from
Aug 17, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 60 additions & 75 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static bool isIntrinsicExpansion(Function &F) {
return false;
}

static bool expandAbs(CallInst *Orig) {
static Value *expandAbs(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Expand All @@ -66,12 +66,10 @@ static bool expandAbs(CallInst *Orig) {
auto *V = Builder.CreateSub(Zero, X);
auto *MaxCall =
Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
Orig->replaceAllUsesWith(MaxCall);
Orig->eraseFromParent();
return true;
return MaxCall;
}

static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
assert(DotIntrinsic == Intrinsic::dx_sdot ||
DotIntrinsic == Intrinsic::dx_udot);
Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
Expand All @@ -97,12 +95,10 @@ static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
ArrayRef<Value *>{Elt0, Elt1, Result},
nullptr, "dx.mad");
}
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Result;
}

static bool expandExpIntrinsic(CallInst *Orig) {
static Value *expandExpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Expand All @@ -119,23 +115,21 @@ static bool expandExpIntrinsic(CallInst *Orig) {
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
Exp2Call->setTailCall(Orig->isTailCall());
Exp2Call->setAttributes(Orig->getAttributes());
Orig->replaceAllUsesWith(Exp2Call);
Orig->eraseFromParent();
return true;
return Exp2Call;
}

static bool expandAnyIntrinsic(CallInst *Orig) {
static Value *expandAnyIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Type *Ty = X->getType();
Type *EltTy = Ty->getScalarType();

Value *Result = nullptr;
if (!Ty->isVectorTy()) {
Value *Cond = EltTy->isFloatingPointTy()
? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
Orig->replaceAllUsesWith(Cond);
Result = EltTy->isFloatingPointTy()
? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
} else {
auto *XVec = dyn_cast<FixedVectorType>(Ty);
Value *Cond =
Expand All @@ -148,18 +142,16 @@ static bool expandAnyIntrinsic(CallInst *Orig) {
X, ConstantVector::getSplat(
ElementCount::getFixed(XVec->getNumElements()),
ConstantInt::get(EltTy, 0)));
Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
for (unsigned I = 1; I < XVec->getNumElements(); I++) {
Value *Elt = Builder.CreateExtractElement(Cond, I);
Result = Builder.CreateOr(Result, Elt);
}
Orig->replaceAllUsesWith(Result);
}
Orig->eraseFromParent();
return true;
return Result;
}

static bool expandLengthIntrinsic(CallInst *Orig) {
static Value *expandLengthIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Expand All @@ -182,30 +174,23 @@ static bool expandLengthIntrinsic(CallInst *Orig) {
Value *Mul = Builder.CreateFMul(Elt, Elt);
Sum = Builder.CreateFAdd(Sum, Mul);
}
Value *Result = Builder.CreateIntrinsic(
EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum}, nullptr, "elt.sqrt");

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Builder.CreateIntrinsic(EltTy, Intrinsic::sqrt, ArrayRef<Value *>{Sum},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes the createIntrinsic call itself is returned, and sometimes the Value * variable that was assigned to this call is returned instead. Should we stick to returning the variable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only places the variable is returned are places where something else is done to it before replacing the old one or where it might come from different creation calls depending on the path taken. Defining a variable on one line just to return it on the next seems unnecessary

nullptr, "elt.sqrt");
}

static bool expandLerpIntrinsic(CallInst *Orig) {
static Value *expandLerpIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Value *S = Orig->getOperand(2);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
auto *V = Builder.CreateFSub(Y, X);
V = Builder.CreateFMul(S, V);
auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Builder.CreateFAdd(X, V, "dx.lerp");
}

static bool expandLogIntrinsic(CallInst *Orig,
float LogConstVal = numbers::ln2f) {
static Value *expandLogIntrinsic(CallInst *Orig,
float LogConstVal = numbers::ln2f) {
Value *X = Orig->getOperand(0);
IRBuilder<> Builder(Orig->getParent());
Builder.SetInsertPoint(Orig);
Expand All @@ -221,16 +206,13 @@ static bool expandLogIntrinsic(CallInst *Orig,
Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
Log2Call->setTailCall(Orig->isTailCall());
Log2Call->setAttributes(Orig->getAttributes());
auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Builder.CreateFMul(Ln2Const, Log2Call);
}
static bool expandLog10Intrinsic(CallInst *Orig) {
static Value *expandLog10Intrinsic(CallInst *Orig) {
return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
}

static bool expandNormalizeIntrinsic(CallInst *Orig) {
static Value *expandNormalizeIntrinsic(CallInst *Orig) {
Value *X = Orig->getOperand(0);
Type *Ty = Orig->getType();
Type *EltTy = Ty->getScalarType();
Expand All @@ -245,11 +227,7 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
report_fatal_error(Twine("Invalid input scalar: length is zero"),
/* gen_crash_diag=*/false);
}
Value *Result = Builder.CreateFDiv(X, X);

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Builder.CreateFDiv(X, X);
}

unsigned XVecSize = XVec->getNumElements();
Expand Down Expand Up @@ -291,14 +269,10 @@ static bool expandNormalizeIntrinsic(CallInst *Orig) {
nullptr, "dx.rsqrt");

Value *MultiplicandVec = Builder.CreateVectorSplat(XVecSize, Multiplicand);
Value *Result = Builder.CreateFMul(X, MultiplicandVec);

Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
return true;
return Builder.CreateFMul(X, MultiplicandVec);
}

static bool expandPowIntrinsic(CallInst *Orig) {
static Value *expandPowIntrinsic(CallInst *Orig) {

Value *X = Orig->getOperand(0);
Value *Y = Orig->getOperand(1);
Expand All @@ -313,9 +287,7 @@ static bool expandPowIntrinsic(CallInst *Orig) {
Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
Exp2Call->setTailCall(Orig->isTailCall());
Exp2Call->setAttributes(Orig->getAttributes());
Orig->replaceAllUsesWith(Exp2Call);
Orig->eraseFromParent();
return true;
return Exp2Call;
}

static Intrinsic::ID getMaxForClamp(Type *ElemTy,
Expand Down Expand Up @@ -344,7 +316,8 @@ static Intrinsic::ID getMinForClamp(Type *ElemTy,
return Intrinsic::minnum;
}

static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
static Value *expandClampIntrinsic(CallInst *Orig,
Intrinsic::ID ClampIntrinsic) {
Value *X = Orig->getOperand(0);
Value *Min = Orig->getOperand(1);
Value *Max = Orig->getOperand(2);
Expand All @@ -353,43 +326,55 @@ static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
Builder.SetInsertPoint(Orig);
auto *MaxCall = Builder.CreateIntrinsic(
Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
auto *MinCall =
Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
{MaxCall, Max}, nullptr, "dx.min");

Orig->replaceAllUsesWith(MinCall);
Orig->eraseFromParent();
return true;
return Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
{MaxCall, Max}, nullptr, "dx.min");
}

static bool expandIntrinsic(Function &F, CallInst *Orig) {
Value *Result = nullptr;
switch (F.getIntrinsicID()) {
case Intrinsic::abs:
return expandAbs(Orig);
Result = expandAbs(Orig);
break;
case Intrinsic::exp:
return expandExpIntrinsic(Orig);
Result = expandExpIntrinsic(Orig);
break;
case Intrinsic::log:
return expandLogIntrinsic(Orig);
Result = expandLogIntrinsic(Orig);
break;
case Intrinsic::log10:
return expandLog10Intrinsic(Orig);
Result = expandLog10Intrinsic(Orig);
break;
case Intrinsic::pow:
return expandPowIntrinsic(Orig);
Result = expandPowIntrinsic(Orig);
break;
case Intrinsic::dx_any:
return expandAnyIntrinsic(Orig);
Result = expandAnyIntrinsic(Orig);
break;
case Intrinsic::dx_uclamp:
case Intrinsic::dx_clamp:
return expandClampIntrinsic(Orig, F.getIntrinsicID());
Result = expandClampIntrinsic(Orig, F.getIntrinsicID());
break;
case Intrinsic::dx_lerp:
return expandLerpIntrinsic(Orig);
Result = expandLerpIntrinsic(Orig);
break;
case Intrinsic::dx_length:
return expandLengthIntrinsic(Orig);
Result = expandLengthIntrinsic(Orig);
break;
case Intrinsic::dx_normalize:
return expandNormalizeIntrinsic(Orig);
Result = expandNormalizeIntrinsic(Orig);
break;
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
return expandIntegerDot(Orig, F.getIntrinsicID());
Result = expandIntegerDot(Orig, F.getIntrinsicID());
break;
}
return false;

if (Result) {
Orig->replaceAllUsesWith(Result);
Orig->eraseFromParent();
farzonl marked this conversation as resolved.
Show resolved Hide resolved
}
return !!Result;
farzonl marked this conversation as resolved.
Show resolved Hide resolved
}

static bool expansionIntrinsics(Module &M) {
Expand Down
Loading