Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Move vnclip patterns into DAGCombiner. #93728

Merged
merged 1 commit into from
May 29, 2024

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented May 29, 2024

Similar to #93596, this moves the signed vnclip patterns into DAG combine.

This will allows us to support more than 1 level of truncate in a
future patch.

There's a pre-commit that refactors the vnclipu code to make it is easier to share code.

@llvmbot
Copy link
Member

llvmbot commented May 29, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

Similar to #93596, this moves the signed vnclip patterns into DAG combine.

This will allows us to support more than 1 level of truncate in a
future patch.

There's a pre-commit that refactors the vnclipu code to make it is easier to share code.


Full diff: https://github.com/llvm/llvm-project/pull/93728.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+53-18)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td (-34)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (-40)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0242cfe178524..f4b64df927418 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16185,8 +16185,8 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
 
 // Combine (truncate_vector_vl (umin X, C)) -> (vnclipu_vl X) if C is maximum
 // value for the truncated type.
-static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
-                                     const RISCVSubtarget &Subtarget) {
+static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
+                                    const RISCVSubtarget &Subtarget) {
   assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
 
   MVT VT = N->getSimpleValueType(0);
@@ -16194,15 +16194,15 @@ static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
   SDValue Mask = N->getOperand(1);
   SDValue VL = N->getOperand(2);
 
-  SDValue Src = N->getOperand(0);
+  auto MatchMinMax = [&VL, &Mask](SDValue V, unsigned Opc, unsigned OpcVL,
+                                  APInt &SplatVal) {
+    if (V.getOpcode() != Opc &&
+        !(V.getOpcode() == OpcVL && V.getOperand(2).isUndef() &&
+          V.getOperand(3) == Mask && V.getOperand(4) == VL))
+      return SDValue();
 
-  // Src must be a UMIN or UMIN_VL.
-  if (Src.getOpcode() != ISD::UMIN &&
-      !(Src.getOpcode() == RISCVISD::UMIN_VL && Src.getOperand(2).isUndef() &&
-        Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
-    return SDValue();
+    SDValue Op = V.getOperand(1);
 
-  auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
     // Peek through conversion between fixed and scalable vectors.
     if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
         isNullConstant(Op.getOperand(2)) &&
@@ -16213,32 +16213,67 @@ static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
       Op = Op.getOperand(1).getOperand(0);
 
     if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
-      return true;
+      return V.getOperand(0);
 
     if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
         Op.getOperand(2) == VL) {
       if (auto *Op1 = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
         SplatVal =
             Op1->getAPIntValue().sextOrTrunc(Op.getScalarValueSizeInBits());
-        return true;
+        return V.getOperand(0);
       }
     }
 
-    return false;
+    return SDValue();
   };
 
-  APInt C;
-  if (!IsSplat(Src.getOperand(1), C))
+  auto DetectUSatPattern = [&](SDValue V) {
+    // Src must be a UMIN or UMIN_VL.
+    APInt C;
+    SDValue UMin = MatchMinMax(V, ISD::UMIN, RISCVISD::UMIN_VL, C);
+    if (!UMin)
+      return SDValue();
+
+    if (!C.isMask(VT.getScalarSizeInBits()))
+      return SDValue();
+
+    return UMin;
+  };
+
+  auto DetectSSatPattern = [&](SDValue V) {
+    unsigned NumDstBits = VT.getScalarSizeInBits();
+    unsigned NumSrcBits = V.getScalarValueSizeInBits();
+    APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+    APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+
+    APInt CMin, CMax;
+    if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, CMin))
+      if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, CMax))
+        if (CMin == SignedMax && CMax == SignedMin)
+          return SMax;
+
+    if (SDValue SMax = MatchMinMax(V, ISD::SMAX, RISCVISD::SMAX_VL, CMax))
+      if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, RISCVISD::SMIN_VL, CMin))
+        if (CMin == SignedMax && CMax == SignedMin)
+          return SMin;
+
     return SDValue();
+  };
 
-  if (!C.isMask(VT.getScalarSizeInBits()))
+  SDValue Val;
+  unsigned ClipOpc;
+  if ((Val = DetectUSatPattern(N->getOperand(0)))) {
+    ClipOpc = RISCVISD::VNCLIPU_VL;
+  } else if ((Val = DetectSSatPattern(N->getOperand(0)))) {
+    ClipOpc = RISCVISD::VNCLIP_VL;
+  } else
     return SDValue();
 
   SDLoc DL(N);
   // Rounding mode here is arbitrary since we aren't shifting out any bits.
   return DAG.getNode(
-      RISCVISD::VNCLIPU_VL, DL, VT,
-      {Src.getOperand(0), DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
+      ClipOpc, DL, VT,
+      {Val, DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
        DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
        VL});
 }
@@ -16462,7 +16497,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::TRUNCATE_VECTOR_VL:
     if (SDValue V = combineTruncOfSraSext(N, DAG))
       return V;
-    return combineTruncToVnclipu(N, DAG, Subtarget);
+    return combineTruncToVnclip(N, DAG, Subtarget);
   case ISD::TRUNCATE:
     return performTRUNCATECombine(N, DAG, Subtarget);
   case ISD::SELECT:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 691f2052ab29d..3163e4bafd4b0 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -1168,40 +1168,6 @@ defm : VPatAVGADD_VV_VX_RM<avgflooru, 0b10, suffix = "U">;
 defm : VPatAVGADD_VV_VX_RM<avgceils, 0b00>;
 defm : VPatAVGADD_VV_VX_RM<avgceilu, 0b00, suffix = "U">;
 
-// 12.5. Vector Narrowing Fixed-Point Clip Instructions
-multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {
-  defvar sew = vti.SEW;
-  defvar uminval = !sub(!shl(1, sew), 1);
-  defvar sminval = !sub(!shl(1, !sub(sew, 1)), 1);
-  defvar smaxval = !sub(0, !shl(1, !sub(sew, 1)));
-
-  let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
-                               GetVTypePredicates<wti>.Predicates) in {
-    def : Pat<(vti.Vector (riscv_trunc_vector_vl
-        (wti.Vector (smin
-          (wti.Vector (smax (wti.Vector wti.RegClass:$rs1),
-            (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))))),
-          (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))))),
-        (vti.Mask V0), VLOpFrag)),
-      (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
-        (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
-        (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
-    def : Pat<(vti.Vector (riscv_trunc_vector_vl
-        (wti.Vector (smax
-          (wti.Vector (smin (wti.Vector wti.RegClass:$rs1),
-            (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))))),
-          (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))))),
-        (vti.Mask V0), VLOpFrag)),
-      (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
-        (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
-        (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-  }
-}
-
-foreach vtiToWti = AllWidenableIntVectors in
-  defm : VPatTruncSatClipSDNode<vtiToWti.Vti, vtiToWti.Wti>;
-
 // 15. Vector Mask Instructions
 
 // 15.1. Vector Mask-Register Logical Instructions
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 610a72dd02b38..ce8133a5a297b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2470,46 +2470,6 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
 defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
 defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
 
-// 12.5. Vector Narrowing Fixed-Point Clip Instructions
-multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
-  defvar sew = vti.SEW;
-  defvar uminval = !sub(!shl(1, sew), 1);
-  defvar sminval = !sub(!shl(1, !sub(sew, 1)), 1);
-  defvar smaxval = !sub(0, !shl(1, !sub(sew, 1)));
-
-  let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
-                               GetVTypePredicates<wti>.Predicates) in {
-    def : Pat<(vti.Vector (riscv_trunc_vector_vl
-        (wti.Vector (riscv_smin_vl
-          (wti.Vector (riscv_smax_vl
-            (wti.Vector wti.RegClass:$rs1),
-            (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))),
-            (wti.Vector undef),(wti.Mask V0), VLOpFrag)),
-          (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))),
-          (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
-        (vti.Mask V0), VLOpFrag)),
-      (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
-        (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
-        (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
-    def : Pat<(vti.Vector (riscv_trunc_vector_vl
-        (wti.Vector (riscv_smax_vl
-          (wti.Vector (riscv_smin_vl
-            (wti.Vector wti.RegClass:$rs1),
-            (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))),
-            (wti.Vector undef),(wti.Mask V0), VLOpFrag)),
-          (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))),
-          (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
-        (vti.Mask V0), VLOpFrag)),
-      (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
-        (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
-        (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-  }
-}
-
-foreach vtiToWti = AllWidenableIntVectors in
-  defm : VPatTruncSatClipVL<vtiToWti.Vti, vtiToWti.Wti>;
-
 // 13. Vector Floating-Point Instructions
 
 // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Aside - This is starting to look a bit like isSaturatingMinMax in DAGCombine. This version handles the _VL variants, and that one handles the select variants, so neither is a subset of the other. Do we have any room to optimize the saturating fp_to_sint patterns the DAG is working on? I don't see anything but figured it's worth asking.. Alternatively, is there anything we can do to optimize/recognize the clamp idiom on it's own?

@topperc
Copy link
Collaborator Author

topperc commented May 29, 2024

LGTM

Aside - This is starting to look a bit like isSaturatingMinMax in DAGCombine. This version handles the _VL variants, and that one handles the select variants, so neither is a subset of the other. Do we have any room to optimize the saturating fp_to_sint patterns the DAG is working on? I don't see anything but figured it's worth asking.. Alternatively, is there anything we can do to optimize/recognize the clamp idiom on it's own?

Looks like there are some vnclips in test/CodeGen/RISCV/rvv/fpclamptosat_vec.ll but also some min/max/vnsrl.

topperc added a commit that referenced this pull request May 29, 2024
Similar to llvm#93596, this moves the signed vnclip patterns into DAG combine.

This will allows us to support more than 1 level of truncate in a
future patch.
@topperc topperc force-pushed the pr/vnclip-combiner branch from 459845a to c19db6f Compare May 29, 2024 23:45
@topperc topperc merged commit 8a8cd8a into llvm:main May 29, 2024
4 of 6 checks passed
@topperc topperc deleted the pr/vnclip-combiner branch May 29, 2024 23:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants