-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV][ISel] Combine scalable vector add/sub/mul with zero/sign extension. #76785
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Chia (sun-jacobi) ChangesThis recreates the #72340 reverted by 4e347b4. Patch is 50.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76785.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 27bb69dc9868c8..2fb79c81b7f169 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1374,8 +1374,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setPrefLoopAlignment(Subtarget.getPrefLoopAlignment());
setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
- ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
- ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+ ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
+ ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
if (Subtarget.is64Bit())
setTargetDAGCombine(ISD::SRA);
@@ -12850,9 +12850,9 @@ struct CombineResult;
/// Helper class for folding sign/zero extensions.
/// In particular, this class is used for the following combines:
-/// add_vl -> vwadd(u) | vwadd(u)_w
-/// sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul_vl -> vwmul(u) | vwmul_su
+/// add | add_vl -> vwadd(u) | vwadd(u)_w
+/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul | mul_vl -> vwmul(u) | vwmul_su
///
/// An object of this class represents an operand of the operation we want to
/// combine.
@@ -12897,6 +12897,8 @@ struct NodeExtensionHelper {
/// E.g., for zext(a), this would return a.
SDValue getSource() const {
switch (OrigOperand.getOpcode()) {
+ case ISD::ZERO_EXTEND:
+ case ISD::SIGN_EXTEND:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
return OrigOperand.getOperand(0);
@@ -12913,7 +12915,8 @@ struct NodeExtensionHelper {
/// Get or create a value that can feed \p Root with the given extension \p
/// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
/// \see ::getSource().
- SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG,
+ SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget,
std::optional<bool> SExt) const {
if (!SExt.has_value())
return OrigOperand;
@@ -12928,8 +12931,10 @@ struct NodeExtensionHelper {
// If we need an extension, we should be changing the type.
SDLoc DL(Root);
- auto [Mask, VL] = getMaskAndVL(Root);
+ auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
switch (OrigOperand.getOpcode()) {
+ case ISD::ZERO_EXTEND:
+ case ISD::SIGN_EXTEND:
case RISCVISD::VSEXT_VL:
case RISCVISD::VZEXT_VL:
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
@@ -12969,12 +12974,15 @@ struct NodeExtensionHelper {
/// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
switch (Opcode) {
+ case ISD::ADD:
case RISCVISD::ADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+ case ISD::MUL:
case RISCVISD::MUL_VL:
return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+ case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
@@ -12987,7 +12995,8 @@ struct NodeExtensionHelper {
/// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
/// newOpcode(a, b).
static unsigned getSUOpcode(unsigned Opcode) {
- assert(Opcode == RISCVISD::MUL_VL && "SU is only supported for MUL");
+ assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) &&
+ "SU is only supported for MUL");
return RISCVISD::VWMULSU_VL;
}
@@ -12995,8 +13004,10 @@ struct NodeExtensionHelper {
/// newOpcode(a, b).
static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
switch (Opcode) {
+ case ISD::ADD:
case RISCVISD::ADD_VL:
return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+ case ISD::SUB:
case RISCVISD::SUB_VL:
return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
default:
@@ -13006,19 +13017,44 @@ struct NodeExtensionHelper {
using CombineToTry = std::function<std::optional<CombineResult>(
SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
- const NodeExtensionHelper & /*RHS*/)>;
+ const NodeExtensionHelper & /*RHS*/, SelectionDAG &,
+ const RISCVSubtarget &)>;
/// Check if this node needs to be fully folded or extended for all users.
bool needToPromoteOtherUsers() const { return EnforceOneUse; }
/// Helper method to set the various fields of this struct based on the
/// type of \p Root.
- void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG) {
+ void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
SupportsZExt = false;
SupportsSExt = false;
EnforceOneUse = true;
CheckMask = true;
- switch (OrigOperand.getOpcode()) {
+ unsigned Opc = OrigOperand.getOpcode();
+ switch (Opc) {
+ case ISD::ZERO_EXTEND:
+ case ISD::SIGN_EXTEND: {
+ MVT VT = OrigOperand.getSimpleValueType();
+ if (!VT.isVector())
+ break;
+
+ MVT NarrowVT = OrigOperand.getOperand(0)->getSimpleValueType(0);
+
+ unsigned ScalarBits = VT.getScalarSizeInBits();
+ unsigned NarrowScalarBits = NarrowVT.getScalarSizeInBits();
+
+ // Ensure the extension's semantic is equivalent to rvv vzext or vsext.
+ if (ScalarBits != NarrowScalarBits * 2)
+ break;
+
+ SupportsZExt = Opc == ISD::ZERO_EXTEND;
+ SupportsSExt = Opc == ISD::SIGN_EXTEND;
+
+ SDLoc DL(Root);
+ std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+ break;
+ }
case RISCVISD::VZEXT_VL:
SupportsZExt = true;
Mask = OrigOperand.getOperand(1);
@@ -13074,8 +13110,16 @@ struct NodeExtensionHelper {
}
/// Check if \p Root supports any extension folding combines.
- static bool isSupportedRoot(const SDNode *Root) {
+ static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
switch (Root->getOpcode()) {
+ case ISD::ADD:
+ case ISD::SUB:
+ case ISD::MUL: {
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ if (!TLI.isTypeLegal(Root->getValueType(0)))
+ return false;
+ return Root->getValueType(0).isScalableVector();
+ }
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
@@ -13090,9 +13134,10 @@ struct NodeExtensionHelper {
}
/// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
- NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG) {
- assert(isSupportedRoot(Root) && "Trying to build an helper with an "
- "unsupported root");
+ NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an "
+ "unsupported root");
assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
OrigOperand = Root->getOperand(OperandIdx);
@@ -13108,7 +13153,7 @@ struct NodeExtensionHelper {
SupportsZExt =
Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
SupportsSExt = !SupportsZExt;
- std::tie(Mask, VL) = getMaskAndVL(Root);
+ std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
CheckMask = true;
// There's no existing extension here, so we don't have to worry about
// making sure it gets removed.
@@ -13117,7 +13162,7 @@ struct NodeExtensionHelper {
}
[[fallthrough]];
default:
- fillUpExtensionSupport(Root, DAG);
+ fillUpExtensionSupport(Root, DAG, Subtarget);
break;
}
}
@@ -13133,14 +13178,27 @@ struct NodeExtensionHelper {
}
/// Helper function to get the Mask and VL from \p Root.
- static std::pair<SDValue, SDValue> getMaskAndVL(const SDNode *Root) {
- assert(isSupportedRoot(Root) && "Unexpected root");
- return std::make_pair(Root->getOperand(3), Root->getOperand(4));
+ static std::pair<SDValue, SDValue>
+ getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ assert(isSupportedRoot(Root, DAG) && "Unexpected root");
+ switch (Root->getOpcode()) {
+ case ISD::ADD:
+ case ISD::SUB:
+ case ISD::MUL: {
+ SDLoc DL(Root);
+ MVT VT = Root->getSimpleValueType(0);
+ return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+ }
+ default:
+ return std::make_pair(Root->getOperand(3), Root->getOperand(4));
+ }
}
/// Check if the Mask and VL of this operand are compatible with \p Root.
- bool areVLAndMaskCompatible(const SDNode *Root) const {
- auto [Mask, VL] = getMaskAndVL(Root);
+ bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) const {
+ auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
return isMaskCompatible(Mask) && isVLCompatible(VL);
}
@@ -13148,11 +13206,14 @@ struct NodeExtensionHelper {
/// foldings that are supported by this class.
static bool isCommutative(const SDNode *N) {
switch (N->getOpcode()) {
+ case ISD::ADD:
+ case ISD::MUL:
case RISCVISD::ADD_VL:
case RISCVISD::MUL_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
return true;
+ case ISD::SUB:
case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
@@ -13197,14 +13258,25 @@ struct CombineResult {
/// Return a value that uses TargetOpcode and that can be used to replace
/// Root.
/// The actual replacement is *not* done in that method.
- SDValue materialize(SelectionDAG &DAG) const {
+ SDValue materialize(SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) const {
SDValue Mask, VL, Merge;
- std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root);
- Merge = Root->getOperand(2);
+ std::tie(Mask, VL) =
+ NodeExtensionHelper::getMaskAndVL(Root, DAG, Subtarget);
+ switch (Root->getOpcode()) {
+ default:
+ Merge = Root->getOperand(2);
+ break;
+ case ISD::ADD:
+ case ISD::SUB:
+ case ISD::MUL:
+ Merge = DAG.getUNDEF(Root->getValueType(0));
+ break;
+ }
return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
- LHS.getOrCreateExtendedOp(Root, DAG, SExtLHS),
- RHS.getOrCreateExtendedOp(Root, DAG, SExtRHS), Merge,
- Mask, VL);
+ LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
+ RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+ Merge, Mask, VL);
}
};
@@ -13221,15 +13293,16 @@ struct CombineResult {
static std::optional<CombineResult>
canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
const NodeExtensionHelper &RHS, bool AllowSExt,
- bool AllowZExt) {
+ bool AllowZExt, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
- if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
+ if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
+ !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
Root->getOpcode(), /*IsSExt=*/false),
- Root, LHS, /*SExtLHS=*/false, RHS,
- /*SExtRHS=*/false);
+ Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
Root->getOpcode(), /*IsSExt=*/true),
@@ -13246,9 +13319,10 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS) {
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*AllowZExt=*/true);
+ /*AllowZExt=*/true, DAG, Subtarget);
}
/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13257,8 +13331,9 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS) {
- if (!RHS.areVLAndMaskCompatible(Root))
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
// FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
@@ -13282,9 +13357,10 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS) {
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
- /*AllowZExt=*/false);
+ /*AllowZExt=*/false, DAG, Subtarget);
}
/// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13293,9 +13369,10 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS) {
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
- /*AllowZExt=*/true);
+ /*AllowZExt=*/true, DAG, Subtarget);
}
/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -13304,10 +13381,13 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
/// can be used to apply the pattern.
static std::optional<CombineResult>
canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
- const NodeExtensionHelper &RHS) {
+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+
if (!LHS.SupportsSExt || !RHS.SupportsZExt)
return std::nullopt;
- if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
+ if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
+ !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
return std::nullopt;
return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
@@ -13317,6 +13397,8 @@ SmallVector<NodeExtensionHelper::CombineToTry>
NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
SmallVector<CombineToTry> Strategies;
switch (Root->getOpcode()) {
+ case ISD::ADD:
+ case ISD::SUB:
case RISCVISD::ADD_VL:
case RISCVISD::SUB_VL:
// add|sub -> vwadd(u)|vwsub(u)
@@ -13324,6 +13406,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
// add|sub -> vwadd(u)_w|vwsub(u)_w
Strategies.push_back(canFoldToVW_W);
break;
+ case ISD::MUL:
case RISCVISD::MUL_VL:
// mul -> vwmul(u)
Strategies.push_back(canFoldToVWWithSameExtension);
@@ -13354,12 +13437,14 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
/// mul_vl -> vwmul(u) | vwmul_su
/// vwadd_w(u) -> vwadd(u)
/// vwub_w(u) -> vwadd(u)
-static SDValue
-combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const RISCVSubtarget &Subtarget) {
SelectionDAG &DAG = DCI.DAG;
- assert(NodeExtensionHelper::isSupportedRoot(N) &&
- "Shouldn't have called this method");
+ if (!NodeExtensionHelper::isSupportedRoot(N, DAG))
+ return SDValue();
+
SmallVector<SDNode *> Worklist;
SmallSet<SDNode *, 8> Inserted;
Worklist.push_back(N);
@@ -13368,11 +13453,11 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
while (!Worklist.empty()) {
SDNode *Root = Worklist.pop_back_val();
- if (!NodeExtensionHelper::isSupportedRoot(Root))
+ if (!NodeExtensionHelper::isSupportedRoot(Root, DAG))
return SDValue();
- NodeExtensionHelper LHS(N, 0, DAG);
- NodeExtensionHelper RHS(N, 1, DAG);
+ NodeExtensionHelper LHS(N, 0, DAG, Subtarget);
+ NodeExtensionHelper RHS(N, 1, DAG, Subtarget);
auto AppendUsersIfNeeded = [&Worklist,
&Inserted](const NodeExtensionHelper &Op) {
if (Op.needToPromoteOtherUsers()) {
@@ -13399,7 +13484,8 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
for (NodeExtensionHelper::CombineToTry FoldingStrategy :
FoldingStrategies) {
- std::optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
+ std::optional<CombineResult> Res =
+ FoldingStrategy(N, LHS, RHS, DAG, Subtarget);
if (Res) {
Matched = true;
CombinesToApply.push_back(*Res);
@@ -13428,7 +13514,7 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
ValuesToReplace.reserve(CombinesToApply.size());
for (CombineResult Res : CombinesToApply) {
- SDValue NewValue = Res.materialize(DAG);
+ SDValue NewValue = Res.materialize(DAG, Subtarget);
if (!InputRootReplacement) {
assert(Res.Root == N &&
"First element is expected to be the current node");
@@ -14700,13 +14786,20 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
- assert(N->getOpcode() == RISCVISD::ADD_VL);
+
+ assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
+
+ if (N->getValueType(0).isFixedLengthVector())
+ return SDValue();
+
SDValue Addend = N->getOperand(0);
SDValue MulOp = N->getOperand(1);
- SDValue AddMergeOp = N->getOperand(2);
- if (!AddMergeOp.isUndef())
- return SDValue();
+ if (N->getOpcode() == RISCVISD::ADD_VL) {
+ SDValue AddMergeOp = N->getOperand(2);
+ if (!AddMergeOp.isUndef())
+ return SDValue();
+ }
auto IsVWMulOpc = [](unsigned Opc) {
switch (Opc) {
@@ -14730,8 +14823,16 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
if (!MulMergeOp.isUndef())
return SDValue();
- SDValue AddMask = N->getOperand(3);
- SDValue AddVL = N->getOperand(4);
+ auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
+ if (N->getOpcode() == ISD::ADD) {
+ SDLoc DL(N);
+ return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
+ Subtarget);
+ }
+ re...
[truncated]
|
Sorry for Ping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Title should describe the whole patch not just the bug fix. It will become the commit title when it is recommitted. |
20f10af
to
fde129a
Compare
…/sign extension." (llvm#76785) This patch was originally introduced in PR llvm#72340, but was reverted due to a bug on invalid extension combine. Specifically, we resolve the case in the llvm#72340 (comment) ``` define <vscale x 1 x i32> @foo(<vscale x 1 x i1> %x, <vscale x 1 x i2> %y) { %a = zext <vscale x 1 x i1> %x to <vscale x 1 x i32> %b = zext <vscale x 1 x i1> %y to <vscale x 1 x i32> %c = add <vscale x 1 x i32> %a, %b ret <vscale x 1 x i32> %c } ``` The previous patch didn't check if the semantic of `ISD::ZERO_EXTEND` and `ISD::ZERO_EXTEND` is equivalent to the `vsext.vf2` or `vzext.vf2` (not ensuring the SEW condition on widening Vector Arithmetic Instructions). Thanks for @topperc pointing out this bug. ## The original description This PR mainly aims at resolving the below missed-optimization case, while it could also be considered as an extension of the previous patch https://reviews.llvm.org/D133739?id= ### Missed-Optimization Case Compiler Explorer: https://godbolt.org/z/GzWzP7Pfh ### Source Code: ``` define <vscale x 2 x i16> @multiple_users(ptr %x, ptr %y, ptr %z) { %a = load <vscale x 2 x i8>, ptr %x %b = load <vscale x 2 x i8>, ptr %y %b2 = load <vscale x 2 x i8>, ptr %z %c = sext <vscale x 2 x i8> %a to <vscale x 2 x i16> %d = sext <vscale x 2 x i8> %b to <vscale x 2 x i16> %d2 = sext <vscale x 2 x i8> %b2 to <vscale x 2 x i16> %e = mul <vscale x 2 x i16> %c, %d %f = add <vscale x 2 x i16> %c, %d2 %g = sub <vscale x 2 x i16> %c, %d2 %h = or <vscale x 2 x i16> %e, %f %i = or <vscale x 2 x i16> %h, %g ret <vscale x 2 x i16> %i } ``` ### Before This Patch ``` # %bb.0: vsetvli a3, zero, e16, mf2, ta, ma vle8.v v8, (a0) vle8.v v9, (a1) vle8.v v10, (a2) svf2 v11, v8 vsext.vf2 v8, v9 vsext.vf2 v9, v10 vmul.vv v8, v11, v8 vadd.vv v10, v11, v9 vsub.vv v9, v11, v9 vor.vv v8, v8, v10 vor.vv v8, v8, v9 ret ``` ### After This Patch ``` # %bb.0: vsetvli a3, zero, e8, mf4, ta, ma vle8.v v8, (a0) vle8.v v9, (a1) vle8.v v10, (a2) vwmul.vv v11, v8, v9 vwadd.vv v9, v8, v10 vwsub.vv v12, v8, v10 vsetvli zero, zero, e16, mf2, ta, ma vor.vv v8, v11, v9 vor.vv v8, v8, v12 ret ``` We can see Add/Sub/Mul are combined with the Sign Extension. ### Relation to the Patch D133739 The patch D133739 introduced an optimization for folding `ADD_VL`/ `SUB_VL` / `MUL_V` with `VSEXT_VL` / `VZEXT_VL`. However, the patch did not consider the case of non-fixed length vector case, thus this PR could also be considered as an extension for the D133739.
…80079) Similar to #78403, but for scalable `vwadd(u).wv`, given that #76785 is recommited. ### Code ``` define <vscale x 8 x i64> @vwadd_wv_mask_v8i32(<vscale x 8 x i32> %x, <vscale x 8 x i64> %y) { %mask = icmp slt <vscale x 8 x i32> %x, shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 42, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) %a = select <vscale x 8 x i1> %mask, <vscale x 8 x i32> %x, <vscale x 8 x i32> zeroinitializer %sa = sext <vscale x 8 x i32> %a to <vscale x 8 x i64> %ret = add <vscale x 8 x i64> %sa, %y ret <vscale x 8 x i64> %ret } ``` ### Before this patch [Compiler Explorer](https://godbolt.org/z/xsoa5xPrd) ``` vwadd_wv_mask_v8i32: li a0, 42 vsetvli a1, zero, e32, m4, ta, ma vmslt.vx v0, v8, a0 vmv.v.i v12, 0 vmerge.vvm v24, v12, v8, v0 vwadd.wv v8, v16, v24 ret ``` ### After this patch ``` vwadd_wv_mask_v8i32: li a0, 42 vsetvli a1, zero, e32, m4, ta, ma vmslt.vx v0, v8, a0 vsetvli zero, zero, e32, m4, tu, mu vwadd.wv v16, v16, v8, v0.t vmv8r.v v8, v16 ret ```
Extend D133739 and #76785 to support vector widening floating-point add/sub/mul instructions. Specifically, this patch works for the below optimization case: ### Source code ``` define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) { %c = fpext <2 x float> %a to <2 x double> %d = fpext <2 x float> %b to <2 x double> %d2 = fpext <2 x float> %b2 to <2 x double> %e = fmul <2 x double> %c, %d %f = fadd <2 x double> %c, %d2 %g = fsub <2 x double> %d, %d2 store <2 x double> %e, ptr %x store <2 x double> %f, ptr %y store <2 x double> %g, ptr %z ret void } ``` ### Before this patch [Compiler Explorer](https://godbolt.org/z/aaEMs5s9h) ``` vfwmul_v2f32_multiple_users: vsetivli zero, 2, e32, mf2, ta, ma vfwcvt.f.f.v v11, v8 vfwcvt.f.f.v v8, v9 vfwcvt.f.f.v v9, v10 vsetvli zero, zero, e64, m1, ta, ma vfmul.vv v10, v11, v8 vfadd.vv v11, v11, v9 vfsub.vv v8, v8, v9 vse64.v v10, (a0) vse64.v v11, (a1) vse64.v v8, (a2) ret ``` ### After this patch ``` vfwmul_v2f32_multiple_users: vsetivli zero, 2, e32, mf2, ta, ma vfwmul.vv v11, v8, v9 vfwadd.vv v12, v8, v10 vfwsub.vv v8, v9, v10 vse64.v v11, (a0) vse64.v v12, (a1) vse64.v v8, (a2) ```
…pes. NFCI I noticed this from a discrepancy in fillUpExtensionSupport between how we apparently need to check for legal types for ISD::{ZERO,SIGN}_EXTEND, but we don't need to for RISCVISD::V{Z,S}EXT_VL. Prior to llvm#72340, combineBinOp_VLToVWBinOp_VL only ran after type legalization because it only operated on _VL nodes. _VL nodes are only emitted during op legalization, which takes place **after** type legalization, which is presumably why the existing code didn't need to check for legal types. After llvm#72340 we now handle generic ops like ISD::ADD that exist before op legalization and thus **before** type legalization. This meant that we needed to add extra checks that the narrow type was legal in llvm#76785. I think the easiest thing to do here is to just maintain the invariant that the types are legal and only run the combine after type legalization.
…pes. NFCI I noticed this from a discrepancy in fillUpExtensionSupport between how we apparently need to check for legal types for ISD::{ZERO,SIGN}_EXTEND, but we don't need to for RISCVISD::V{Z,S}EXT_VL. Prior to llvm#72340, combineBinOp_VLToVWBinOp_VL only ran after type legalization because it only operated on _VL nodes. _VL nodes are only emitted during op legalization, which takes place **after** type legalization, which is presumably why the existing code didn't need to check for legal types. After llvm#72340 we now handle generic ops like ISD::ADD that exist before op legalization and thus **before** type legalization. This meant that we needed to add extra checks that the narrow type was legal in llvm#76785. I think the easiest thing to do here is to just maintain the invariant that the types are legal and only run the combine after type legalization.
…pes. NFCI I noticed this from a discrepancy in fillUpExtensionSupport between how we apparently need to check for legal types for ISD::{ZERO,SIGN}_EXTEND, but we don't need to for RISCVISD::V{Z,S}EXT_VL. Prior to llvm#72340, combineBinOp_VLToVWBinOp_VL only ran after type legalization because it only operated on _VL nodes. _VL nodes are only emitted during op legalization, which takes place **after** type legalization, which is presumably why the existing code didn't need to check for legal types. After llvm#72340 we now handle generic ops like ISD::ADD that exist before op legalization and thus **before** type legalization. This meant that we needed to add extra checks that the narrow type was legal in llvm#76785. I think the easiest thing to do here is to just maintain the invariant that the types are legal and only run the combine after type legalization.
…pes. NFCI (#84125) I noticed this from a discrepancy in fillUpExtensionSupport between how we apparently need to check for legal types for ISD::{ZERO,SIGN}_EXTEND, but we don't need to for RISCVISD::V{Z,S}EXT_VL. Prior to #72340, combineBinOp_VLToVWBinOp_VL only ran after type legalization because it only operated on _VL nodes. _VL nodes are only emitted during op legalization, which takes place **after** type legalization, which is presumably why the existing code didn't need to check for legal types. After #72340 we now handle generic ops like ISD::ADD that exist before op legalization and thus **before** type legalization. This meant that we needed to add extra checks that the narrow type was legal in #76785. I think the easiest thing to do here is to just maintain the invariant that the types are legal and only run the combine after type legalization.
This patch was originally introduced in PR #72340, but was reverted due to a bug on invalid extension combine.
Specifically, we resolve the case in the #72340 (comment)
The previous patch didn't check if the semantic of
ISD::ZERO_EXTEND
andISD::ZERO_EXTEND
is equivalent to thevsext.vf2
orvzext.vf2
(not ensuring the SEW condition on widening Vector Arithmetic Instructions).Thanks for @topperc pointing out this bug.
The original description
This PR mainly aims at resolving the below missed-optimization case, while it could also be considered as an extension of the previous patch https://reviews.llvm.org/D133739?id=
Missed-Optimization Case
Compiler Explorer: https://godbolt.org/z/GzWzP7Pfh
Source Code:
Before This Patch
After This Patch
We can see Add/Sub/Mul are combined with the Sign Extension.
Relation to the Patch D133739
The patch D133739 introduced an optimization for folding
ADD_VL
/SUB_VL
/MUL_V
withVSEXT_VL
/VZEXT_VL
. However, the patch did not consider the case of non-fixed length vector case, thus this PR could also be considered as an extension for the D133739.