Skip to content

Commit

Permalink
[ScalarizeMaskedMemIntr] Don't use a scalar mask on GPUs (llvm#104842)
Browse files Browse the repository at this point in the history
ScalarizedMaskedMemIntr contains an optimization where the <N x i1> mask
is bitcast into an iN and then bit-tests with powers of two are used to
determine whether to load/store/... or not.

However, on machines with branch divergence (mainly GPUs), this is a
mis-optimization, since each i1 in the mask will be stored in a
condition register - that is, ecah of these "i1"s is likely to be a word
or two wide, making these bit operations counterproductive.

Therefore, amend this pass to skip the optimizaiton on targets that it
pessimizes.

Pre-commit tests llvm#104645
  • Loading branch information
krzysz00 authored and cjdb committed Aug 23, 2024
1 parent aae9b4b commit f2ceb78
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 109 deletions.
136 changes: 83 additions & 53 deletions llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,11 @@ class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {

static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
const TargetTransformInfo &TTI, const DataLayout &DL,
DomTreeUpdater *DTU);
bool HasBranchDivergence, DomTreeUpdater *DTU);
static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
const TargetTransformInfo &TTI,
const DataLayout &DL, DomTreeUpdater *DTU);
const DataLayout &DL, bool HasBranchDivergence,
DomTreeUpdater *DTU);

char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;

Expand Down Expand Up @@ -141,8 +142,9 @@ static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
// %10 = extractelement <16 x i1> %mask, i32 2
// br i1 %10, label %cond.load4, label %else5
//
static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
DomTreeUpdater *DTU, bool &ModifiedDT) {
static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence,
CallInst *CI, DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Ptr = CI->getArgOperand(0);
Value *Alignment = CI->getArgOperand(1);
Value *Mask = CI->getArgOperand(2);
Expand Down Expand Up @@ -221,25 +223,26 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
return;
}
// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
// Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
// - what's a good way to detect this?
Value *SclrMask;
if (VectorWidth != 1) {
// better results on X86 at least. However, don't do this on GPUs and other
// machines with divergence, as there each i1 needs a vector register.
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}

for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
// Fill the "else" block, created in the previous iteration
//
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
// %mask_1 = and i16 %scalar_mask, i32 1 << Idx
// %cond = icmp ne i16 %mask_1, 0
// br i1 %mask_1, label %cond.load, label %else
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
// %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
// %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead
Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -312,8 +315,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
// store i32 %6, i32* %7
// br label %else2
// . . .
static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
DomTreeUpdater *DTU, bool &ModifiedDT) {
static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence,
CallInst *CI, DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Src = CI->getArgOperand(0);
Value *Ptr = CI->getArgOperand(1);
Value *Alignment = CI->getArgOperand(2);
Expand Down Expand Up @@ -378,10 +382,10 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
}

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.

Value *SclrMask;
if (VectorWidth != 1) {
// better results on X86 at least. However, don't do this on GPUs or other
// machines with branch divergence, as there each i1 takes up a register.
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}
Expand All @@ -393,8 +397,11 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
// %cond = icmp ne i16 %mask_1, 0
// br i1 %mask_1, label %cond.store, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead
Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -461,7 +468,8 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
// . . .
// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
// ret <16 x i32> %Result
static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
static void scalarizeMaskedGather(const DataLayout &DL,
bool HasBranchDivergence, CallInst *CI,
DomTreeUpdater *DTU, bool &ModifiedDT) {
Value *Ptrs = CI->getArgOperand(0);
Value *Alignment = CI->getArgOperand(1);
Expand Down Expand Up @@ -500,9 +508,10 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
}

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
Value *SclrMask;
if (VectorWidth != 1) {
// better results on X86 at least. However, don't do this on GPUs or other
// machines with branch divergence, as there, each i1 takes up a register.
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}
Expand All @@ -514,9 +523,12 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
// %cond = icmp ne i16 %mask_1, 0
// br i1 %Mask1, label %cond.load, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead

Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -591,7 +603,8 @@ static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI,
// store i32 %Elt1, i32* %Ptr1, align 4
// br label %else2
// . . .
static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
static void scalarizeMaskedScatter(const DataLayout &DL,
bool HasBranchDivergence, CallInst *CI,
DomTreeUpdater *DTU, bool &ModifiedDT) {
Value *Src = CI->getArgOperand(0);
Value *Ptrs = CI->getArgOperand(1);
Expand Down Expand Up @@ -629,8 +642,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
Value *SclrMask;
if (VectorWidth != 1) {
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}
Expand All @@ -642,8 +655,11 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
// %cond = icmp ne i16 %mask_1, 0
// br i1 %Mask1, label %cond.store, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead
Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -681,7 +697,8 @@ static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI,
ModifiedDT = true;
}

