-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Implement promotion type legalisation for histogram intrinsic #101017
[AArch64] Implement promotion type legalisation for histogram intrinsic #101017
Conversation
Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.
@llvm/pr-subscribers-backend-aarch64 Author: Max Beck-Jones (DevM-uk) ChangesCurrently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion. Full diff: https://github.com/llvm/llvm-project/pull/101017.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1e9da9b819bdd..153d5fe28be7b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1775,9 +1775,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
// Histcnt is SVE2 only
- if (Subtarget->hasSVE2())
+ if (Subtarget->hasSVE2()) {
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
Custom);
+ setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
+ setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
+ }
}
@@ -28018,9 +28021,17 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
EVT IndexVT = Index.getValueType();
EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
IndexVT.getVectorElementCount());
+ EVT IncExtVT = IndexVT.getVectorElementCount().getKnownMinValue() == 4
+ ? MVT::i32
+ : MVT::i64;
+ EVT IncSplatVT = EVT::getVectorVT(*DAG.getContext(), IncExtVT,
+ IndexVT.getVectorElementCount());
+ bool ExtTrunc = IncSplatVT != MemVT;
+
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
- SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero);
- SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc);
+ SDValue PassThru = DAG.getSplatVector(IncSplatVT, DL, Zero);
+ SDValue IncSplat = DAG.getSplatVector(
+ IncSplatVT, DL, DAG.getAnyExtOrTrunc(Inc, DL, IncExtVT));
SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale};
MachineMemOperand *MMO = HG->getMemOperand();
@@ -28029,18 +28040,19 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(),
MMO->getAlign(), MMO->getAAInfo());
ISD::MemIndexType IndexType = HG->getIndexType();
- SDValue Gather =
- DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops,
- GMMO, IndexType, ISD::NON_EXTLOAD);
+ SDValue Gather = DAG.getMaskedGather(
+ DAG.getVTList(IncSplatVT, MVT::Other), MemVT, DL, Ops, GMMO, IndexType,
+ ExtTrunc ? ISD::EXTLOAD : ISD::NON_EXTLOAD);
SDValue GChain = Gather.getValue(1);
// Perform the histcnt, multiply by inc, add to bucket data.
- SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT);
+ SDValue ID =
+ DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncExtVT);
SDValue HistCnt =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index);
- SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul);
+ SDValue Mul = DAG.getNode(ISD::MUL, DL, IncSplatVT, HistCnt, IncSplat);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, IncSplatVT, Gather, Mul);
// Create an MMO for the scatter, without load|store flags.
MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand(
@@ -28049,7 +28061,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale};
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL,
- ScatterOps, SMMO, IndexType, false);
+ ScatterOps, SMMO, IndexType, ExtTrunc);
return Scatter;
}
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index db164e288abde..2874e47511e12 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -50,4 +50,123 @@ define void @histogram_i32_literal_noscale(ptr %base, <vscale x 4 x i32> %indice
ret void
}
+define void @histogram_i32_promote(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i32 %inc) #0 {
+; CHECK-LABEL: histogram_i32_promote:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1w { z2.d }, p0/z, [x0, z0.d, lsl #2]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1w { z1.d }, p0, [x0, z0.d, lsl #2]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i32, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i32(<vscale x 2 x ptr> %buckets, i32 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, w1
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 %inc, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i8(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, w1
+; CHECK-NEXT: ld1b { z2.s }, p0/z, [x0, z0.s, sxtw]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1b { z1.s }, p0, [x0, z0.s, sxtw]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i8, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i8(<vscale x 4 x ptr> %buckets, i8 %inc, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16_2_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1h { z2.d }, p0/z, [x0, z0.d, lsl #1]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1h { z1.d }, p0, [x0, z0.d, lsl #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i16(<vscale x 2 x ptr> %buckets, i16 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i8_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8_2_lane:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT: mov z3.d, x1
+; CHECK-NEXT: ld1b { z2.d }, p0/z, [x0, z0.d]
+; CHECK-NEXT: ptrue p1.d
+; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT: st1b { z1.d }, p0, [x0, z0.d]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i8, ptr %base, <vscale x 2 x i64> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv2p0.i8(<vscale x 2 x ptr> %buckets, i8 %inc, <vscale x 2 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_1:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: add z1.s, z2.s, z1.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 1, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_2:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: adr z1.s, [z2.s, z1.s, lsl #1]
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 2, <vscale x 4 x i1> %mask)
+ ret void
+}
+
+define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_3:
+; CHECK: // %bb.0:
+; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT: mov z3.s, #3 // =0x3
+; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ptrue p1.s
+; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT: ret
+ %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+ call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 3, <vscale x 4 x i1> %mask)
+ ret void
+}
+
attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }
|
@@ -28018,9 +28021,17 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op, | |||
EVT IndexVT = Index.getValueType(); | |||
EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT, | |||
IndexVT.getVectorElementCount()); | |||
EVT IncExtVT = IndexVT.getVectorElementCount().getKnownMinValue() == 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the context and element count are used more than once here, having locals for them would make it a little neater. Ctx and EC are typical names in this file.
EVT IncExtVT = IndexVT.getVectorElementCount().getKnownMinValue() == 4 | |
EVT IncExtVT = EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.getKnownMinValue()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…ic (llvm#101017) Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.
This PR updates the AArch64 cost model to consider the cheaper cost of <i32 histograms to reflect the improvements from llvm#101017 and llvm#103037 Work by Max Beck-Jones (@DevM-uk)
…08521) This PR updates the AArch64 cost model to consider the cheaper cost of <i32 histograms to reflect the improvements from llvm#101017 and llvm#103037 Work by Max Beck-Jones (@DevM-uk) --------- Co-authored-by: DevM-uk <max.beck-jones@arm.com>
Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.