Skip to content

Commit

Permalink
[DAG][RISCV] Use vp_reduce_* when widening illegal types for reductio…
Browse files Browse the repository at this point in the history
…ns (#105455)

This allows the use a single wider operation with a restricted EVL
instead of padding the vector with the neutral element.

For RISCV specifically, it's worth noting that an alternate padded
lowering is available when VL is one less than a power of two, and LMUL
<= m1. We could slide the vector operand up by one, and insert the
padding via a vslide1up. We don't currently pattern match this, but we
could. This form would arguably be better iff the surrounding code
wanted VL=4. This patch will force a VL toggle in that case instead.

Basically, it comes down to a question of whether we think odd sized
vectors are going to appear clustered with odd size vector operations,
or mixed in with larger power of two operations.

Note there is a potential downside of using vp nodes; we loose any
generic DAG combines which might have applied to the widened form.
  • Loading branch information
preames authored Aug 22, 2024
1 parent 41dcdfb commit 00baa1a
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 202 deletions.
29 changes: 15 additions & 14 deletions llvm/include/llvm/IR/VPIntrinsics.def
Original file line number Diff line number Diff line change
Expand Up @@ -651,63 +651,64 @@ END_REGISTER_VP(vp_gather, VP_GATHER)
#error \
"The internal helper macro HELPER_REGISTER_REDUCTION_VP is already defined!"
#endif
#define HELPER_REGISTER_REDUCTION_VP(VPID, VPSD, INTRIN) \
#define HELPER_REGISTER_REDUCTION_VP(VPID, VPSD, INTRIN, SDOPC) \
BEGIN_REGISTER_VP(VPID, 2, 3, VPSD, 1) \
VP_PROPERTY_FUNCTIONAL_INTRINSIC(INTRIN) \
VP_PROPERTY_FUNCTIONAL_SDOPC(SDOPC) \
VP_PROPERTY_REDUCTION(0, 1) \
END_REGISTER_VP(VPID, VPSD)

// llvm.vp.reduce.add(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_add, VP_REDUCE_ADD,
vector_reduce_add)
vector_reduce_add, VECREDUCE_ADD)

// llvm.vp.reduce.mul(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_mul, VP_REDUCE_MUL,
vector_reduce_mul)
vector_reduce_mul, VECREDUCE_MUL)

// llvm.vp.reduce.and(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_and, VP_REDUCE_AND,
vector_reduce_and)
vector_reduce_and, VECREDUCE_AND)

// llvm.vp.reduce.or(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_or, VP_REDUCE_OR,
vector_reduce_or)
vector_reduce_or, VECREDUCE_OR)

// llvm.vp.reduce.xor(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_xor, VP_REDUCE_XOR,
vector_reduce_xor)
vector_reduce_xor, VECREDUCE_XOR)

// llvm.vp.reduce.smax(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_smax, VP_REDUCE_SMAX,
vector_reduce_smax)
vector_reduce_smax, VECREDUCE_SMAX)

// llvm.vp.reduce.smin(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_smin, VP_REDUCE_SMIN,
vector_reduce_smin)
vector_reduce_smin, VECREDUCE_SMIN)

// llvm.vp.reduce.umax(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_umax, VP_REDUCE_UMAX,
vector_reduce_umax)
vector_reduce_umax, VECREDUCE_UMAX)

// llvm.vp.reduce.umin(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_umin, VP_REDUCE_UMIN,
vector_reduce_umin)
vector_reduce_umin, VECREDUCE_UMIN)

// llvm.vp.reduce.fmax(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_fmax, VP_REDUCE_FMAX,
vector_reduce_fmax)
vector_reduce_fmax, VECREDUCE_FMAX)

// llvm.vp.reduce.fmin(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_fmin, VP_REDUCE_FMIN,
vector_reduce_fmin)
vector_reduce_fmin, VECREDUCE_FMIN)

// llvm.vp.reduce.fmaximum(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_fmaximum, VP_REDUCE_FMAXIMUM,
vector_reduce_fmaximum)
vector_reduce_fmaximum, VECREDUCE_FMAXIMUM)

// llvm.vp.reduce.fminimum(start,x,mask,vlen)
HELPER_REGISTER_REDUCTION_VP(vp_reduce_fminimum, VP_REDUCE_FMINIMUM,
vector_reduce_fminimum)
vector_reduce_fminimum, VECREDUCE_FMINIMUM)

#undef HELPER_REGISTER_REDUCTION_VP

Expand Down
43 changes: 41 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7271,9 +7271,29 @@ SDValue DAGTypeLegalizer::WidenVecOp_STRICT_FSETCC(SDNode *N) {
return DAG.getBuildVector(VT, dl, Scalars);
}

static unsigned getExtendForIntVecReduction(unsigned Opc) {
switch (Opc) {
default:
llvm_unreachable("Expected integer vector reduction");
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_MUL:
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
return ISD::ANY_EXTEND;
case ISD::VECREDUCE_SMAX:
case ISD::VECREDUCE_SMIN:
return ISD::SIGN_EXTEND;
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN:
return ISD::ZERO_EXTEND;
}
}

SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE(SDNode *N) {
SDLoc dl(N);
SDValue Op = GetWidenedVector(N->getOperand(0));
EVT VT = N->getValueType(0);
EVT OrigVT = N->getOperand(0).getValueType();
EVT WideVT = Op.getValueType();
EVT ElemVT = OrigVT.getVectorElementType();
Expand All @@ -7288,6 +7308,25 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE(SDNode *N) {
unsigned OrigElts = OrigVT.getVectorMinNumElements();
unsigned WideElts = WideVT.getVectorMinNumElements();

// Generate a vp.reduce_op if it is custom/legal for the target. This avoids
// needing to pad the source vector, because the inactive lanes can simply be
// disabled and not contribute to the result.
// TODO: VECREDUCE_FADD, VECREDUCE_FMUL aren't currently mapped correctly,
// and thus don't take this path.
if (auto VPOpcode = ISD::getVPForBaseOpcode(Opc);
VPOpcode && TLI.isOperationLegalOrCustom(*VPOpcode, WideVT)) {
SDValue Start = NeutralElem;
if (VT.isInteger())
Start = DAG.getNode(getExtendForIntVecReduction(Opc), dl, VT, Start);
assert(Start.getValueType() == VT);
EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
WideVT.getVectorElementCount());
SDValue Mask = DAG.getAllOnesConstant(dl, WideMaskVT);
SDValue EVL = DAG.getElementCount(dl, TLI.getVPExplicitVectorLengthTy(),
OrigVT.getVectorElementCount());
return DAG.getNode(*VPOpcode, dl, VT, {Start, Op, Mask, EVL}, Flags);
}

if (WideVT.isScalableVector()) {
unsigned GCD = std::gcd(OrigElts, WideElts);
EVT SplatVT = EVT::getVectorVT(*DAG.getContext(), ElemVT,
Expand All @@ -7296,14 +7335,14 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE(SDNode *N) {
for (unsigned Idx = OrigElts; Idx < WideElts; Idx = Idx + GCD)
Op = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVT, Op, SplatNeutral,
DAG.getVectorIdxConstant(Idx, dl));
return DAG.getNode(Opc, dl, N->getValueType(0), Op, Flags);
return DAG.getNode(Opc, dl, VT, Op, Flags);
}

for (unsigned Idx = OrigElts; Idx < WideElts; Idx++)
Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, WideVT, Op, NeutralElem,
DAG.getVectorIdxConstant(Idx, dl));

return DAG.getNode(Opc, dl, N->getValueType(0), Op, Flags);
return DAG.getNode(Opc, dl, VT, Op, Flags);
}

SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
Expand Down
Loading

0 comments on commit 00baa1a

Please sign in to comment.