Skip to content

Commit

Permalink
[AArch64] Check for streaming mode in HasSME* features. (llvm#96302)
Browse files Browse the repository at this point in the history
This also fixes up some asserts in copyPhysReg, loadRegFromStackSlot and
storeRegToStackSlot.
  • Loading branch information
sdesmalen-arm authored Jun 24, 2024
1 parent 2ae0905 commit 62baf21
Show file tree
Hide file tree
Showing 111 changed files with 186 additions and 172 deletions.
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
addQRType(MVT::v8bf16);
}

if (Subtarget->hasSVEorSME()) {
if (Subtarget->isSVEorStreamingSVEAvailable()) {
// Add legal sve predicate types
addRegisterClass(MVT::nxv1i1, &AArch64::PPRRegClass);
addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass);
Expand Down Expand Up @@ -1408,7 +1408,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,

// FIXME: Move lowering for more nodes here if those are common between
// SVE and SME.
if (Subtarget->hasSVEorSME()) {
if (Subtarget->isSVEorStreamingSVEAvailable()) {
for (auto VT :
{MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) {
setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
Expand Down
45 changes: 22 additions & 23 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4497,7 +4497,8 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
// Copy a Predicate register by ORRing with itself.
if (AArch64::PPRRegClass.contains(DestReg) &&
AArch64::PPRRegClass.contains(SrcReg)) {
assert(Subtarget.hasSVEorSME() && "Unexpected SVE register.");
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected SVE register.");
BuildMI(MBB, I, DL, get(AArch64::ORR_PPzPP), DestReg)
.addReg(SrcReg) // Pg
.addReg(SrcReg)
Expand All @@ -4510,8 +4511,6 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
bool DestIsPNR = AArch64::PNRRegClass.contains(DestReg);
bool SrcIsPNR = AArch64::PNRRegClass.contains(SrcReg);
if (DestIsPNR || SrcIsPNR) {
assert((Subtarget.hasSVE2p1() || Subtarget.hasSME2()) &&
"Unexpected predicate-as-counter register.");
auto ToPPR = [](MCRegister R) -> MCRegister {
return (R - AArch64::PN0) + AArch64::P0;
};
Expand All @@ -4532,7 +4531,8 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
// Copy a Z register by ORRing with itself.
if (AArch64::ZPRRegClass.contains(DestReg) &&
AArch64::ZPRRegClass.contains(SrcReg)) {
assert(Subtarget.hasSVEorSME() && "Unexpected SVE register.");
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected SVE register.");
BuildMI(MBB, I, DL, get(AArch64::ORR_ZZZ), DestReg)
.addReg(SrcReg)
.addReg(SrcReg, getKillRegState(KillSrc));
Expand All @@ -4544,7 +4544,8 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
AArch64::ZPR2StridedOrContiguousRegClass.contains(DestReg)) &&
(AArch64::ZPR2RegClass.contains(SrcReg) ||
AArch64::ZPR2StridedOrContiguousRegClass.contains(SrcReg))) {
assert(Subtarget.hasSVEorSME() && "Unexpected SVE register.");
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected SVE register.");
static const unsigned Indices[] = {AArch64::zsub0, AArch64::zsub1};
copyPhysRegTuple(MBB, I, DL, DestReg, SrcReg, KillSrc, AArch64::ORR_ZZZ,
Indices);
Expand All @@ -4554,7 +4555,8 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
// Copy a Z register triple by copying the individual sub-registers.
if (AArch64::ZPR3RegClass.contains(DestReg) &&
AArch64::ZPR3RegClass.contains(SrcReg)) {
assert(Subtarget.hasSVEorSME() && "Unexpected SVE register.");
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected SVE register.");
static const unsigned Indices[] = {AArch64::zsub0, AArch64::zsub1,
AArch64::zsub2};
copyPhysRegTuple(MBB, I, DL, DestReg, SrcReg, KillSrc, AArch64::ORR_ZZZ,
Expand All @@ -4567,7 +4569,8 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB,
AArch64::ZPR4StridedOrContiguousRegClass.contains(DestReg)) &&
(AArch64::ZPR4RegClass.contains(SrcReg) ||
AArch64::ZPR4StridedOrContiguousRegClass.contains(SrcReg))) {
assert(Subtarget.hasSVEorSME() && "Unexpected SVE register.");
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected SVE register.");
static const unsigned Indices[] = {AArch64::zsub0, AArch64::zsub1,
AArch64::zsub2, AArch64::zsub3};
copyPhysRegTuple(MBB, I, DL, DestReg, SrcReg, KillSrc, AArch64::ORR_ZZZ,
Expand Down Expand Up @@ -4830,14 +4833,12 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
Opc = AArch64::STRBui;
break;
case 2: {
bool IsPNR = AArch64::PNRRegClass.hasSubClassEq(RC);
if (AArch64::FPR16RegClass.hasSubClassEq(RC))
Opc = AArch64::STRHui;
else if (IsPNR || AArch64::PPRRegClass.hasSubClassEq(RC)) {
assert(Subtarget.hasSVEorSME() &&
else if (AArch64::PNRRegClass.hasSubClassEq(RC) ||
AArch64::PPRRegClass.hasSubClassEq(RC)) {
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected register store without SVE store instructions");
assert((!IsPNR || Subtarget.hasSVE2p1() || Subtarget.hasSME2()) &&
"Unexpected register store without SVE2p1 or SME2");
Opc = AArch64::STR_PXI;
StackID = TargetStackID::ScalableVector;
}
Expand Down Expand Up @@ -4886,7 +4887,7 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
AArch64::sube64, AArch64::subo64, FI, MMO);
return;
} else if (AArch64::ZPRRegClass.hasSubClassEq(RC)) {
assert(Subtarget.hasSVEorSME() &&
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected register store without SVE store instructions");
Opc = AArch64::STR_ZXI;
StackID = TargetStackID::ScalableVector;
Expand All @@ -4910,7 +4911,7 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
Offset = false;
} else if (AArch64::ZPR2RegClass.hasSubClassEq(RC) ||
AArch64::ZPR2StridedOrContiguousRegClass.hasSubClassEq(RC)) {
assert(Subtarget.hasSVEorSME() &&
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected register store without SVE store instructions");
Opc = AArch64::STR_ZZXI;
StackID = TargetStackID::ScalableVector;
Expand All @@ -4922,7 +4923,7 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
Opc = AArch64::ST1Threev2d;
Offset = false;
} else if (AArch64::ZPR3RegClass.hasSubClassEq(RC)) {
assert(Subtarget.hasSVEorSME() &&
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected register store without SVE store instructions");
Opc = AArch64::STR_ZZZXI;
StackID = TargetStackID::ScalableVector;
Expand All @@ -4935,7 +4936,7 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
Offset = false;
} else if (AArch64::ZPR4RegClass.hasSubClassEq(RC) ||
AArch64::ZPR4StridedOrContiguousRegClass.hasSubClassEq(RC)) {
assert(Subtarget.hasSVEorSME() &&
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected register store without SVE store instructions");
Opc = AArch64::STR_ZZZZXI;
StackID = TargetStackID::ScalableVector;
Expand Down Expand Up @@ -5008,10 +5009,8 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
if (AArch64::FPR16RegClass.hasSubClassEq(RC))
Opc = AArch64::LDRHui;
else if (IsPNR || AArch64::PPRRegClass.hasSubClassEq(RC)) {
assert(Subtarget.hasSVEorSME() &&
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected register load without SVE load instructions");
assert((!IsPNR || Subtarget.hasSVE2p1() || Subtarget.hasSME2()) &&
"Unexpected register load without SVE2p1 or SME2");
if (IsPNR)
PNRReg = DestReg;
Opc = AArch64::LDR_PXI;
Expand Down Expand Up @@ -5062,7 +5061,7 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
AArch64::subo64, FI, MMO);
return;
} else if (AArch64::ZPRRegClass.hasSubClassEq(RC)) {
assert(Subtarget.hasSVEorSME() &&
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected register load without SVE load instructions");
Opc = AArch64::LDR_ZXI;
StackID = TargetStackID::ScalableVector;
Expand All @@ -5086,7 +5085,7 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
Offset = false;
} else if (AArch64::ZPR2RegClass.hasSubClassEq(RC) ||
AArch64::ZPR2StridedOrContiguousRegClass.hasSubClassEq(RC)) {
assert(Subtarget.hasSVEorSME() &&
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected register load without SVE load instructions");
Opc = AArch64::LDR_ZZXI;
StackID = TargetStackID::ScalableVector;
Expand All @@ -5098,7 +5097,7 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
Opc = AArch64::LD1Threev2d;
Offset = false;
} else if (AArch64::ZPR3RegClass.hasSubClassEq(RC)) {
assert(Subtarget.hasSVEorSME() &&
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected register load without SVE load instructions");
Opc = AArch64::LDR_ZZZXI;
StackID = TargetStackID::ScalableVector;
Expand All @@ -5111,7 +5110,7 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
Offset = false;
} else if (AArch64::ZPR4RegClass.hasSubClassEq(RC) ||
AArch64::ZPR4StridedOrContiguousRegClass.hasSubClassEq(RC)) {
assert(Subtarget.hasSVEorSME() &&
assert(Subtarget.isSVEorStreamingSVEAvailable() &&
"Unexpected register load without SVE load instructions");
Opc = AArch64::LDR_ZZZZXI;
StackID = TargetStackID::ScalableVector;
Expand Down
54 changes: 30 additions & 24 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,35 +141,41 @@ def HasSPE : Predicate<"Subtarget->hasSPE()">,
def HasFuseAES : Predicate<"Subtarget->hasFuseAES()">,
AssemblerPredicateWithAll<(all_of FeatureFuseAES),
"fuse-aes">;
def HasSVE : Predicate<"Subtarget->hasSVE()">,
def HasSVE : Predicate<"Subtarget->isSVEAvailable()">,
AssemblerPredicateWithAll<(all_of FeatureSVE), "sve">;
def HasSVE2 : Predicate<"Subtarget->hasSVE2()">,
def HasSVE2 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2()">,
AssemblerPredicateWithAll<(all_of FeatureSVE2), "sve2">;
def HasSVE2p1 : Predicate<"Subtarget->hasSVE2p1()">,
def HasSVE2p1 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2p1()">,
AssemblerPredicateWithAll<(all_of FeatureSVE2p1), "sve2p1">;
def HasSVE2AES : Predicate<"Subtarget->hasSVE2AES()">,
def HasSVE2AES : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2AES()">,
AssemblerPredicateWithAll<(all_of FeatureSVE2AES), "sve2-aes">;
def HasSVE2SM4 : Predicate<"Subtarget->hasSVE2SM4()">,
def HasSVE2SM4 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2SM4()">,
AssemblerPredicateWithAll<(all_of FeatureSVE2SM4), "sve2-sm4">;
def HasSVE2SHA3 : Predicate<"Subtarget->hasSVE2SHA3()">,
def HasSVE2SHA3 : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2SHA3()">,
AssemblerPredicateWithAll<(all_of FeatureSVE2SHA3), "sve2-sha3">;
def HasSVE2BitPerm : Predicate<"Subtarget->hasSVE2BitPerm()">,
def HasSVE2BitPerm : Predicate<"Subtarget->isSVEAvailable() && Subtarget->hasSVE2BitPerm()">,
AssemblerPredicateWithAll<(all_of FeatureSVE2BitPerm), "sve2-bitperm">;
def HasB16B16 : Predicate<"Subtarget->hasB16B16()">,
AssemblerPredicateWithAll<(all_of FeatureB16B16), "b16b16">;
def HasSME : Predicate<"Subtarget->hasSME()">,
def HasSMEandIsNonStreamingSafe
: Predicate<"Subtarget->hasSME()">,
AssemblerPredicateWithAll<(all_of FeatureSME), "sme">;
def HasSMEF64F64 : Predicate<"Subtarget->hasSMEF64F64()">,
def HasSME : Predicate<"Subtarget->isStreaming() && Subtarget->hasSME()">,
AssemblerPredicateWithAll<(all_of FeatureSME), "sme">;
def HasSMEF64F64 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSMEF64F64()">,
AssemblerPredicateWithAll<(all_of FeatureSMEF64F64), "sme-f64f64">;
def HasSMEF16F16 : Predicate<"Subtarget->hasSMEF16F16()">,
def HasSMEF16F16 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSMEF16F16()">,
AssemblerPredicateWithAll<(all_of FeatureSMEF16F16), "sme-f16f16">;
def HasSMEFA64 : Predicate<"Subtarget->hasSMEFA64()">,
def HasSMEFA64 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSMEFA64()">,
AssemblerPredicateWithAll<(all_of FeatureSMEFA64), "sme-fa64">;
def HasSMEI16I64 : Predicate<"Subtarget->hasSMEI16I64()">,
def HasSMEI16I64 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSMEI16I64()">,
AssemblerPredicateWithAll<(all_of FeatureSMEI16I64), "sme-i16i64">;
def HasSME2 : Predicate<"Subtarget->hasSME2()">,
def HasSME2andIsNonStreamingSafe
: Predicate<"Subtarget->hasSME2()">,
AssemblerPredicateWithAll<(all_of FeatureSME2), "sme2">;
def HasSME2 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSME2()">,
AssemblerPredicateWithAll<(all_of FeatureSME2), "sme2">;
def HasSME2p1 : Predicate<"Subtarget->hasSME2p1()">,
def HasSME2p1 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSME2p1()">,
AssemblerPredicateWithAll<(all_of FeatureSME2p1), "sme2p1">;
def HasFP8 : Predicate<"Subtarget->hasFP8()">,
AssemblerPredicateWithAll<(all_of FeatureFP8), "fp8">;
Expand Down Expand Up @@ -198,39 +204,39 @@ def HasSSVE_FP8DOT4 : Predicate<"Subtarget->hasSSVE_FP8DOT4() || "
"ssve-fp8dot4 or (sve2 and fp8dot4)">;
def HasLUT : Predicate<"Subtarget->hasLUT()">,
AssemblerPredicateWithAll<(all_of FeatureLUT), "lut">;
def HasSME_LUTv2 : Predicate<"Subtarget->hasSME_LUTv2()">,
def HasSME_LUTv2 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSME_LUTv2()">,
AssemblerPredicateWithAll<(all_of FeatureSME_LUTv2), "sme-lutv2">;
def HasSMEF8F16 : Predicate<"Subtarget->hasSMEF8F16()">,
def HasSMEF8F16 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSMEF8F16()">,
AssemblerPredicateWithAll<(all_of FeatureSMEF8F16), "sme-f8f16">;
def HasSMEF8F32 : Predicate<"Subtarget->hasSMEF8F32()">,
def HasSMEF8F32 : Predicate<"Subtarget->isStreaming() && Subtarget->hasSMEF8F32()">,
AssemblerPredicateWithAll<(all_of FeatureSMEF8F32), "sme-f8f32">;

// A subset of SVE(2) instructions are legal in Streaming SVE execution mode,
// they should be enabled if either has been specified.
def HasSVEorSME
: Predicate<"Subtarget->hasSVEorSME()">,
: Predicate<"Subtarget->hasSVE() || (Subtarget->isStreaming() && Subtarget->hasSME())">,
AssemblerPredicateWithAll<(any_of FeatureSVE, FeatureSME),
"sve or sme">;
def HasSVE2orSME
: Predicate<"Subtarget->hasSVE2() || Subtarget->hasSME()">,
: Predicate<"Subtarget->hasSVE2() || (Subtarget->isStreaming() && Subtarget->hasSME())">,
AssemblerPredicateWithAll<(any_of FeatureSVE2, FeatureSME),
"sve2 or sme">;
def HasSVE2orSME2
: Predicate<"Subtarget->hasSVE2() || Subtarget->hasSME2()">,
: Predicate<"Subtarget->hasSVE2() || (Subtarget->isStreaming() && Subtarget->hasSME2())">,
AssemblerPredicateWithAll<(any_of FeatureSVE2, FeatureSME2),
"sve2 or sme2">;
def HasSVE2p1_or_HasSME
: Predicate<"Subtarget->hasSVE2p1() || Subtarget->hasSME()">,
: Predicate<"Subtarget->hasSVE2p1() || (Subtarget->isStreaming() && Subtarget->hasSME())">,
AssemblerPredicateWithAll<(any_of FeatureSME, FeatureSVE2p1), "sme or sve2p1">;
def HasSVE2p1_or_HasSME2
: Predicate<"Subtarget->hasSVE2p1() || Subtarget->hasSME2()">,
: Predicate<"Subtarget->hasSVE2p1() || (Subtarget->isStreaming() && Subtarget->hasSME2())">,
AssemblerPredicateWithAll<(any_of FeatureSME2, FeatureSVE2p1), "sme2 or sve2p1">;
def HasSVE2p1_or_HasSME2p1
: Predicate<"Subtarget->hasSVE2p1() || Subtarget->hasSME2p1()">,
: Predicate<"Subtarget->hasSVE2p1() || (Subtarget->isStreaming() && Subtarget->hasSME2p1())">,
AssemblerPredicateWithAll<(any_of FeatureSME2p1, FeatureSVE2p1), "sme2p1 or sve2p1">;

def HasSMEF16F16orSMEF8F16
: Predicate<"Subtarget->hasSMEF16F16() || Subtarget->hasSMEF8F16()">,
: Predicate<"Subtarget->isStreaming() && (Subtarget->hasSMEF16F16() || Subtarget->hasSMEF8F16())">,
AssemblerPredicateWithAll<(any_of FeatureSMEF16F16, FeatureSMEF8F16),
"sme-f16f16 or sme-f8f16">;

Expand Down
Loading

0 comments on commit 62baf21

Please sign in to comment.