Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ScalarizeMaskedMemIntr] Optimize splat non-constant masks #104537

Merged
merged 2 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
Expand Down Expand Up @@ -161,7 +162,9 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,

// Short-cut if the mask is all-true.
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
NewI->copyMetadata(*CI);
NewI->takeName(CI);
CI->replaceAllUsesWith(NewI);
CI->eraseFromParent();
return;
Expand All @@ -188,8 +191,39 @@ static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI,
return;
}

// Optimize the case where the "masked load" is a predicated load - that is,
// where the mask is the splat of a non-constant scalar boolean. In that case,
// use that splated value as the guard on a conditional vector load.
if (isSplatValue(Mask, /*Index=*/0)) {
Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a debug assert so that I can verify that this behavior worked without inspecting asm/ir ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so - we need the rest of the function

Maybe I could collect a statistic?

Mask->getName() + ".first");
Instruction *ThenTerm =
SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);

BasicBlock *CondBlock = ThenTerm->getParent();
CondBlock->setName("cond.load");
Builder.SetInsertPoint(CondBlock->getTerminator());
LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
CI->getName() + ".cond.load");
Load->copyMetadata(*CI);

BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
Builder.SetInsertPoint(PostLoad, PostLoad->begin());
PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
Phi->addIncoming(Load, CondBlock);
Phi->addIncoming(Src0, IfBlock);
Phi->takeName(CI);

CI->replaceAllUsesWith(Phi);
CI->eraseFromParent();
ModifiedDT = true;
return;
}
// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.
// Note: this produces worse code on AMDGPU, where the "i1" is implicitly SIMD
// - what's a good way to detect this?
Value *SclrMask;
if (VectorWidth != 1) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
Expand Down Expand Up @@ -297,7 +331,9 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,

// Short-cut if the mask is all-true.
if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
Builder.CreateAlignedStore(Src, Ptr, AlignVal);
StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
Store->takeName(CI);
Store->copyMetadata(*CI);
CI->eraseFromParent();
return;
}
Expand All @@ -319,8 +355,31 @@ static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI,
return;
}

// Optimize the case where the "masked store" is a predicated store - that is,
// when the mask is the splat of a non-constant scalar boolean. In that case,
// optimize to a conditional store.
if (isSplatValue(Mask, /*Index=*/0)) {
Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
Mask->getName() + ".first");
Instruction *ThenTerm =
SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);
BasicBlock *CondBlock = ThenTerm->getParent();
CondBlock->setName("cond.store");
Builder.SetInsertPoint(CondBlock->getTerminator());

StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
Store->takeName(CI);
Store->copyMetadata(*CI);

CI->eraseFromParent();
ModifiedDT = true;
return;
}

// If the mask is not v1i1, use scalar bit test operations. This generates
// better results on X86 at least.