static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
static void scalarizeMaskedExpandLoad(const DataLayout &DL,
bool HasBranchDivergence, CallInst *CI,
DomTreeUpdater *DTU, bool &ModifiedDT) {
Value *Ptr = CI->getArgOperand(0);
Value *Mask = CI->getArgOperand(1);
Expand Down Expand Up @@ -738,23 +755,27 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
}

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
Value *SclrMask;
if (VectorWidth != 1) {
// better results on X86 at least. However, don't do this on GPUs or other
// machines with branch divergence, as there, each i1 takes up a register.
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}

for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
// Fill the "else" block, created in the previous iteration
//
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
// br i1 %mask_1, label %cond.load, label %else
// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
// %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
// label %cond.load, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead

Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -813,7 +834,8 @@ static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI,
ModifiedDT = true;
}

static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
static void scalarizeMaskedCompressStore(const DataLayout &DL,
bool HasBranchDivergence, CallInst *CI,
DomTreeUpdater *DTU,
bool &ModifiedDT) {
Value *Src = CI->getArgOperand(0);
Expand Down Expand Up @@ -855,9 +877,10 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
}

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
Value *SclrMask;
if (VectorWidth != 1) {
// better results on X86 at least. However, don't do this on GPUs or other
// machines with branch divergence, as there, each i1 takes up a register.
Value *SclrMask = nullptr;
if (VectorWidth != 1 && !HasBranchDivergence) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
}
Expand All @@ -868,8 +891,11 @@ static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI,
// %mask_1 = extractelement <16 x i1> %mask, i32 Idx
// br i1 %mask_1, label %cond.store, label %else
//
// On GPUs, use
// %cond = extrectelement %mask, Idx
// instead
Value *Predicate;
if (VectorWidth != 1) {
if (SclrMask != nullptr) {
Value *Mask = Builder.getInt(APInt::getOneBitSet(
VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
Expand Down Expand Up @@ -993,12 +1019,13 @@ static bool runImpl(Function &F, const TargetTransformInfo &TTI,
bool EverMadeChange = false;
bool MadeChange = true;
auto &DL = F.getDataLayout();
bool HasBranchDivergence = TTI.hasBranchDivergence(&F);
while (MadeChange) {
MadeChange = false;
for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
bool ModifiedDTOnIteration = false;
MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
DTU ? &*DTU : nullptr);
HasBranchDivergence, DTU ? &*DTU : nullptr);

// Restart BB iteration if the dominator tree of the Function was changed
if (ModifiedDTOnIteration)
Expand Down Expand Up @@ -1032,13 +1059,14 @@ ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {

static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
const TargetTransformInfo &TTI, const DataLayout &DL,
DomTreeUpdater *DTU) {
bool HasBranchDivergence, DomTreeUpdater *DTU) {
bool MadeChange = false;

BasicBlock::iterator CurInstIterator = BB.begin();
while (CurInstIterator != BB.end()) {
if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU);
MadeChange |=
optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
if (ModifiedDT)
return true;
}
Expand All @@ -1048,7 +1076,8 @@ static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,

static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
const TargetTransformInfo &TTI,
const DataLayout &DL, DomTreeUpdater *DTU) {
const DataLayout &DL, bool HasBranchDivergence,
DomTreeUpdater *DTU) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
if (II) {
// The scalarization code below does not work for scalable vectors.
Expand All @@ -1071,14 +1100,14 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
CI->getType(),
cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
return false;
scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT);
scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
case Intrinsic::masked_store:
if (TTI.isLegalMaskedStore(
CI->getArgOperand(0)->getType(),
cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
return false;
scalarizeMaskedStore(DL, CI, DTU, ModifiedDT);
scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
case Intrinsic::masked_gather: {
MaybeAlign MA =
Expand All @@ -1089,7 +1118,7 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
!TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
return false;
scalarizeMaskedGather(DL, CI, DTU, ModifiedDT);
scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
}
case Intrinsic::masked_scatter: {
Expand All @@ -1102,22 +1131,23 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
!TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
Alignment))
return false;
scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT);
scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
}
case Intrinsic::masked_expandload:
if (TTI.isLegalMaskedExpandLoad(
CI->getType(),
CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
return false;
scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT);
scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
return true;
case Intrinsic::masked_compressstore:
if (TTI.isLegalMaskedCompressStore(
CI->getArgOperand(0)->getType(),
CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
return false;
scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT);
scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU,
ModifiedDT);
return true;
}
}
Expand Down
Loading

0 comments on commit f2ceb78

Please sign in to comment.