-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Move vnclip patterns into DAGCombiner. #93728
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Craig Topper (topperc) ChangesSimilar to #93596, this moves the signed vnclip patterns into DAG combine. This will allows us to support more than 1 level of truncate in a There's a pre-commit that refactors the vnclipu code to make it is easier to share code. Full diff: https://github.com/llvm/llvm-project/pull/93728.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0242cfe178524..f4b64df927418 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16185,8 +16185,8 @@ static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) {
// Combine (truncate_vector_vl (umin X, C)) -> (vnclipu_vl X) if C is maximum
// value for the truncated type.
-static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
- const RISCVSubtarget &Subtarget) {
+static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
+ const RISCVSubtarget &Subtarget) {
assert(N->getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL);
MVT VT = N->getSimpleValueType(0);
@@ -16194,15 +16194,15 @@ static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
SDValue Mask = N->getOperand(1);
SDValue VL = N->getOperand(2);
- SDValue Src = N->getOperand(0);
+ auto MatchMinMax = [&VL, &Mask](SDValue V, unsigned Opc, unsigned OpcVL,
+ APInt &SplatVal) {
+ if (V.getOpcode() != Opc &&
+ !(V.getOpcode() == OpcVL && V.getOperand(2).isUndef() &&
+ V.getOperand(3) == Mask && V.getOperand(4) == VL))
+ return SDValue();
- // Src must be a UMIN or UMIN_VL.
- if (Src.getOpcode() != ISD::UMIN &&
- !(Src.getOpcode() == RISCVISD::UMIN_VL && Src.getOperand(2).isUndef() &&
- Src.getOperand(3) == Mask && Src.getOperand(4) == VL))
- return SDValue();
+ SDValue Op = V.getOperand(1);
- auto IsSplat = [&VL](SDValue Op, APInt &SplatVal) {
// Peek through conversion between fixed and scalable vectors.
if (Op.getOpcode() == ISD::INSERT_SUBVECTOR && Op.getOperand(0).isUndef() &&
isNullConstant(Op.getOperand(2)) &&
@@ -16213,32 +16213,67 @@ static SDValue combineTruncToVnclipu(SDNode *N, SelectionDAG &DAG,
Op = Op.getOperand(1).getOperand(0);
if (ISD::isConstantSplatVector(Op.getNode(), SplatVal))
- return true;
+ return V.getOperand(0);
if (Op.getOpcode() == RISCVISD::VMV_V_X_VL && Op.getOperand(0).isUndef() &&
Op.getOperand(2) == VL) {
if (auto *Op1 = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
SplatVal =
Op1->getAPIntValue().sextOrTrunc(Op.getScalarValueSizeInBits());
- return true;
+ return V.getOperand(0);
}
}
- return false;
+ return SDValue();
};
- APInt C;
- if (!IsSplat(Src.getOperand(1), C))
+ auto DetectUSatPattern = [&](SDValue V) {
+ // Src must be a UMIN or UMIN_VL.
+ APInt C;
+ SDValue UMin = MatchMinMax(V, ISD::UMIN, RISCVISD::UMIN_VL, C);
+ if (!UMin)
+ return SDValue();
+
+ if (!C.isMask(VT.getScalarSizeInBits()))
+ return SDValue();
+
+ return UMin;
+ };
+
+ auto DetectSSatPattern = [&](SDValue V) {
+ unsigned NumDstBits = VT.getScalarSizeInBits();
+ unsigned NumSrcBits = V.getScalarValueSizeInBits();
+ APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
+ APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
+
+ APInt CMin, CMax;
+ if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, CMin))
+ if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, CMax))
+ if (CMin == SignedMax && CMax == SignedMin)
+ return SMax;
+
+ if (SDValue SMax = MatchMinMax(V, ISD::SMAX, RISCVISD::SMAX_VL, CMax))
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, RISCVISD::SMIN_VL, CMin))
+ if (CMin == SignedMax && CMax == SignedMin)
+ return SMin;
+
return SDValue();
+ };
- if (!C.isMask(VT.getScalarSizeInBits()))
+ SDValue Val;
+ unsigned ClipOpc;
+ if ((Val = DetectUSatPattern(N->getOperand(0)))) {
+ ClipOpc = RISCVISD::VNCLIPU_VL;
+ } else if ((Val = DetectSSatPattern(N->getOperand(0)))) {
+ ClipOpc = RISCVISD::VNCLIP_VL;
+ } else
return SDValue();
SDLoc DL(N);
// Rounding mode here is arbitrary since we aren't shifting out any bits.
return DAG.getNode(
- RISCVISD::VNCLIPU_VL, DL, VT,
- {Src.getOperand(0), DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
+ ClipOpc, DL, VT,
+ {Val, DAG.getConstant(0, DL, VT), DAG.getUNDEF(VT), Mask,
DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
VL});
}
@@ -16462,7 +16497,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::TRUNCATE_VECTOR_VL:
if (SDValue V = combineTruncOfSraSext(N, DAG))
return V;
- return combineTruncToVnclipu(N, DAG, Subtarget);
+ return combineTruncToVnclip(N, DAG, Subtarget);
case ISD::TRUNCATE:
return performTRUNCATECombine(N, DAG, Subtarget);
case ISD::SELECT:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index 691f2052ab29d..3163e4bafd4b0 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -1168,40 +1168,6 @@ defm : VPatAVGADD_VV_VX_RM<avgflooru, 0b10, suffix = "U">;
defm : VPatAVGADD_VV_VX_RM<avgceils, 0b00>;
defm : VPatAVGADD_VV_VX_RM<avgceilu, 0b00, suffix = "U">;
-// 12.5. Vector Narrowing Fixed-Point Clip Instructions
-multiclass VPatTruncSatClipSDNode<VTypeInfo vti, VTypeInfo wti> {
- defvar sew = vti.SEW;
- defvar uminval = !sub(!shl(1, sew), 1);
- defvar sminval = !sub(!shl(1, !sub(sew, 1)), 1);
- defvar smaxval = !sub(0, !shl(1, !sub(sew, 1)));
-
- let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
- GetVTypePredicates<wti>.Predicates) in {
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (smin
- (wti.Vector (smax (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))))),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))))),
- (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (smax
- (wti.Vector (smin (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))))),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))))),
- (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
- }
-}
-
-foreach vtiToWti = AllWidenableIntVectors in
- defm : VPatTruncSatClipSDNode<vtiToWti.Vti, vtiToWti.Wti>;
-
// 15. Vector Mask Instructions
// 15.1. Vector Mask-Register Logical Instructions
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 610a72dd02b38..ce8133a5a297b 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2470,46 +2470,6 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
-// 12.5. Vector Narrowing Fixed-Point Clip Instructions
-multiclass VPatTruncSatClipVL<VTypeInfo vti, VTypeInfo wti> {
- defvar sew = vti.SEW;
- defvar uminval = !sub(!shl(1, sew), 1);
- defvar sminval = !sub(!shl(1, !sub(sew, 1)), 1);
- defvar smaxval = !sub(0, !shl(1, !sub(sew, 1)));
-
- let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
- GetVTypePredicates<wti>.Predicates) in {
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (riscv_smin_vl
- (wti.Vector (riscv_smax_vl
- (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))),
- (wti.Vector undef),(wti.Mask V0), VLOpFrag)),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))),
- (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
- (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
-
- def : Pat<(vti.Vector (riscv_trunc_vector_vl
- (wti.Vector (riscv_smax_vl
- (wti.Vector (riscv_smin_vl
- (wti.Vector wti.RegClass:$rs1),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), sminval, (XLenVT srcvalue))),
- (wti.Vector undef),(wti.Mask V0), VLOpFrag)),
- (wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), smaxval, (XLenVT srcvalue))),
- (wti.Vector undef), (wti.Mask V0), VLOpFrag)),
- (vti.Mask V0), VLOpFrag)),
- (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
- (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
- (vti.Mask V0), 0, GPR:$vl, vti.Log2SEW, TA_MA)>;
- }
-}
-
-foreach vtiToWti = AllWidenableIntVectors in
- defm : VPatTruncSatClipVL<vtiToWti.Vti, vtiToWti.Wti>;
-
// 13. Vector Floating-Point Instructions
// 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Aside - This is starting to look a bit like isSaturatingMinMax in DAGCombine. This version handles the _VL variants, and that one handles the select variants, so neither is a subset of the other. Do we have any room to optimize the saturating fp_to_sint patterns the DAG is working on? I don't see anything but figured it's worth asking.. Alternatively, is there anything we can do to optimize/recognize the clamp idiom on it's own?
Looks like there are some vnclips in test/CodeGen/RISCV/rvv/fpclamptosat_vec.ll but also some min/max/vnsrl. |
…nclip support. NFC Reviewed as part of #93728.
Similar to llvm#93596, this moves the signed vnclip patterns into DAG combine. This will allows us to support more than 1 level of truncate in a future patch.
459845a
to
c19db6f
Compare
Similar to #93596, this moves the signed vnclip patterns into DAG combine.
This will allows us to support more than 1 level of truncate in a
future patch.
There's a pre-commit that refactors the vnclipu code to make it is easier to share code.