Value *SclrMask;
if (VectorWidth != 1) {
Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
Expand Down Expand Up @@ -997,7 +1056,6 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
any_of(II->args(),
[](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
return false;

switch (II->getIntrinsicID()) {
default:
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ define <2 x i64> @scalarize_v2i64(ptr %p, <2 x i1> %mask, <2 x i64> %passthru) {

define <2 x i64> @scalarize_v2i64_ones_mask(ptr %p, <2 x i64> %passthru) {
; CHECK-LABEL: @scalarize_v2i64_ones_mask(
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x i64>, ptr [[P:%.*]], align 8
; CHECK-NEXT: ret <2 x i64> [[TMP1]]
; CHECK-NEXT: [[RET:%.*]] = load <2 x i64>, ptr [[P:%.*]], align 8
; CHECK-NEXT: ret <2 x i64> [[RET]]
;
%ret = call <2 x i64> @llvm.masked.load.v2i64.p0(ptr %p, i32 8, <2 x i1> <i1 true, i1 true>, <2 x i64> %passthru)
ret <2 x i64> %ret
Expand All @@ -58,34 +58,18 @@ define <2 x i64> @scalarize_v2i64_const_mask(ptr %p, <2 x i64> %passthru) {
ret <2 x i64> %ret
}

; To be fixed: If the mask is the splat/broadcast of a non-constant value, use a
; vector load
define <2 x i64> @scalarize_v2i64_splat_mask(ptr %p, i1 %mask, <2 x i64> %passthrough) {
; CHECK-LABEL: @scalarize_v2i64_splat_mask(
; CHECK-NEXT: [[MASK_VEC:%.*]] = insertelement <2 x i1> poison, i1 [[MASK:%.*]], i32 0
; CHECK-NEXT: [[MASK_SPLAT:%.*]] = shufflevector <2 x i1> [[MASK_VEC]], <2 x i1> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK_SPLAT]] to i2
; CHECK-NEXT: [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
; CHECK-NEXT: br i1 [[TMP2]], label [[COND_LOAD:%.*]], label [[ELSE:%.*]]
; CHECK-NEXT: [[MASK_SPLAT_FIRST:%.*]] = extractelement <2 x i1> [[MASK_SPLAT]], i64 0
; CHECK-NEXT: br i1 [[MASK_SPLAT_FIRST]], label [[COND_LOAD:%.*]], label [[TMP1:%.*]]
; CHECK: cond.load:
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i32 0
; CHECK-NEXT: [[TMP4:%.*]] = load i64, ptr [[TMP3]], align 8
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i64> [[PASSTHROUGH:%.*]], i64 [[TMP4]], i64 0
; CHECK-NEXT: br label [[ELSE]]
; CHECK: else:
; CHECK-NEXT: [[RES_PHI_ELSE:%.*]] = phi <2 x i64> [ [[TMP5]], [[COND_LOAD]] ], [ [[PASSTHROUGH]], [[TMP0:%.*]] ]
; CHECK-NEXT: [[TMP6:%.*]] = and i2 [[SCALAR_MASK]], -2
; CHECK-NEXT: [[TMP7:%.*]] = icmp ne i2 [[TMP6]], 0
; CHECK-NEXT: br i1 [[TMP7]], label [[COND_LOAD1:%.*]], label [[ELSE2:%.*]]
; CHECK: cond.load1:
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[P]], i32 1
; CHECK-NEXT: [[TMP9:%.*]] = load i64, ptr [[TMP8]], align 8
; CHECK-NEXT: [[TMP10:%.*]] = insertelement <2 x i64> [[RES_PHI_ELSE]], i64 [[TMP9]], i64 1
; CHECK-NEXT: br label [[ELSE2]]
; CHECK: else2:
; CHECK-NEXT: [[RES_PHI_ELSE3:%.*]] = phi <2 x i64> [ [[TMP10]], [[COND_LOAD1]] ], [ [[RES_PHI_ELSE]], [[ELSE]] ]
; CHECK-NEXT: ret <2 x i64> [[RES_PHI_ELSE3]]
; CHECK-NEXT: [[RET_COND_LOAD:%.*]] = load <2 x i64>, ptr [[P:%.*]], align 8
; CHECK-NEXT: br label [[TMP1]]
; CHECK: 1:
; CHECK-NEXT: [[RET:%.*]] = phi <2 x i64> [ [[RET_COND_LOAD]], [[COND_LOAD]] ], [ [[PASSTHROUGH:%.*]], [[TMP0:%.*]] ]
; CHECK-NEXT: ret <2 x i64> [[RET]]
;
%mask.vec = insertelement <2 x i1> poison, i1 %mask, i32 0
%mask.splat = shufflevector <2 x i1> %mask.vec, <2 x i1> poison, <2 x i32> zeroinitializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,31 +56,16 @@ define void @scalarize_v2i64_const_mask(ptr %p, <2 x i64> %data) {
ret void
}

; To be fixed: If the mask is the splat/broadcast of a non-constant value, use a
; vector store
define void @scalarize_v2i64_splat_mask(ptr %p, <2 x i64> %data, i1 %mask) {
; CHECK-LABEL: @scalarize_v2i64_splat_mask(
; CHECK-NEXT: [[MASK_VEC:%.*]] = insertelement <2 x i1> poison, i1 [[MASK:%.*]], i32 0
; CHECK-NEXT: [[MASK_SPLAT:%.*]] = shufflevector <2 x i1> [[MASK_VEC]], <2 x i1> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[SCALAR_MASK:%.*]] = bitcast <2 x i1> [[MASK_SPLAT]] to i2
; CHECK-NEXT: [[TMP1:%.*]] = and i2 [[SCALAR_MASK]], 1
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i2 [[TMP1]], 0
; CHECK-NEXT: br i1 [[TMP2]], label [[COND_STORE:%.*]], label [[ELSE:%.*]]
; CHECK-NEXT: [[MASK_SPLAT_FIRST:%.*]] = extractelement <2 x i1> [[MASK_SPLAT]], i64 0
; CHECK-NEXT: br i1 [[MASK_SPLAT_FIRST]], label [[COND_STORE:%.*]], label [[TMP1:%.*]]
; CHECK: cond.store:
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x i64> [[DATA:%.*]], i64 0
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i32 0
; CHECK-NEXT: store i64 [[TMP3]], ptr [[TMP4]], align 8
; CHECK-NEXT: br label [[ELSE]]
; CHECK: else:
; CHECK-NEXT: [[TMP5:%.*]] = and i2 [[SCALAR_MASK]], -2
; CHECK-NEXT: [[TMP6:%.*]] = icmp ne i2 [[TMP5]], 0
; CHECK-NEXT: br i1 [[TMP6]], label [[COND_STORE1:%.*]], label [[ELSE2:%.*]]
; CHECK: cond.store1:
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i64> [[DATA]], i64 1
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i64, ptr [[P]], i32 1
; CHECK-NEXT: store i64 [[TMP7]], ptr [[TMP8]], align 8
; CHECK-NEXT: br label [[ELSE2]]
; CHECK: else2:
; CHECK-NEXT: store <2 x i64> [[DATA:%.*]], ptr [[P:%.*]], align 8
; CHECK-NEXT: br label [[TMP1]]
; CHECK: 1:
; CHECK-NEXT: ret void
;
%mask.vec = insertelement <2 x i1> poison, i1 %mask, i32 0
Expand Down
Loading