Skip to content

Commit

Permalink
[ScalarizeMaskedMemIntr] Optimize splat non-constant masks (llvm#104537)
Browse files Browse the repository at this point in the history
In cases (like the ones added in the tests) where the condition of a
masked load or store is a splat but not a constant (that is, a masked
operation is being used to implement patterns like "load if the current
lane is in-bounds, otherwise return 0"), optimize the 'scalarized' code
to perform an aligned vector load/store if the splat constant is true.

Additionally, take a few steps to preserve aliasing information and
names when nothing is scalarized while I'm here.

As motivation, some LLVM IR users will genatate masked load/store in
cases that map to this kind of predicated operation (where either the
vector is loaded/stored or it isn't) in order to take advantage of
hardware primitives, but on AMDGPU, where we don't have a masked load or
store, this pass would scalarize a load or store that was intended to be
- and can be - vectorized while also introducing expensive branches.

Fixes llvm#104520

Pre-commit tests at llvm#104527

Change-Id: I389e3398af7377108533a8ef0dd7a45e4b20b5ee
  • Loading branch information
krzysz00 authored and David Salinas committed Oct 1, 2024
1 parent 3b93451 commit 76a1d46
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 874 deletions.
64 changes: 61 additions & 3 deletions llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
Expand Down Expand Up @@ -161,7 +162,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,

// Short-cut if the mask is all-true.
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
NewI->copyMetadata(*CI);
NewI->takeName(CI);
CI->replaceAllUsesWith(NewI);
CI->eraseFromParent();
return;
Expand All @@ -188,8 +191,39 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
return;
}

// Optimize the case where the "masked load" is a predicated load - that is,
// where the mask is the splat of a non-constant scalar boolean. In that case,
// use that splated value as the guard on a conditional vector load.
if (isSplatValue(Mask, /*Index=*/0)) {
Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
Mask->getName() + ".first");
Instruction *ThenTerm =
SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);

BasicBlock *CondBlock = ThenTerm->getParent();
CondBlock->setName("cond.load");
Builder.SetInsertPoint(CondBlock->getTerminator());
LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
CI->getName() + ".cond.load");
Load->copyMetadata(*CI);

BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
Builder.SetInsertPoint(PostLoad, PostLoad->begin());
PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
Phi->addIncoming(Load, CondBlock);
Phi->addIncoming(Src0, IfBlock);
Phi->takeName(CI);

CI->replaceAllUsesWith(Phi);
CI->eraseFromParent();
ModifiedDT = true;
return;
}
// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
// Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
// - what's a good way to detect this?
Value *SclrMask;
if (VectorWidth != 1) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
Expand Down Expand Up @@ -297,7 +331,9 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,

// Short-cut if the mask is all-true.
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
Builder.CreateAlignedStore(Src, Ptr, AlignVal);
StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
Store->takeName(CI);
Store->copyMetadata(*CI);
CI->eraseFromParent();
return;
}
Expand All @@ -319,8 +355,31 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
return;
}

// Optimize the case where the "masked store" is a predicated store - that is,
// when the mask is the splat of a non-constant scalar boolean. In that case,
// optimize to a conditional store.
if (isSplatValue(Mask, /*Index=*/0)) {
Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
Mask->getName() + ".first");
Instruction *ThenTerm =
SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);
BasicBlock *CondBlock = ThenTerm->getParent();
CondBlock->setName("cond.store");
Builder.SetInsertPoint(CondBlock->getTerminator());

StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
Store->takeName(CI);
Store->copyMetadata(*CI);

CI->eraseFromParent();
ModifiedDT = true;
return;
}

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.

Value *SclrMask;
if (VectorWidth != 1) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
Expand Down Expand Up @@ -997,7 +1056,6 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
any_of(II->args(),
[](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
return false;

switch (II->getIntrinsicID()) {
default:
break;
Expand Down
Loading

0 comments on commit 76a1d46

Please sign in to comment.