-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV][ISel] Remove redundant vmerge for the vwadd. #78403
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Chia (sun-jacobi) ChangesThis patch is aiming at resolving the below missed-optimization case. Code
Before this patch
After this patch
This pattern could be found in a reduction with a widening destination Specifically, we first do a fold like Full diff: https://github.com/llvm/llvm-project/pull/78403.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index cb9ffabc41236e..a030538e5e8ba9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13457,6 +13457,56 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return InputRootReplacement;
}
+// (vwadd y, (select cond, x, 0)) -> select cond (vwadd y, x), y
+static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) {
+ unsigned Opc = N->getOpcode();
+ assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADD_W_VL ||
+ Opc == RISCVISD::VWADDU_W_VL);
+
+ SDValue VL = N->getOperand(4);
+ SDValue Y = N->getOperand(0);
+ SDValue Merge = N->getOperand(1);
+
+ if (Merge.getOpcode() != RISCVISD::VMERGE_VL)
+ return SDValue();
+
+ SDValue Cond = Merge->getOperand(0);
+ SDValue X = Merge->getOperand(1);
+ SDValue Z = Merge->getOperand(2);
+
+ if (Z.getOpcode() != ISD::INSERT_SUBVECTOR ||
+ !isNullConstant(Z.getOperand(2)))
+ return SDValue();
+
+ if (!Merge.hasOneUse())
+ return SDValue();
+
+ SmallVector<SDValue, 6> Ops(N->op_values());
+ Ops[0] = Y;
+ Ops[1] = X;
+
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+
+ SDValue WX = DAG.getNode(Opc, DL, VT, Ops, N->getFlags());
+ return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, DAG.getUNDEF(VT),
+ VL);
+}
+
+static SDValue performVWADD_VLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ unsigned Opc = N->getOpcode();
+ assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADD_W_VL ||
+ Opc == RISCVISD::VWADDU_W_VL);
+
+ if (Opc != RISCVISD::VWADD_VL) {
+ if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
+ return V;
+ }
+
+ return combineVWADDSelect(N, DCI.DAG);
+}
+
// Helper function for performMemPairCombine.
// Try to combine the memory loads/stores LSNode1 and LSNode2
// into a single memory pair operation.
@@ -15500,9 +15550,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI))
return V;
return combineToVWMACC(N, DAG, Subtarget);
- case RISCVISD::SUB_VL:
+ case RISCVISD::VWADD_VL:
case RISCVISD::VWADD_W_VL:
case RISCVISD::VWADDU_W_VL:
+ return performVWADD_VLCombine(N, DCI);
+ case RISCVISD::SUB_VL:
case RISCVISD::VWSUB_W_VL:
case RISCVISD::VWSUBU_W_VL:
case RISCVISD::MUL_VL:
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 1deb9a709463e8..6744a38d036b00 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -691,6 +691,30 @@ multiclass VPatTiedBinaryNoMaskVL_V<SDNode vop,
GPR:$vl, sew, TU_MU)>;
}
+class VPatTiedBinaryMaskVL_V<SDNode vop,
+ string instruction_name,
+ string suffix,
+ ValueType result_type,
+ ValueType op2_type,
+ ValueType mask_type,
+ int sew,
+ LMULInfo vlmul,
+ VReg result_reg_class,
+ VReg op2_reg_class>
+ : Pat<(riscv_vmerge_vl (mask_type V0),
+ (result_type (vop
+ result_reg_class:$rs1,
+ (op2_type op2_reg_class:$rs2),
+ srcvalue,
+ true_mask,
+ VLOpFrag)),
+ result_reg_class:$rs1, result_reg_class:$merge, VLOpFrag),
+ (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
+ result_reg_class:$merge,
+ result_reg_class:$rs1,
+ op2_reg_class:$rs2,
+ (mask_type V0), GPR:$vl, sew, TAIL_AGNOSTIC)>;
+
multiclass VPatTiedBinaryNoMaskVL_V_RM<SDNode vop,
string instruction_name,
string suffix,
@@ -819,6 +843,9 @@ multiclass VPatBinaryWVL_VV_VX_WV_WX<SDPatternOperator vop, SDNode vop_w,
defm : VPatTiedBinaryNoMaskVL_V<vop_w, instruction_name, "WV",
wti.Vector, vti.Vector, vti.Log2SEW,
vti.LMul, wti.RegClass, vti.RegClass>;
+ def : VPatTiedBinaryMaskVL_V<vop_w, instruction_name, "WV",
+ wti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
+ vti.LMul, wti.RegClass, vti.RegClass>;
def : VPatBinaryVL_V<vop_w, instruction_name, "WV",
wti.Vector, wti.Vector, vti.Vector, vti.Mask,
vti.Log2SEW, vti.LMul, wti.RegClass, wti.RegClass,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
new file mode 100644
index 00000000000000..afc59b875d79df
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwadd-mask.ll
@@ -0,0 +1,35 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=riscv32 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
+; RUN: llc -mtriple=riscv64 -mattr=+v -verify-machineinstrs < %s | FileCheck %s --check-prefixes=CHECK
+
+define <8 x i64> @vwadd_mask_v8i32(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_mask_v8i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+ %sa = sext <8 x i32> %a to <8 x i64>
+ %ret = add <8 x i64> %sa, %y
+ ret <8 x i64> %ret
+}
+
+define <8 x i64> @vwadd_mask_v8i32_commutative(<8 x i32> %x, <8 x i64> %y) {
+; CHECK-LABEL: vwadd_mask_v8i32_commutative:
+; CHECK: # %bb.0:
+; CHECK-NEXT: li a0, 42
+; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
+; CHECK-NEXT: vmslt.vx v0, v8, a0
+; CHECK-NEXT: vwadd.wv v16, v12, v8, v0.t
+; CHECK-NEXT: vmv4r.v v8, v16
+; CHECK-NEXT: ret
+ %mask = icmp slt <8 x i32> %x, <i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42, i32 42>
+ %a = select <8 x i1> %mask, <8 x i32> %x, <8 x i32> zeroinitializer
+ %sa = sext <8 x i32> %a to <8 x i64>
+ %ret = add <8 x i64> %y, %sa
+ ret <8 x i64> %ret
+}
|
EVT VT = N->getValueType(0); | ||
|
||
SDValue WX = DAG.getNode(Opc, DL, VT, Ops, N->getFlags()); | ||
return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, DAG.getUNDEF(VT), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Operand 2 of the original vmerge is the passthru operand for elements past VL if the VWADD is tail undisturbed. This VMERGE_VL has undef for its passthru. That corrupts the elements past VL.
SDValue Z = Merge->getOperand(2); | ||
|
||
if (Z.getOpcode() != ISD::INSERT_SUBVECTOR || | ||
!isNullConstant(Z.getOperand(2))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only checks that the insertion index is 0. Where do you check the vector being inserted is 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing out this, I will fix it.
I think this might be an issue in define <vscale x 2 x i32> @f(<vscale x 2 x i32> %x, <vscale x 2 x i32> %y) {
%mask = icmp slt <vscale x 2 x i32> %x, shufflevector(<vscale x 2 x i32> insertelement(<vscale x 2 x i32> poison, i32 42, i32 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer)
%a = select <vscale x 2 x i1> %mask, <vscale x 2 x i32> %x, <vscale x 2 x i32> zeroinitializer
%ret = add <vscale x 2 x i32> %a, %y
ret <vscale x 2 x i32> %ret
}
define <vscale x 2 x i64> @g(<vscale x 2 x i32> %x, <vscale x 2 x i64> %y) {
%mask = icmp slt <vscale x 2 x i32> %x, shufflevector(<vscale x 2 x i32> insertelement(<vscale x 2 x i32> poison, i32 42, i32 0), <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer)
%a = select <vscale x 2 x i1> %mask, <vscale x 2 x i32> %x, <vscale x 2 x i32> zeroinitializer
%sa = sext <vscale x 2 x i32> %a to <vscale x 2 x i64>
%ret = add <vscale x 2 x i64> %sa, %y
ret <vscale x 2 x i64> %ret
} Update: It's not an issue with performCombineVMergeAndVOps, we're doing a similar combine as this patch somewhere for add.
Update: The combine is DAGCombiner::foldBinOpIntoSelect, which doesn't trigger for the sext case because there's a sign_extend in between the add and vselect:
|
Yes, the original |
@lukel97 Thank you! That's exactly what I mean. |
return SDValue(); | ||
|
||
SmallVector<SDValue, 6> Ops(N->op_values()); | ||
Ops[0] = Y; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't Ops[0] already Y?
EVT VT = N->getValueType(0); | ||
|
||
SDValue WX = DAG.getNode(Opc, DL, VT, Ops, N->getFlags()); | ||
return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, Y, VL); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still incorrect. You have to use N->getOperand(2)
for the passthru operand to the vmerge.
You're also losing any mask that the VWADD_W_VL may have already had.
// (vwadd y, (select cond, x, 0)) -> select cond (vwadd y, x), y | ||
static SDValue combineVWADDSelect(SDNode *N, SelectionDAG &DAG) { | ||
unsigned Opc = N->getOpcode(); | ||
assert(Opc == RISCVISD::VWADD_VL || Opc == RISCVISD::VWADD_W_VL || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can't be VWADD_VL due to the check in performVWADD_VLCombine
right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check in performVWADD_VLCombine
is for RISCVISD::VWADD_W_VL
and RISCVISD::VWADDU_W_VL
.
We need to first do combineBinOp_VLToVWBinOp_VL
on those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops. You're right. Sorry about that.
EVT VT = N->getValueType(0); | ||
|
||
SDValue WX = DAG.getNode(Opc, DL, VT, Ops, N->getFlags()); | ||
return DAG.getNode(RISCVISD::VMERGE_VL, DL, VT, Cond, WX, Y, Y, VL); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to create a VMERGE, you just need to change the Mask operand when you create WX
. RISCVISD::VWADD_W_VL supports all the operands you need to describe this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the advice. It works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this we get a normal masked instruction, but in this case I think we need the MASK_TIED
.
|
||
SmallVector<SDValue, 6> Ops(N->op_values()); | ||
Ops[MergeID] = X; | ||
Ops[3] = Cond; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can't replace operand 3 without checking that operand 3 was an all 1s mask or the passthru was undef originally. If the mask wasn't all 1s or the passthru wasn't undef then then original add produced the passthru operand for masked off elements.
SDValue X = Merge->getOperand(1); | ||
SDValue Z = Merge->getOperand(2); | ||
|
||
if (Z.getOpcode() != ISD::INSERT_SUBVECTOR || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't check what operand 0 of the insert is or the size of the insertion. So you only know some subvector of the input is 0. You don't know the whole vector is 0.
if (!Merge.hasOneUse()) | ||
return SDValue(); | ||
|
||
SmallVector<SDValue, 6> Ops(N->op_values()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why 6? I think there are only 5 operands. LHS, RHS, Passthru, Mask, VL
AFAIU, we may need For |
Sorry for Ping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
263047c
to
1386a93
Compare
…80079) Similar to #78403, but for scalable `vwadd(u).wv`, given that #76785 is recommited. ### Code ``` define <vscale x 8 x i64> @vwadd_wv_mask_v8i32(<vscale x 8 x i32> %x, <vscale x 8 x i64> %y) { %mask = icmp slt <vscale x 8 x i32> %x, shufflevector (<vscale x 8 x i32> insertelement (<vscale x 8 x i32> poison, i32 42, i64 0), <vscale x 8 x i32> poison, <vscale x 8 x i32> zeroinitializer) %a = select <vscale x 8 x i1> %mask, <vscale x 8 x i32> %x, <vscale x 8 x i32> zeroinitializer %sa = sext <vscale x 8 x i32> %a to <vscale x 8 x i64> %ret = add <vscale x 8 x i64> %sa, %y ret <vscale x 8 x i64> %ret } ``` ### Before this patch [Compiler Explorer](https://godbolt.org/z/xsoa5xPrd) ``` vwadd_wv_mask_v8i32: li a0, 42 vsetvli a1, zero, e32, m4, ta, ma vmslt.vx v0, v8, a0 vmv.v.i v12, 0 vmerge.vvm v24, v12, v8, v0 vwadd.wv v8, v16, v24 ret ``` ### After this patch ``` vwadd_wv_mask_v8i32: li a0, 42 vsetvli a1, zero, e32, m4, ta, ma vmslt.vx v0, v8, a0 vsetvli zero, zero, e32, m4, tu, mu vwadd.wv v16, v16, v8, v0.t vmv8r.v v8, v16 ret ```
Note we can't use vwaddu.wv because it will get combined away with #78403
This patch is aiming at resolving the below missed-optimization case.
Code
Before this patch
Compiler Explorer
After this patch
This pattern could be found in a reduction with a widening destination
Specifically, we first do a fold like
(vwadd.wv y, (vmerge cond, x, 0)) -> (vwadd.wv y, x, y, cond)
, then do pattern matching on it.