Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][ISel] Combine scalable vector add/sub/mul with zero/sign extension. #76785

Merged
merged 1 commit into from
Jan 18, 2024

Conversation

sun-jacobi
Copy link
Member

@sun-jacobi sun-jacobi commented Jan 3, 2024

This patch was originally introduced in PR #72340, but was reverted due to a bug on invalid extension combine.

Specifically, we resolve the case in the #72340 (comment)

define <vscale x 1 x i32> @foo(<vscale x 1 x i1> %x, <vscale x 1 x i2> %y) {     
  %a = zext <vscale x 1 x i1> %x to <vscale x 1 x i32>                           
  %b = zext <vscale x 1 x i1> %y to <vscale x 1 x i32>                           
  %c = add <vscale x 1 x i32> %a, %b                                             
  ret <vscale x 1 x i32> %c                                                      
}

The previous patch didn't check if the semantic of ISD::ZERO_EXTEND and ISD::ZERO_EXTEND is equivalent to the vsext.vf2 or vzext.vf2 (not ensuring the SEW condition on widening Vector Arithmetic Instructions).

Thanks for @topperc pointing out this bug.

The original description

This PR mainly aims at resolving the below missed-optimization case, while it could also be considered as an extension of the previous patch https://reviews.llvm.org/D133739?id=

Missed-Optimization Case

Compiler Explorer: https://godbolt.org/z/GzWzP7Pfh

Source Code:

define <vscale x 2 x i16> @multiple_users(ptr  %x, ptr  %y, ptr %z) {
  %a = load <vscale x 2 x i8>, ptr %x
  %b = load <vscale x 2 x i8>, ptr %y
  %b2 = load <vscale x 2 x i8>, ptr %z
  %c = sext <vscale x 2 x i8> %a to <vscale x 2 x i16>
  %d = sext <vscale x 2 x i8> %b to <vscale x 2 x i16>
  %d2 = sext <vscale x 2 x i8> %b2 to <vscale x 2 x i16>
  %e = mul <vscale x 2 x i16> %c, %d
  %f = add <vscale x 2 x i16> %c, %d2
  %g = sub <vscale x 2 x i16> %c, %d2
  %h = or <vscale x 2 x i16> %e, %f
  %i = or <vscale x 2 x i16> %h, %g
  ret <vscale x 2 x i16> %i
}

Before This Patch

# %bb.0:
        vsetvli a3, zero, e16, mf2, ta, ma
        vle8.v  v8, (a0)
        vle8.v  v9, (a1)
        vle8.v  v10, (a2)
        svf2       v11, v8
        vsext.vf2       v8, v9
        vsext.vf2       v9, v10
        vmul.vv v8, v11, v8
        vadd.vv v10, v11, v9
        vsub.vv v9, v11, v9
        vor.vv  v8, v8, v10
        vor.vv  v8, v8, v9
        ret

After This Patch

# %bb.0:
	vsetvli	a3, zero, e8, mf4, ta, ma
	vle8.v	v8, (a0)
	vle8.v	v9, (a1)
	vle8.v	v10, (a2)
	vwmul.vv	v11, v8, v9
	vwadd.vv	v9, v8, v10
	vwsub.vv	v12, v8, v10
	vsetvli	zero, zero, e16, mf2, ta, ma
	vor.vv	v8, v11, v9
	vor.vv	v8, v8, v12
	ret

We can see Add/Sub/Mul are combined with the Sign Extension.

Relation to the Patch D133739

The patch D133739 introduced an optimization for folding ADD_VL/ SUB_VL / MUL_V with VSEXT_VL / VZEXT_VL. However, the patch did not consider the case of non-fixed length vector case, thus this PR could also be considered as an extension for the D133739.

@llvmbot
Copy link
Member

llvmbot commented Jan 3, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Chia (sun-jacobi)

Changes

This recreates the #72340 reverted by 4e347b4.


Patch is 50.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76785.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+172-61)
  • (modified) llvm/test/CodeGen/RISCV/rvv/ctlz-sdnode.ll (+68-60)
  • (modified) llvm/test/CodeGen/RISCV/rvv/vscale-vw-web-simplification.ll (+402-34)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 27bb69dc9868c8..2fb79c81b7f169 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1374,8 +1374,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
   setPrefLoopAlignment(Subtarget.getPrefLoopAlignment());
 
   setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
-                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND,
-                       ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
+                       ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::MUL,
+                       ISD::AND, ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT});
   if (Subtarget.is64Bit())
     setTargetDAGCombine(ISD::SRA);
 
@@ -12850,9 +12850,9 @@ struct CombineResult;
 
 /// Helper class for folding sign/zero extensions.
 /// In particular, this class is used for the following combines:
-/// add_vl -> vwadd(u) | vwadd(u)_w
-/// sub_vl -> vwsub(u) | vwsub(u)_w
-/// mul_vl -> vwmul(u) | vwmul_su
+/// add | add_vl -> vwadd(u) | vwadd(u)_w
+/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
+/// mul | mul_vl -> vwmul(u) | vwmul_su
 ///
 /// An object of this class represents an operand of the operation we want to
 /// combine.
@@ -12897,6 +12897,8 @@ struct NodeExtensionHelper {
   /// E.g., for zext(a), this would return a.
   SDValue getSource() const {
     switch (OrigOperand.getOpcode()) {
+    case ISD::ZERO_EXTEND:
+    case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
       return OrigOperand.getOperand(0);
@@ -12913,7 +12915,8 @@ struct NodeExtensionHelper {
   /// Get or create a value that can feed \p Root with the given extension \p
   /// SExt. If \p SExt is std::nullopt, this returns the source of this operand.
   /// \see ::getSource().
-  SDValue getOrCreateExtendedOp(const SDNode *Root, SelectionDAG &DAG,
+  SDValue getOrCreateExtendedOp(SDNode *Root, SelectionDAG &DAG,
+                                const RISCVSubtarget &Subtarget,
                                 std::optional<bool> SExt) const {
     if (!SExt.has_value())
       return OrigOperand;
@@ -12928,8 +12931,10 @@ struct NodeExtensionHelper {
 
     // If we need an extension, we should be changing the type.
     SDLoc DL(Root);
-    auto [Mask, VL] = getMaskAndVL(Root);
+    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
     switch (OrigOperand.getOpcode()) {
+    case ISD::ZERO_EXTEND:
+    case ISD::SIGN_EXTEND:
     case RISCVISD::VSEXT_VL:
     case RISCVISD::VZEXT_VL:
       return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
@@ -12969,12 +12974,15 @@ struct NodeExtensionHelper {
   /// \pre \p Opcode represents a supported root (\see ::isSupportedRoot()).
   static unsigned getSameExtensionOpcode(unsigned Opcode, bool IsSExt) {
     switch (Opcode) {
+    case ISD::ADD:
     case RISCVISD::ADD_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
       return IsSExt ? RISCVISD::VWADD_VL : RISCVISD::VWADDU_VL;
+    case ISD::MUL:
     case RISCVISD::MUL_VL:
       return IsSExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
+    case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
@@ -12987,7 +12995,8 @@ struct NodeExtensionHelper {
   /// Get the opcode to materialize \p Opcode(sext(a), zext(b)) ->
   /// newOpcode(a, b).
   static unsigned getSUOpcode(unsigned Opcode) {
-    assert(Opcode == RISCVISD::MUL_VL && "SU is only supported for MUL");
+    assert((Opcode == RISCVISD::MUL_VL || Opcode == ISD::MUL) &&
+           "SU is only supported for MUL");
     return RISCVISD::VWMULSU_VL;
   }
 
@@ -12995,8 +13004,10 @@ struct NodeExtensionHelper {
   /// newOpcode(a, b).
   static unsigned getWOpcode(unsigned Opcode, bool IsSExt) {
     switch (Opcode) {
+    case ISD::ADD:
     case RISCVISD::ADD_VL:
       return IsSExt ? RISCVISD::VWADD_W_VL : RISCVISD::VWADDU_W_VL;
+    case ISD::SUB:
     case RISCVISD::SUB_VL:
       return IsSExt ? RISCVISD::VWSUB_W_VL : RISCVISD::VWSUBU_W_VL;
     default:
@@ -13006,19 +13017,44 @@ struct NodeExtensionHelper {
 
   using CombineToTry = std::function<std::optional<CombineResult>(
       SDNode * /*Root*/, const NodeExtensionHelper & /*LHS*/,
-      const NodeExtensionHelper & /*RHS*/)>;
+      const NodeExtensionHelper & /*RHS*/, SelectionDAG &,
+      const RISCVSubtarget &)>;
 
   /// Check if this node needs to be fully folded or extended for all users.
   bool needToPromoteOtherUsers() const { return EnforceOneUse; }
 
   /// Helper method to set the various fields of this struct based on the
   /// type of \p Root.
-  void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG) {
+  void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
+                              const RISCVSubtarget &Subtarget) {
     SupportsZExt = false;
     SupportsSExt = false;
     EnforceOneUse = true;
     CheckMask = true;
-    switch (OrigOperand.getOpcode()) {
+    unsigned Opc = OrigOperand.getOpcode();
+    switch (Opc) {
+    case ISD::ZERO_EXTEND:
+    case ISD::SIGN_EXTEND: {
+      MVT VT = OrigOperand.getSimpleValueType();
+      if (!VT.isVector())
+        break;
+
+      MVT NarrowVT = OrigOperand.getOperand(0)->getSimpleValueType(0);
+
+      unsigned ScalarBits = VT.getScalarSizeInBits();
+      unsigned NarrowScalarBits = NarrowVT.getScalarSizeInBits();
+
+      // Ensure the extension's semantic is equivalent to rvv vzext or vsext.
+      if (ScalarBits != NarrowScalarBits * 2)
+        break;
+
+      SupportsZExt = Opc == ISD::ZERO_EXTEND;
+      SupportsSExt = Opc == ISD::SIGN_EXTEND;
+
+      SDLoc DL(Root);
+      std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+      break;
+    }
     case RISCVISD::VZEXT_VL:
       SupportsZExt = true;
       Mask = OrigOperand.getOperand(1);
@@ -13074,8 +13110,16 @@ struct NodeExtensionHelper {
   }
 
   /// Check if \p Root supports any extension folding combines.
-  static bool isSupportedRoot(const SDNode *Root) {
+  static bool isSupportedRoot(const SDNode *Root, const SelectionDAG &DAG) {
     switch (Root->getOpcode()) {
+    case ISD::ADD:
+    case ISD::SUB:
+    case ISD::MUL: {
+      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+      if (!TLI.isTypeLegal(Root->getValueType(0)))
+        return false;
+      return Root->getValueType(0).isScalableVector();
+    }
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
@@ -13090,9 +13134,10 @@ struct NodeExtensionHelper {
   }
 
   /// Build a NodeExtensionHelper for \p Root.getOperand(\p OperandIdx).
-  NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG) {
-    assert(isSupportedRoot(Root) && "Trying to build an helper with an "
-                                    "unsupported root");
+  NodeExtensionHelper(SDNode *Root, unsigned OperandIdx, SelectionDAG &DAG,
+                      const RISCVSubtarget &Subtarget) {
+    assert(isSupportedRoot(Root, DAG) && "Trying to build an helper with an "
+                                         "unsupported root");
     assert(OperandIdx < 2 && "Requesting something else than LHS or RHS");
     OrigOperand = Root->getOperand(OperandIdx);
 
@@ -13108,7 +13153,7 @@ struct NodeExtensionHelper {
         SupportsZExt =
             Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
         SupportsSExt = !SupportsZExt;
-        std::tie(Mask, VL) = getMaskAndVL(Root);
+        std::tie(Mask, VL) = getMaskAndVL(Root, DAG, Subtarget);
         CheckMask = true;
         // There's no existing extension here, so we don't have to worry about
         // making sure it gets removed.
@@ -13117,7 +13162,7 @@ struct NodeExtensionHelper {
       }
       [[fallthrough]];
     default:
-      fillUpExtensionSupport(Root, DAG);
+      fillUpExtensionSupport(Root, DAG, Subtarget);
       break;
     }
   }
@@ -13133,14 +13178,27 @@ struct NodeExtensionHelper {
   }
 
   /// Helper function to get the Mask and VL from \p Root.
-  static std::pair<SDValue, SDValue> getMaskAndVL(const SDNode *Root) {
-    assert(isSupportedRoot(Root) && "Unexpected root");
-    return std::make_pair(Root->getOperand(3), Root->getOperand(4));
+  static std::pair<SDValue, SDValue>
+  getMaskAndVL(const SDNode *Root, SelectionDAG &DAG,
+               const RISCVSubtarget &Subtarget) {
+    assert(isSupportedRoot(Root, DAG) && "Unexpected root");
+    switch (Root->getOpcode()) {
+    case ISD::ADD:
+    case ISD::SUB:
+    case ISD::MUL: {
+      SDLoc DL(Root);
+      MVT VT = Root->getSimpleValueType(0);
+      return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
+    }
+    default:
+      return std::make_pair(Root->getOperand(3), Root->getOperand(4));
+    }
   }
 
   /// Check if the Mask and VL of this operand are compatible with \p Root.
-  bool areVLAndMaskCompatible(const SDNode *Root) const {
-    auto [Mask, VL] = getMaskAndVL(Root);
+  bool areVLAndMaskCompatible(SDNode *Root, SelectionDAG &DAG,
+                              const RISCVSubtarget &Subtarget) const {
+    auto [Mask, VL] = getMaskAndVL(Root, DAG, Subtarget);
     return isMaskCompatible(Mask) && isVLCompatible(VL);
   }
 
@@ -13148,11 +13206,14 @@ struct NodeExtensionHelper {
   /// foldings that are supported by this class.
   static bool isCommutative(const SDNode *N) {
     switch (N->getOpcode()) {
+    case ISD::ADD:
+    case ISD::MUL:
     case RISCVISD::ADD_VL:
     case RISCVISD::MUL_VL:
     case RISCVISD::VWADD_W_VL:
     case RISCVISD::VWADDU_W_VL:
       return true;
+    case ISD::SUB:
     case RISCVISD::SUB_VL:
     case RISCVISD::VWSUB_W_VL:
     case RISCVISD::VWSUBU_W_VL:
@@ -13197,14 +13258,25 @@ struct CombineResult {
   /// Return a value that uses TargetOpcode and that can be used to replace
   /// Root.
   /// The actual replacement is *not* done in that method.
-  SDValue materialize(SelectionDAG &DAG) const {
+  SDValue materialize(SelectionDAG &DAG,
+                      const RISCVSubtarget &Subtarget) const {
     SDValue Mask, VL, Merge;
-    std::tie(Mask, VL) = NodeExtensionHelper::getMaskAndVL(Root);
-    Merge = Root->getOperand(2);
+    std::tie(Mask, VL) =
+        NodeExtensionHelper::getMaskAndVL(Root, DAG, Subtarget);
+    switch (Root->getOpcode()) {
+    default:
+      Merge = Root->getOperand(2);
+      break;
+    case ISD::ADD:
+    case ISD::SUB:
+    case ISD::MUL:
+      Merge = DAG.getUNDEF(Root->getValueType(0));
+      break;
+    }
     return DAG.getNode(TargetOpcode, SDLoc(Root), Root->getValueType(0),
-                       LHS.getOrCreateExtendedOp(Root, DAG, SExtLHS),
-                       RHS.getOrCreateExtendedOp(Root, DAG, SExtRHS), Merge,
-                       Mask, VL);
+                       LHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtLHS),
+                       RHS.getOrCreateExtendedOp(Root, DAG, Subtarget, SExtRHS),
+                       Merge, Mask, VL);
   }
 };
 
@@ -13221,15 +13293,16 @@ struct CombineResult {
 static std::optional<CombineResult>
 canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
                                  const NodeExtensionHelper &RHS, bool AllowSExt,
-                                 bool AllowZExt) {
+                                 bool AllowZExt, SelectionDAG &DAG,
+                                 const RISCVSubtarget &Subtarget) {
   assert((AllowSExt || AllowZExt) && "Forgot to set what you want?");
-  if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
+  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
+      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   if (AllowZExt && LHS.SupportsZExt && RHS.SupportsZExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), /*IsSExt=*/false),
-                         Root, LHS, /*SExtLHS=*/false, RHS,
-                         /*SExtRHS=*/false);
+                         Root, LHS, /*SExtLHS=*/false, RHS, /*SExtRHS=*/false);
   if (AllowSExt && LHS.SupportsSExt && RHS.SupportsSExt)
     return CombineResult(NodeExtensionHelper::getSameExtensionOpcode(
                              Root->getOpcode(), /*IsSExt=*/true),
@@ -13246,9 +13319,10 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
-                             const NodeExtensionHelper &RHS) {
+                             const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                             const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/true);
+                                          /*AllowZExt=*/true, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -13257,8 +13331,9 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
-              const NodeExtensionHelper &RHS) {
-  if (!RHS.areVLAndMaskCompatible(Root))
+              const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+              const RISCVSubtarget &Subtarget) {
+  if (!RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
 
   // FIXME: Is it useful to form a vwadd.wx or vwsub.wx if it removes a scalar
@@ -13282,9 +13357,10 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
-                    const NodeExtensionHelper &RHS) {
+                    const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                    const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/true,
-                                          /*AllowZExt=*/false);
+                                          /*AllowZExt=*/false, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
@@ -13293,9 +13369,10 @@ canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
-                    const NodeExtensionHelper &RHS) {
+                    const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+                    const RISCVSubtarget &Subtarget) {
   return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, /*AllowSExt=*/false,
-                                          /*AllowZExt=*/true);
+                                          /*AllowZExt=*/true, DAG, Subtarget);
 }
 
 /// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -13304,10 +13381,13 @@ canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
 /// can be used to apply the pattern.
 static std::optional<CombineResult>
 canFoldToVW_SU(SDNode *Root, const NodeExtensionHelper &LHS,
-               const NodeExtensionHelper &RHS) {
+               const NodeExtensionHelper &RHS, SelectionDAG &DAG,
+               const RISCVSubtarget &Subtarget) {
+
   if (!LHS.SupportsSExt || !RHS.SupportsZExt)
     return std::nullopt;
-  if (!LHS.areVLAndMaskCompatible(Root) || !RHS.areVLAndMaskCompatible(Root))
+  if (!LHS.areVLAndMaskCompatible(Root, DAG, Subtarget) ||
+      !RHS.areVLAndMaskCompatible(Root, DAG, Subtarget))
     return std::nullopt;
   return CombineResult(NodeExtensionHelper::getSUOpcode(Root->getOpcode()),
                        Root, LHS, /*SExtLHS=*/true, RHS, /*SExtRHS=*/false);
@@ -13317,6 +13397,8 @@ SmallVector<NodeExtensionHelper::CombineToTry>
 NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
   SmallVector<CombineToTry> Strategies;
   switch (Root->getOpcode()) {
+  case ISD::ADD:
+  case ISD::SUB:
   case RISCVISD::ADD_VL:
   case RISCVISD::SUB_VL:
     // add|sub -> vwadd(u)|vwsub(u)
@@ -13324,6 +13406,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
     // add|sub -> vwadd(u)_w|vwsub(u)_w
     Strategies.push_back(canFoldToVW_W);
     break;
+  case ISD::MUL:
   case RISCVISD::MUL_VL:
     // mul -> vwmul(u)
     Strategies.push_back(canFoldToVWWithSameExtension);
@@ -13354,12 +13437,14 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 /// mul_vl -> vwmul(u) | vwmul_su
 /// vwadd_w(u) -> vwadd(u)
 /// vwub_w(u) -> vwadd(u)
-static SDValue
-combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
+                                           TargetLowering::DAGCombinerInfo &DCI,
+                                           const RISCVSubtarget &Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
 
-  assert(NodeExtensionHelper::isSupportedRoot(N) &&
-         "Shouldn't have called this method");
+  if (!NodeExtensionHelper::isSupportedRoot(N, DAG))
+    return SDValue();
+
   SmallVector<SDNode *> Worklist;
   SmallSet<SDNode *, 8> Inserted;
   Worklist.push_back(N);
@@ -13368,11 +13453,11 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
 
   while (!Worklist.empty()) {
     SDNode *Root = Worklist.pop_back_val();
-    if (!NodeExtensionHelper::isSupportedRoot(Root))
+    if (!NodeExtensionHelper::isSupportedRoot(Root, DAG))
       return SDValue();
 
-    NodeExtensionHelper LHS(N, 0, DAG);
-    NodeExtensionHelper RHS(N, 1, DAG);
+    NodeExtensionHelper LHS(N, 0, DAG, Subtarget);
+    NodeExtensionHelper RHS(N, 1, DAG, Subtarget);
     auto AppendUsersIfNeeded = [&Worklist,
                                 &Inserted](const NodeExtensionHelper &Op) {
       if (Op.needToPromoteOtherUsers()) {
@@ -13399,7 +13484,8 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
 
       for (NodeExtensionHelper::CombineToTry FoldingStrategy :
            FoldingStrategies) {
-        std::optional<CombineResult> Res = FoldingStrategy(N, LHS, RHS);
+        std::optional<CombineResult> Res =
+            FoldingStrategy(N, LHS, RHS, DAG, Subtarget);
         if (Res) {
           Matched = true;
           CombinesToApply.push_back(*Res);
@@ -13428,7 +13514,7 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   SmallVector<std::pair<SDValue, SDValue>> ValuesToReplace;
   ValuesToReplace.reserve(CombinesToApply.size());
   for (CombineResult Res : CombinesToApply) {
-    SDValue NewValue = Res.materialize(DAG);
+    SDValue NewValue = Res.materialize(DAG, Subtarget);
     if (!InputRootReplacement) {
       assert(Res.Root == N &&
              "First element is expected to be the current node");
@@ -14700,13 +14786,20 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
 
 static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
                                const RISCVSubtarget &Subtarget) {
-  assert(N->getOpcode() == RISCVISD::ADD_VL);
+
+  assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
+
+  if (N->getValueType(0).isFixedLengthVector())
+    return SDValue();
+
   SDValue Addend = N->getOperand(0);
   SDValue MulOp = N->getOperand(1);
-  SDValue AddMergeOp = N->getOperand(2);
 
-  if (!AddMergeOp.isUndef())
-    return SDValue();
+  if (N->getOpcode() == RISCVISD::ADD_VL) {
+    SDValue AddMergeOp = N->getOperand(2);
+    if (!AddMergeOp.isUndef())
+      return SDValue();
+  }
 
   auto IsVWMulOpc = [](unsigned Opc) {
     switch (Opc) {
@@ -14730,8 +14823,16 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
   if (!MulMergeOp.isUndef())
     return SDValue();
 
-  SDValue AddMask = N->getOperand(3);
-  SDValue AddVL = N->getOperand(4);
+  auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
+                             const RISCVSubtarget &Subtarget) {
+    if (N->getOpcode() == ISD::ADD) {
+      SDLoc DL(N);
+      return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
+                                     Subtarget);
+    }
+    re...
[truncated]

@sun-jacobi sun-jacobi changed the title Recreate "[RISCV][ISel] Combine scalable vector add/sub/mul with zero/sign extension. (#72340)" [RISCV][ISel] fix bug on invalid extension combine in#72340 Jan 3, 2024
@sun-jacobi sun-jacobi changed the title [RISCV][ISel] fix bug on invalid extension combine in#72340 [RISCV][ISel] Fix bug on invalid extension combine in #72340 Jan 3, 2024
@sun-jacobi sun-jacobi requested a review from topperc January 3, 2024 07:31
@sun-jacobi
Copy link
Member Author

Sorry for Ping.

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@topperc
Copy link
Collaborator

topperc commented Jan 13, 2024

Title should describe the whole patch not just the bug fix. It will become the commit title when it is recommitted.

@sun-jacobi sun-jacobi changed the title [RISCV][ISel] Fix bug on invalid extension combine in #72340 [RISCV][ISel] Combine scalable vector add/sub/mul with zero/sign extension. Jan 14, 2024
@sun-jacobi sun-jacobi force-pushed the vscale-add-ext-combine branch from 20f10af to fde129a Compare January 14, 2024 08:47
@topperc topperc merged commit ba81477 into llvm:main Jan 18, 2024
4 checks passed
ampandey-1995 pushed a commit to ampandey-1995/llvm-project that referenced this pull request Jan 19, 2024
…/sign extension." (llvm#76785)

This patch was originally introduced in PR llvm#72340, but was reverted due
to a bug on invalid extension combine.

Specifically, we resolve the case in the
llvm#72340 (comment)

```
define <vscale x 1 x i32> @foo(<vscale x 1 x i1> %x, <vscale x 1 x i2> %y) {     
  %a = zext <vscale x 1 x i1> %x to <vscale x 1 x i32>                           
  %b = zext <vscale x 1 x i1> %y to <vscale x 1 x i32>                           
  %c = add <vscale x 1 x i32> %a, %b                                             
  ret <vscale x 1 x i32> %c                                                      
}
```
The previous patch didn't check if the semantic of `ISD::ZERO_EXTEND`
and `ISD::ZERO_EXTEND` is equivalent to the `vsext.vf2` or `vzext.vf2`
(not ensuring the SEW condition on widening Vector Arithmetic
Instructions).

Thanks for @topperc pointing out this bug.

## The original description 
This PR mainly aims at resolving the below missed-optimization case,
while it could also be considered as an extension of the previous patch
https://reviews.llvm.org/D133739?id=

### Missed-Optimization Case
Compiler Explorer: https://godbolt.org/z/GzWzP7Pfh

### Source Code: 
```
define <vscale x 2 x i16> @multiple_users(ptr  %x, ptr  %y, ptr %z) {
  %a = load <vscale x 2 x i8>, ptr %x
  %b = load <vscale x 2 x i8>, ptr %y
  %b2 = load <vscale x 2 x i8>, ptr %z
  %c = sext <vscale x 2 x i8> %a to <vscale x 2 x i16>
  %d = sext <vscale x 2 x i8> %b to <vscale x 2 x i16>
  %d2 = sext <vscale x 2 x i8> %b2 to <vscale x 2 x i16>
  %e = mul <vscale x 2 x i16> %c, %d
  %f = add <vscale x 2 x i16> %c, %d2
  %g = sub <vscale x 2 x i16> %c, %d2
  %h = or <vscale x 2 x i16> %e, %f
  %i = or <vscale x 2 x i16> %h, %g
  ret <vscale x 2 x i16> %i
}
```
### Before This Patch
```
# %bb.0:
        vsetvli a3, zero, e16, mf2, ta, ma
        vle8.v  v8, (a0)
        vle8.v  v9, (a1)
        vle8.v  v10, (a2)
        svf2       v11, v8
        vsext.vf2       v8, v9
        vsext.vf2       v9, v10
        vmul.vv v8, v11, v8
        vadd.vv v10, v11, v9
        vsub.vv v9, v11, v9
        vor.vv  v8, v8, v10
        vor.vv  v8, v8, v9
        ret
```
###  After This Patch 
```
# %bb.0:
	vsetvli	a3, zero, e8, mf4, ta, ma
	vle8.v	v8, (a0)
	vle8.v	v9, (a1)
	vle8.v	v10, (a2)
	vwmul.vv	v11, v8, v9
	vwadd.vv	v9, v8, v10
	vwsub.vv	v12, v8, v10
	vsetvli	zero, zero, e16, mf2, ta, ma
	vor.vv	v8, v11, v9
	vor.vv	v8, v8, v12
	ret
```
We can see Add/Sub/Mul are combined with the Sign Extension.

### Relation to the Patch D133739
The patch D133739 introduced an optimization for folding `ADD_VL`/
`SUB_VL` / `MUL_V` with `VSEXT_VL` / `VZEXT_VL`. However, the patch did
not consider the case of non-fixed length vector case, thus this PR
could also be considered as an extension for the D133739.
sun-jacobi added a commit that referenced this pull request Jan 31, 2024
…80079)

Similar to #78403, but for scalable `vwadd(u).wv`, given that #76785 is recommited.

### Code
```
define <vscale x 8 x i64> @vwadd_wv_mask_v8i32(<vscale x 8 x i32> %x, <vscale x 8 x i64> %y) {
    %mask = icmp slt <vscale x 8 x i32> %x, shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 42, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer)
    %a = select <vscale x 8 x i1> %mask, <vscale x 8 x i32> %x, <vscale x 8 x i32> zeroinitializer
    %sa = sext <vscale x 8 x i32> %a to <vscale x 8 x i64>
    %ret = add <vscale x 8 x i64> %sa, %y
    ret <vscale x 8 x i64> %ret
}
```

### Before this patch
[Compiler Explorer](https://godbolt.org/z/xsoa5xPrd)
```
vwadd_wv_mask_v8i32:
        li      a0, 42
        vsetvli a1, zero, e32, m4, ta, ma
        vmslt.vx        v0, v8, a0
        vmv.v.i v12, 0
        vmerge.vvm      v24, v12, v8, v0
        vwadd.wv        v8, v16, v24
        ret
```

### After this patch
```
vwadd_wv_mask_v8i32:
        li a0, 42
        vsetvli a1, zero, e32, m4, ta, ma
        vmslt.vx v0, v8, a0
        vsetvli zero, zero, e32, m4, tu, mu
        vwadd.wv v16, v16, v8, v0.t
        vmv8r.v v8, v16
        ret
```
sun-jacobi added a commit that referenced this pull request Feb 21, 2024
Extend D133739 and #76785 to support vector widening floating-point
add/sub/mul instructions.

Specifically, this patch works for the below optimization case:

### Source code
```
define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) {
  %c = fpext <2 x float> %a to <2 x double>
  %d = fpext <2 x float> %b to <2 x double>
  %d2 = fpext <2 x float> %b2 to <2 x double>
  %e = fmul <2 x double> %c, %d
  %f = fadd <2 x double> %c, %d2
  %g = fsub <2 x double> %d, %d2
  store <2 x double> %e, ptr %x
  store <2 x double> %f, ptr %y
  store <2 x double> %g, ptr %z
  ret void
}
```

### Before this patch
[Compiler Explorer](https://godbolt.org/z/aaEMs5s9h)
```
vfwmul_v2f32_multiple_users:
        vsetivli        zero, 2, e32, mf2, ta, ma
        vfwcvt.f.f.v    v11, v8
        vfwcvt.f.f.v    v8, v9
        vfwcvt.f.f.v    v9, v10
        vsetvli zero, zero, e64, m1, ta, ma
        vfmul.vv        v10, v11, v8
        vfadd.vv        v11, v11, v9
        vfsub.vv        v8, v8, v9
        vse64.v v10, (a0)
        vse64.v v11, (a1)
        vse64.v v8, (a2)
        ret
```

### After this patch
```
vfwmul_v2f32_multiple_users:
        vsetivli zero, 2, e32, mf2, ta, ma
        vfwmul.vv v11, v8, v9
        vfwadd.vv v12, v8, v10
        vfwsub.vv v8, v9, v10
        vse64.v v11, (a0)
        vse64.v v12, (a1)
        vse64.v v8, (a2)
```
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Mar 6, 2024
…pes. NFCI

I noticed this from a discrepancy in fillUpExtensionSupport between how we apparently need to check for legal types for ISD::{ZERO,SIGN}_EXTEND, but we don't need to for RISCVISD::V{Z,S}EXT_VL.

Prior to llvm#72340, combineBinOp_VLToVWBinOp_VL only ran after type legalization because it only operated on _VL nodes.  _VL nodes are only emitted during op legalization, which takes place **after** type legalization, which is presumably why the existing code didn't need to check for legal types.

After llvm#72340 we now handle generic ops like ISD::ADD that exist before op legalization and thus **before** type legalization. This meant that we needed to add extra checks that the narrow type was legal in llvm#76785.

I think the easiest thing to do here is to just maintain the invariant that the types are legal and only run the combine after type legalization.
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Mar 6, 2024
…pes. NFCI

I noticed this from a discrepancy in fillUpExtensionSupport between how we apparently need to check for legal types for ISD::{ZERO,SIGN}_EXTEND, but we don't need to for RISCVISD::V{Z,S}EXT_VL.

Prior to llvm#72340, combineBinOp_VLToVWBinOp_VL only ran after type legalization because it only operated on _VL nodes.  _VL nodes are only emitted during op legalization, which takes place **after** type legalization, which is presumably why the existing code didn't need to check for legal types.

After llvm#72340 we now handle generic ops like ISD::ADD that exist before op legalization and thus **before** type legalization. This meant that we needed to add extra checks that the narrow type was legal in llvm#76785.

I think the easiest thing to do here is to just maintain the invariant that the types are legal and only run the combine after type legalization.
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Mar 11, 2024
…pes. NFCI

I noticed this from a discrepancy in fillUpExtensionSupport between how we apparently need to check for legal types for ISD::{ZERO,SIGN}_EXTEND, but we don't need to for RISCVISD::V{Z,S}EXT_VL.

Prior to llvm#72340, combineBinOp_VLToVWBinOp_VL only ran after type legalization because it only operated on _VL nodes.  _VL nodes are only emitted during op legalization, which takes place **after** type legalization, which is presumably why the existing code didn't need to check for legal types.

After llvm#72340 we now handle generic ops like ISD::ADD that exist before op legalization and thus **before** type legalization. This meant that we needed to add extra checks that the narrow type was legal in llvm#76785.

I think the easiest thing to do here is to just maintain the invariant that the types are legal and only run the combine after type legalization.
lukel97 added a commit that referenced this pull request Mar 11, 2024
…pes. NFCI (#84125)

I noticed this from a discrepancy in fillUpExtensionSupport between how
we apparently need to check for legal types for ISD::{ZERO,SIGN}_EXTEND,
but we don't need to for RISCVISD::V{Z,S}EXT_VL.

Prior to #72340, combineBinOp_VLToVWBinOp_VL only ran after type
legalization because it only operated on _VL nodes. _VL nodes are only
emitted during op legalization, which takes place **after** type
legalization, which is presumably why the existing code didn't need to
check for legal types.

After #72340 we now handle generic ops like ISD::ADD that exist before
op legalization and thus **before** type legalization. This meant that
we needed to add extra checks that the narrow type was legal in #76785.

I think the easiest thing to do here is to just maintain the invariant
that the types are legal and only run the combine after type
legalization.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants