Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Implement promotion type legalisation for histogram intrinsic #101017

Merged
merged 2 commits into from
Aug 12, 2024

Conversation

DevM-uk
Copy link
Member

@DevM-uk DevM-uk commented Jul 29, 2024

Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.

Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.
@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Max Beck-Jones (DevM-uk)

Changes

Currently the histogram intrinsic (llvm.experimental.vector.histogram.add) only allows i32 and i64 types for the memory locations to be updated, matching the restrictions of the histcnt instruction. This patch adds support for the legalisation of smaller types (i8 and i16) via promotion.


Full diff: https://github.com/llvm/llvm-project/pull/101017.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+22-10)
  • (modified) llvm/test/CodeGen/AArch64/sve2-histcnt.ll (+119)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1e9da9b819bdd..153d5fe28be7b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1775,9 +1775,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
 
     // Histcnt is SVE2 only
-    if (Subtarget->hasSVE2())
+    if (Subtarget->hasSVE2()) {
       setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
                          Custom);
+      setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
+      setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
+    }
   }
 
 
@@ -28018,9 +28021,17 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
   EVT IndexVT = Index.getValueType();
   EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
                                IndexVT.getVectorElementCount());
+  EVT IncExtVT = IndexVT.getVectorElementCount().getKnownMinValue() == 4
+                     ? MVT::i32
+                     : MVT::i64;
+  EVT IncSplatVT = EVT::getVectorVT(*DAG.getContext(), IncExtVT,
+                                    IndexVT.getVectorElementCount());
+  bool ExtTrunc = IncSplatVT != MemVT;
+
   SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
-  SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero);
-  SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc);
+  SDValue PassThru = DAG.getSplatVector(IncSplatVT, DL, Zero);
+  SDValue IncSplat = DAG.getSplatVector(
+      IncSplatVT, DL, DAG.getAnyExtOrTrunc(Inc, DL, IncExtVT));
   SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale};
 
   MachineMemOperand *MMO = HG->getMemOperand();
@@ -28029,18 +28040,19 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
       MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(),
       MMO->getAlign(), MMO->getAAInfo());
   ISD::MemIndexType IndexType = HG->getIndexType();
-  SDValue Gather =
-      DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops,
-                          GMMO, IndexType, ISD::NON_EXTLOAD);
+  SDValue Gather = DAG.getMaskedGather(
+      DAG.getVTList(IncSplatVT, MVT::Other), MemVT, DL, Ops, GMMO, IndexType,
+      ExtTrunc ? ISD::EXTLOAD : ISD::NON_EXTLOAD);
 
   SDValue GChain = Gather.getValue(1);
 
   // Perform the histcnt, multiply by inc, add to bucket data.
-  SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT);
+  SDValue ID =
+      DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncExtVT);
   SDValue HistCnt =
       DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index);
-  SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat);
-  SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul);
+  SDValue Mul = DAG.getNode(ISD::MUL, DL, IncSplatVT, HistCnt, IncSplat);
+  SDValue Add = DAG.getNode(ISD::ADD, DL, IncSplatVT, Gather, Mul);
 
   // Create an MMO for the scatter, without load|store flags.
   MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand(
@@ -28049,7 +28061,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
 
   SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale};
   SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL,
-                                         ScatterOps, SMMO, IndexType, false);
+                                         ScatterOps, SMMO, IndexType, ExtTrunc);
   return Scatter;
 }
 
diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
index db164e288abde..2874e47511e12 100644
--- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
+++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll
@@ -50,4 +50,123 @@ define void @histogram_i32_literal_noscale(ptr %base, <vscale x 4 x i32> %indice
   ret void
 }
 
+define void @histogram_i32_promote(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i32 %inc) #0 {
+; CHECK-LABEL: histogram_i32_promote:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT:    mov z3.d, x1
+; CHECK-NEXT:    ld1w { z2.d }, p0/z, [x0, z0.d, lsl #2]
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT:    st1w { z1.d }, p0, [x0, z0.d, lsl #2]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i32, ptr %base, <vscale x 2 x i64> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv2p0.i32(<vscale x 2 x ptr> %buckets, i32 %inc, <vscale x 2 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, w1
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 %inc, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i8(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, w1
+; CHECK-NEXT:    ld1b { z2.s }, p0/z, [x0, z0.s, sxtw]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1b { z1.s }, p0, [x0, z0.s, sxtw]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i8, ptr %base, <vscale x 4 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i8(<vscale x 4 x ptr> %buckets, i8 %inc, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i16 %inc) #0 {
+; CHECK-LABEL: histogram_i16_2_lane:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT:    mov z3.d, x1
+; CHECK-NEXT:    ld1h { z2.d }, p0/z, [x0, z0.d, lsl #1]
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT:    st1h { z1.d }, p0, [x0, z0.d, lsl #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 2 x i64> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv2p0.i16(<vscale x 2 x ptr> %buckets, i16 %inc, <vscale x 2 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i8_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i8 %inc) #0 {
+; CHECK-LABEL: histogram_i8_2_lane:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.d, p0/z, z0.d, z0.d
+; CHECK-NEXT:    // kill: def $w1 killed $w1 def $x1
+; CHECK-NEXT:    mov z3.d, x1
+; CHECK-NEXT:    ld1b { z2.d }, p0/z, [x0, z0.d]
+; CHECK-NEXT:    ptrue p1.d
+; CHECK-NEXT:    mad z1.d, p1/m, z3.d, z2.d
+; CHECK-NEXT:    st1b { z1.d }, p0, [x0, z0.d]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i8, ptr %base, <vscale x 2 x i64> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv2p0.i8(<vscale x 2 x ptr> %buckets, i8 %inc, <vscale x 2 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_1:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    add z1.s, z2.s, z1.s
+; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 1, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_2:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    adr z1.s, [z2.s, z1.s, lsl #1]
+; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 2, <vscale x 4 x i1> %mask)
+  ret void
+}
+
+define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
+; CHECK-LABEL: histogram_i16_literal_3:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    histcnt z1.s, p0/z, z0.s, z0.s
+; CHECK-NEXT:    mov z3.s, #3 // =0x3
+; CHECK-NEXT:    ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ptrue p1.s
+; CHECK-NEXT:    mad z1.s, p1/m, z3.s, z2.s
+; CHECK-NEXT:    st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
+; CHECK-NEXT:    ret
+  %buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
+  call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 3, <vscale x 4 x i1> %mask)
+  ret void
+}
+
 attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }

@@ -28018,9 +28021,17 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
EVT IndexVT = Index.getValueType();
EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
IndexVT.getVectorElementCount());
EVT IncExtVT = IndexVT.getVectorElementCount().getKnownMinValue() == 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the context and element count are used more than once here, having locals for them would make it a little neater. Ctx and EC are typical names in this file.

Suggested change
EVT IncExtVT = IndexVT.getVectorElementCount().getKnownMinValue() == 4
EVT IncExtVT = EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.getKnownMinValue());

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator

@huntergr-arm huntergr-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@DevM-uk DevM-uk merged commit 670d208 into llvm:main Aug 12, 2024
7 checks passed
bwendling pushed a commit to bwendling/llvm-project that referenced this pull request Aug 15, 2024
…ic (llvm#101017)

Currently the histogram intrinsic
(llvm.experimental.vector.histogram.add) only allows i32 and i64 types
for the memory locations to be updated, matching the restrictions of the
histcnt instruction. This patch adds support for the legalisation of
smaller types (i8 and i16) via promotion.
SamTebbs33 pushed a commit to SamTebbs33/llvm-project that referenced this pull request Sep 13, 2024
This PR updates the AArch64 cost model to consider the cheaper cost of
<i32 histograms to reflect the improvements from
llvm#101017 and llvm#103037

Work by Max Beck-Jones (@DevM-uk)
SamTebbs33 added a commit that referenced this pull request Sep 19, 2024
This PR updates the AArch64 cost model to consider the cheaper cost of
<i32 histograms to reflect the improvements from
#101017 and
#103037

Work by Max Beck-Jones (@DevM-uk)

---------

Co-authored-by: DevM-uk <max.beck-jones@arm.com>
tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
…08521)

This PR updates the AArch64 cost model to consider the cheaper cost of
<i32 histograms to reflect the improvements from
llvm#101017 and
llvm#103037

Work by Max Beck-Jones (@DevM-uk)

---------

Co-authored-by: DevM-uk <max.beck-jones@arm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants