Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Move vmv.v.v peephole from SelectionDAG to RISCVVectorPeephole #100367

Merged
merged 6 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 14 additions & 70 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3664,32 +3664,6 @@ static bool IsVMerge(SDNode *N) {
return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMERGE_VVM;
}

static bool IsVMv(SDNode *N) {
return RISCV::getRVVMCOpcode(N->getMachineOpcode()) == RISCV::VMV_V_V;
}

static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
switch (LMUL) {
case RISCVII::LMUL_F8:
return RISCV::PseudoVMSET_M_B1;
case RISCVII::LMUL_F4:
return RISCV::PseudoVMSET_M_B2;
case RISCVII::LMUL_F2:
return RISCV::PseudoVMSET_M_B4;
case RISCVII::LMUL_1:
return RISCV::PseudoVMSET_M_B8;
case RISCVII::LMUL_2:
return RISCV::PseudoVMSET_M_B16;
case RISCVII::LMUL_4:
return RISCV::PseudoVMSET_M_B32;
case RISCVII::LMUL_8:
return RISCV::PseudoVMSET_M_B64;
case RISCVII::LMUL_RESERVED:
llvm_unreachable("Unexpected LMUL");
}
llvm_unreachable("Unknown VLMUL enum");
}

// Try to fold away VMERGE_VVM instructions into their true operands:
//
// %true = PseudoVADD_VV ...
Expand All @@ -3704,35 +3678,22 @@ static unsigned GetVMSetForLMul(RISCVII::VLMUL LMUL) {
// If %true is masked, then we can use its mask instead of vmerge's if vmerge's
// mask is all ones.
//
// We can also fold a VMV_V_V into its true operand, since it is equivalent to a
// VMERGE_VVM with an all ones mask.
//
// The resulting VL is the minimum of the two VLs.
//
// The resulting policy is the effective policy the vmerge would have had,
// i.e. whether or not it's passthru operand was implicit-def.
bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
SDValue Passthru, False, True, VL, Mask, Glue;
// A vmv.v.v is equivalent to a vmerge with an all-ones mask.
if (IsVMv(N)) {
Passthru = N->getOperand(0);
False = N->getOperand(0);
True = N->getOperand(1);
VL = N->getOperand(2);
// A vmv.v.v won't have a Mask or Glue, instead we'll construct an all-ones
// mask later below.
} else {
assert(IsVMerge(N));
Passthru = N->getOperand(0);
False = N->getOperand(1);
True = N->getOperand(2);
Mask = N->getOperand(3);
VL = N->getOperand(4);
// We always have a glue node for the mask at v0.
Glue = N->getOperand(N->getNumOperands() - 1);
}
assert(!Mask || cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
assert(!Glue || Glue.getValueType() == MVT::Glue);
assert(IsVMerge(N));
Passthru = N->getOperand(0);
False = N->getOperand(1);
True = N->getOperand(2);
Mask = N->getOperand(3);
VL = N->getOperand(4);
// We always have a glue node for the mask at v0.
Glue = N->getOperand(N->getNumOperands() - 1);
assert(cast<RegisterSDNode>(Mask)->getReg() == RISCV::V0);
assert(Glue.getValueType() == MVT::Glue);

// If the EEW of True is different from vmerge's SEW, then we can't fold.
if (True.getSimpleValueType() != N->getSimpleValueType(0))
Expand Down Expand Up @@ -3780,7 +3741,7 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {

// If True is masked then the vmerge must have either the same mask or an all
// 1s mask, since we're going to keep the mask from True.
if (IsMasked && Mask) {
if (IsMasked) {
// FIXME: Support mask agnostic True instruction which would have an
// undef passthru operand.
SDValue TrueMask =
Expand Down Expand Up @@ -3810,11 +3771,9 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
SmallVector<const SDNode *, 4> LoopWorklist;
SmallPtrSet<const SDNode *, 16> Visited;
LoopWorklist.push_back(False.getNode());
if (Mask)
LoopWorklist.push_back(Mask.getNode());
LoopWorklist.push_back(Mask.getNode());
LoopWorklist.push_back(VL.getNode());
if (Glue)
LoopWorklist.push_back(Glue.getNode());
LoopWorklist.push_back(Glue.getNode());
if (SDNode::hasPredecessorHelper(True.getNode(), Visited, LoopWorklist))
return false;
}
Expand Down Expand Up @@ -3875,21 +3834,6 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
Glue = True->getOperand(True->getNumOperands() - 1);
assert(Glue.getValueType() == MVT::Glue);
}
// If we end up using the vmerge mask the vmerge is actually a vmv.v.v, create
// an all-ones mask to use.
else if (IsVMv(N)) {
unsigned TSFlags = TII->get(N->getMachineOpcode()).TSFlags;
unsigned VMSetOpc = GetVMSetForLMul(RISCVII::getLMul(TSFlags));
ElementCount EC = N->getValueType(0).getVectorElementCount();
MVT MaskVT = MVT::getVectorVT(MVT::i1, EC);

SDValue AllOnesMask =
SDValue(CurDAG->getMachineNode(VMSetOpc, DL, MaskVT, VL, SEW), 0);
SDValue MaskCopy = CurDAG->getCopyToReg(CurDAG->getEntryNode(), DL,
RISCV::V0, AllOnesMask, SDValue());
Mask = CurDAG->getRegister(RISCV::V0, MaskVT);
Glue = MaskCopy.getValue(1);
}

unsigned MaskedOpc = Info->MaskedPseudo;
#ifndef NDEBUG
Expand Down Expand Up @@ -3968,7 +3912,7 @@ bool RISCVDAGToDAGISel::doPeepholeMergeVVMFold() {
if (N->use_empty() || !N->isMachineOpcode())
continue;

if (IsVMerge(N) || IsVMv(N))
if (IsVMerge(N))
MadeChange |= performCombineVMergeAndVOps(N);
}
return MadeChange;
Expand Down
141 changes: 140 additions & 1 deletion llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class RISCVVectorPeephole : public MachineFunctionPass {
bool convertToWholeRegister(MachineInstr &MI) const;
bool convertToUnmasked(MachineInstr &MI) const;
bool convertVMergeToVMv(MachineInstr &MI) const;
bool foldVMV_V_V(MachineInstr &MI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a concern about the naming there. It is not consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I noticed this as well. I think the camel case in convertVMergeToVMv came from trying to please clang-tidy. But VMV_V_V is more accurate to the actual instruction name. Do we have a preference?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to decide... I don't have a preference, maybe just keep it, we already have a lof of exceptions in RISCVISelLowering now.


bool isAllOnesMask(const MachineInstr *MaskDef) const;
std::optional<unsigned> getConstant(const MachineOperand &VL) const;
Expand Down Expand Up @@ -324,6 +325,143 @@ bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
return true;
}

/// Given two VL operands, returns the one known to be the smallest or nullptr
/// if unknown.
static const MachineOperand *getKnownMinVL(const MachineOperand *LHS,
const MachineOperand *RHS) {
if (LHS->isReg() && RHS->isReg() && LHS->getReg().isVirtual() &&
LHS->getReg() == RHS->getReg())
return LHS;
if (LHS->isImm() && LHS->getImm() == RISCV::VLMaxSentinel)
return RHS;
if (RHS->isImm() && RHS->getImm() == RISCV::VLMaxSentinel)
return LHS;
if (!LHS->isImm() || !RHS->isImm())
return nullptr;
return LHS->getImm() <= RHS->getImm() ? LHS : RHS;
}

/// Check if it's safe to move From down to To, checking that no physical
/// registers are clobbered.
static bool isSafeToMove(const MachineInstr &From, const MachineInstr &To) {
assert(From.getParent() == To.getParent() && !From.hasImplicitDef());
SmallVector<Register> PhysUses;
for (const MachineOperand &MO : From.all_uses())
if (MO.getReg().isPhysical())
PhysUses.push_back(MO.getReg());
bool SawStore = false;
for (auto II = From.getIterator(); II != To.getIterator(); II++) {
for (Register PhysReg : PhysUses)
if (II->definesRegister(PhysReg, nullptr))
return false;
if (II->mayStore()) {
SawStore = true;
break;
}
}
return From.isSafeToMove(SawStore);
}

static unsigned getSEWLMULRatio(const MachineInstr &MI) {
RISCVII::VLMUL LMUL = RISCVII::getLMul(MI.getDesc().TSFlags);
unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
return RISCVVType::getSEWLMULRatio(1 << Log2SEW, LMUL);
}

/// If a PseudoVMV_V_V is the only user of its input, fold its passthru and VL
/// into it.
///
/// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl1, sew, policy
/// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl2, sew, policy
///
/// ->
///
/// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, min(vl1, vl2), sew, policy
bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V)
return false;

MachineOperand &Passthru = MI.getOperand(1);

if (!MRI->hasOneUse(MI.getOperand(2).getReg()))
lukel97 marked this conversation as resolved.
Show resolved Hide resolved
return false;

MachineInstr *Src = MRI->getVRegDef(MI.getOperand(2).getReg());
if (!Src || Src->hasUnmodeledSideEffects() ||
Src->getParent() != MI.getParent() || Src->getNumDefs() != 1 ||
!RISCVII::isFirstDefTiedToFirstUse(Src->getDesc()) ||
!RISCVII::hasVLOp(Src->getDesc().TSFlags) ||
!RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags))
return false;

// Src needs to have the same VLMAX as MI
if (getSEWLMULRatio(MI) != getSEWLMULRatio(*Src))
return false;

// Src needs to have the same passthru as VMV_V_V
MachineOperand &SrcPassthru = Src->getOperand(1);
if (SrcPassthru.getReg() != RISCV::NoRegister &&
SrcPassthru.getReg() != Passthru.getReg())
return false;

// Because Src and MI have the same passthru, we can use either AVL as long as
// it's the smaller of the two.
//
// (src pt, ..., vl=5) x x x x x|. . .
// (vmv.v.v pt, src, vl=3) x x x|. . . . .
// ->
// (src pt, ..., vl=3) x x x|. . . . .
//
// (src pt, ..., vl=3) x x x|. . . . .
// (vmv.v.v pt, src, vl=6) x x x . . .|. .
// ->
// (src pt, ..., vl=3) x x x|. . . . .
MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
const MachineOperand *MinVL = getKnownMinVL(&MI.getOperand(3), &SrcVL);
if (!MinVL)
return false;

bool VLChanged = !MinVL->isIdenticalTo(SrcVL);
bool ActiveElementsAffectResult = RISCVII::activeElementsAffectResult(
TII->get(RISCV::getRVVMCOpcode(Src->getOpcode())).TSFlags);

if (VLChanged && (ActiveElementsAffectResult || Src->mayRaiseFPException()))
return false;

// If Src ends up using MI's passthru/VL, move it so it can access it.
// TODO: We don't need to do this if they already dominate Src.
if (!SrcVL.isIdenticalTo(*MinVL) || !SrcPassthru.isIdenticalTo(Passthru)) {
if (!isSafeToMove(*Src, MI))
return false;
Src->moveBefore(&MI);
}

if (SrcPassthru.getReg() != Passthru.getReg()) {
SrcPassthru.setReg(Passthru.getReg());
// If Src is masked then its passthru needs to be in VRNoV0.
if (Passthru.getReg() != RISCV::NoRegister)
MRI->constrainRegClass(Passthru.getReg(),
TII->getRegClass(Src->getDesc(), 1, TRI,
*Src->getParent()->getParent()));
}

if (MinVL->isImm())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading through this code, I'm left with a question. If the Src instruction uses a different SEW than the vmv.v.v, why is it legal to reduce the VL without accounting for the different size of the elements? I can't find that check in the DAG version of this code either. Am I forgetting something here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yikes good catch. Looks like we're miscompiling this in the DAG version too, and I don't think it's legal to fold in the mask either. Fix incoming

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed now in this PR, but by checking the VLMAXes are the same since we don't have access to the MVTs here.

SrcVL.ChangeToImmediate(MinVL->getImm());
else if (MinVL->isReg())
SrcVL.ChangeToRegister(MinVL->getReg(), false);

// Use a conservative tu,mu policy, RISCVInsertVSETVLI will relax it if
// passthru is undef.
Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()))
.setImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED);

MRI->replaceRegWith(MI.getOperand(0).getReg(), Src->getOperand(0).getReg());
MI.eraseFromParent();
V0Defs.erase(&MI);

return true;
}

bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
if (skipFunction(MF.getFunction()))
return false;
Expand Down Expand Up @@ -358,11 +496,12 @@ bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
}

for (MachineBasicBlock &MBB : MF) {
for (MachineInstr &MI : MBB) {
for (MachineInstr &MI : make_early_inc_range(MBB)) {
Changed |= convertToVLMAX(MI);
Changed |= convertToUnmasked(MI);
Changed |= convertToWholeRegister(MI);
Changed |= convertVMergeToVMv(MI);
Changed |= foldVMV_V_V(MI);
}
}

Expand Down
Loading