Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arm64/Sve: Implement SVE Math *Multiply* APIs #102007

Merged
merged 30 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
97373ca
Add *Fused* APIs
kunalspathak May 6, 2024
4e14098
fix an assert in morph
kunalspathak May 7, 2024
3fb9dea
Map APIs to instructions
kunalspathak May 7, 2024
600391a
Add test cases
kunalspathak May 8, 2024
67e4d4d
handle fused* instructions
kunalspathak May 8, 2024
54899b2
jit format
kunalspathak May 8, 2024
e4a53ae
Added MultiplyAdd/MultiplySubtract
kunalspathak May 8, 2024
bfad7b7
Add mapping of API to instruction
kunalspathak May 8, 2024
100f289
Add test cases
kunalspathak May 8, 2024
8ac1840
Handle mov Z, Z instruction
kunalspathak May 9, 2024
9eb195e
Reuse GetResultOpNumForRmwIntrinsic() for arm64
kunalspathak May 9, 2024
c182d0d
Reuse HW_Flag_FmaIntrinsic for arm64
kunalspathak May 9, 2024
62ea159
Mark FMA APIs as HW_Flag_FmaIntrinsic
kunalspathak May 9, 2024
28a49cb
Handle FMA in LSRA and codegen
kunalspathak May 9, 2024
722dd55
Remove the SpecialCodeGen flag from selectedScalar
kunalspathak May 9, 2024
229f78f
address some more scenarios
kunalspathak May 10, 2024
a21439f
jit format
kunalspathak May 10, 2024
6a01ca4
Add MultiplyBySelectedScalar
kunalspathak May 10, 2024
318cbf3
Map the API to the instruction
kunalspathak May 10, 2024
e3fc830
fix a bug where *Indexed API used with ConditionalSelect were failing
kunalspathak May 10, 2024
1ca5539
unpredicated movprfx should not send opt
kunalspathak May 10, 2024
eb41e1d
Add the missing flags for Subtract/Multiply
kunalspathak May 10, 2024
7874f25
Added tests for MultiplyBySelectedScalar
kunalspathak May 10, 2024
f756afb
fixes to test cases
kunalspathak May 10, 2024
2904934
fix the parameter for selectedScalar test
kunalspathak May 10, 2024
53d29a0
Merge remote-tracking branch 'origin/main' into sve_math6
kunalspathak May 10, 2024
98ac0ce
jit format
kunalspathak May 10, 2024
0f89e10
Contain(op3) of CndSel if op1 is AllTrueMask
kunalspathak May 10, 2024
8e928ec
Handle FMA properly
kunalspathak May 10, 2024
c713d31
added assert
kunalspathak May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4250,9 +4250,11 @@ void emitter::emitIns_Mov(

case INS_sve_mov:
{
if (isPredicateRegister(dstReg) && isPredicateRegister(srcReg))
// TODO-SVE: Remove check for insOptsNone() when predicate registers
// are present.
kunalspathak marked this conversation as resolved.
Show resolved Hide resolved
if (insOptsNone(opt) && isPredicateRegister(dstReg) && isPredicateRegister(srcReg))
{
assert(insOptsNone(opt));
// assert(insOptsNone(opt));

opt = INS_OPTS_SCALABLE_B;
attr = EA_SCALABLE;
Expand All @@ -4263,6 +4265,16 @@ void emitter::emitIns_Mov(
}
fmt = IF_SVE_CZ_4A_L;
}
else if (isVectorRegister(dstReg) && isVectorRegister(srcReg))
{
assert(insOptsScalable(opt));

if (IsRedundantMov(ins, size, dstReg, srcReg, canSkip))
{
return;
}
fmt = IF_SVE_AU_3A;
}
else
{
unreached();
Expand Down
34 changes: 31 additions & 3 deletions src/coreclr/jit/emitarm64sve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10374,7 +10374,6 @@ BYTE* emitter::emitOutput_InstrSve(BYTE* dst, instrDesc* id)
case IF_SVE_FN_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply long
case IF_SVE_FO_3A: // ...........mmmmm ......nnnnnddddd -- SVE integer matrix multiply accumulate
case IF_SVE_AT_3B: // ...........mmmmm ......nnnnnddddd -- SVE integer add/subtract vectors (unpredicated)
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
case IF_SVE_BD_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply vectors (unpredicated)
case IF_SVE_EF_3A: // ...........mmmmm ......nnnnnddddd -- SVE two-way dot product
case IF_SVE_EI_3A: // ...........mmmmm ......nnnnnddddd -- SVE mixed sign dot product
Expand All @@ -10396,6 +10395,17 @@ BYTE* emitter::emitOutput_InstrSve(BYTE* dst, instrDesc* id)
dst += emitOutput_Instr(dst, code);
break;

case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
code = emitInsCodeSve(ins, fmt);
code |= insEncodeReg_V<4, 0>(id->idReg1()); // ddddd
code |= insEncodeReg_V<9, 5>(id->idReg2()); // nnnnn
if (id->idIns() != INS_sve_mov)
{
code |= insEncodeReg_V<20, 16>(id->idReg3()); // mmmmm
}
dst += emitOutput_Instr(dst, code);
break;

case IF_SVE_AV_3A: // ...........mmmmm ......kkkkkddddd -- SVE2 bitwise ternary operations
code = emitInsCodeSve(ins, fmt);
code |= insEncodeReg_V<4, 0>(id->idReg1()); // ddddd
Expand Down Expand Up @@ -12882,7 +12892,6 @@ void emitter::emitInsSveSanityCheck(instrDesc* id)
case IF_SVE_FN_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply long
case IF_SVE_FO_3A: // ...........mmmmm ......nnnnnddddd -- SVE integer matrix multiply accumulate
case IF_SVE_AT_3B: // ...........mmmmm ......nnnnnddddd -- SVE integer add/subtract vectors (unpredicated)
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
case IF_SVE_BD_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply vectors (unpredicated)
case IF_SVE_EF_3A: // ...........mmmmm ......nnnnnddddd -- SVE two-way dot product
case IF_SVE_EI_3A: // ...........mmmmm ......nnnnnddddd -- SVE mixed sign dot product
Expand All @@ -12902,6 +12911,12 @@ void emitter::emitInsSveSanityCheck(instrDesc* id)
assert(isVectorRegister(id->idReg2())); // nnnnn/mmmmm
assert(isVectorRegister(id->idReg3())); // mmmmm/aaaaa
break;
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
assert(insOptsScalable(id->idInsOpt()));
assert(isVectorRegister(id->idReg1())); // ddddd
assert(isVectorRegister(id->idReg2())); // nnnnn/mmmmm
assert((id->idIns() == INS_sve_mov) || isVectorRegister(id->idReg3())); // mmmmm/aaaaa
break;

case IF_SVE_HA_3A_F: // ...........mmmmm ......nnnnnddddd -- SVE BFloat16 floating-point dot product
case IF_SVE_EW_3A: // ...........mmmmm ......nnnnnddddd -- SVE2 multiply-add (checked pointer)
Expand Down Expand Up @@ -14526,7 +14541,6 @@ void emitter::emitDispInsSveHelp(instrDesc* id)
case IF_SVE_HD_3A_A: // ...........mmmmm ......nnnnnddddd -- SVE floating point matrix multiply accumulate
// <Zd>.D, <Zn>.D, <Zm>.D
case IF_SVE_AT_3B: // ...........mmmmm ......nnnnnddddd -- SVE integer add/subtract vectors (unpredicated)
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
// <Zd>.B, <Zn>.B, <Zm>.B
case IF_SVE_GF_3A: // ........xx.mmmmm ......nnnnnddddd -- SVE2 histogram generation (segment)
case IF_SVE_BD_3B: // ...........mmmmm ......nnnnnddddd -- SVE2 integer multiply vectors (unpredicated)
Expand All @@ -14541,6 +14555,20 @@ void emitter::emitDispInsSveHelp(instrDesc* id)
emitDispSveReg(id->idReg3(), id->idInsOpt(), false); // mmmmm/aaaaa
break;

// <Zd>.D, <Zn>.D, <Zm>.D
case IF_SVE_AU_3A: // ...........mmmmm ......nnnnnddddd -- SVE bitwise logical operations (unpredicated)
emitDispSveReg(id->idReg1(), id->idInsOpt(), true); // ddddd
if (id->idIns() == INS_sve_mov)
{
emitDispSveReg(id->idReg2(), id->idInsOpt(), false); // nnnnn/mmmmm
}
else
{
emitDispSveReg(id->idReg2(), id->idInsOpt(), true); // nnnnn/mmmmm
emitDispSveReg(id->idReg3(), id->idInsOpt(), false); // mmmmm/aaaaa
}
break;

// <Zda>.D, <Zn>.D, <Zm>.D
case IF_SVE_EW_3A: // ...........mmmmm ......nnnnnddddd -- SVE2 multiply-add (checked pointer)
// <Zdn>.D, <Zm>.D, <Za>.D
Expand Down
6 changes: 5 additions & 1 deletion src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27955,7 +27955,7 @@ bool GenTreeLclVar::IsNeverNegative(Compiler* comp) const
return comp->lvaGetDesc(GetLclNum())->IsNeverNegative();
}

#if defined(TARGET_XARCH) && defined(FEATURE_HW_INTRINSICS)
#if (defined(TARGET_XARCH) || defined(TARGET_ARM64)) && defined(FEATURE_HW_INTRINSICS)
//------------------------------------------------------------------------
// GetResultOpNumForRmwIntrinsic: check if the result is written into one of the operands.
// In the case that none of the operand is overwritten, check if any of them is lastUse.
Expand All @@ -27966,7 +27966,11 @@ bool GenTreeLclVar::IsNeverNegative(Compiler* comp) const
//
unsigned GenTreeHWIntrinsic::GetResultOpNumForRmwIntrinsic(GenTree* use, GenTree* op1, GenTree* op2, GenTree* op3)
{
#if defined(TARGET_XARCH)
assert(HWIntrinsicInfo::IsFmaIntrinsic(gtHWIntrinsicId) || HWIntrinsicInfo::IsPermuteVar2x(gtHWIntrinsicId));
#elif defined(TARGET_ARM64)
assert(HWIntrinsicInfo::IsFmaIntrinsic(gtHWIntrinsicId));
#endif

if (use != nullptr && use->OperIs(GT_STORE_LCL_VAR))
{
Expand Down
24 changes: 12 additions & 12 deletions src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,27 +216,27 @@ enum HWIntrinsicFlag : unsigned int
// The intrinsic is an RMW intrinsic
HW_Flag_RmwIntrinsic = 0x1000000,

// The intrinsic is a FusedMultiplyAdd intrinsic
HW_Flag_FmaIntrinsic = 0x2000000,

// The intrinsic is a PermuteVar2x intrinsic
HW_Flag_PermuteVar2x = 0x4000000,
HW_Flag_PermuteVar2x = 0x2000000,

// The intrinsic is an embedded broadcast compatible intrinsic
HW_Flag_EmbBroadcastCompatible = 0x8000000,
HW_Flag_EmbBroadcastCompatible = 0x4000000,

// The intrinsic is an embedded rounding compatible intrinsic
HW_Flag_EmbRoundingCompatible = 0x10000000,
HW_Flag_EmbRoundingCompatible = 0x8000000,

// The intrinsic is an embedded masking compatible intrinsic
HW_Flag_EmbMaskingCompatible = 0x20000000,
HW_Flag_EmbMaskingCompatible = 0x10000000,
#elif defined(TARGET_ARM64)

// The intrinsic has an enum operand. Using this implies HW_Flag_HasImmediateOperand.
HW_Flag_HasEnumOperand = 0x1000000,

#endif // TARGET_XARCH

// The intrinsic is a FusedMultiplyAdd intrinsic
HW_Flag_FmaIntrinsic = 0x20000000,

HW_Flag_CanBenefitFromConstantProp = 0x80000000,
};

Expand Down Expand Up @@ -935,17 +935,17 @@ struct HWIntrinsicInfo
return (flags & HW_Flag_MaybeNoJmpTableIMM) != 0;
}

#if defined(TARGET_XARCH)
static bool IsRmwIntrinsic(NamedIntrinsic id)
static bool IsFmaIntrinsic(NamedIntrinsic id)
{
HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_RmwIntrinsic) != 0;
return (flags & HW_Flag_FmaIntrinsic) != 0;
}

static bool IsFmaIntrinsic(NamedIntrinsic id)
#if defined(TARGET_XARCH)
static bool IsRmwIntrinsic(NamedIntrinsic id)
{
HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_FmaIntrinsic) != 0;
return (flags & HW_Flag_RmwIntrinsic) != 0;
}

static bool IsPermuteVar2x(NamedIntrinsic id)
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ void HWIntrinsicInfo::lookupImmBounds(
case NI_AdvSimd_Arm64_StoreSelectedScalarVector128x4:
case NI_AdvSimd_Arm64_DuplicateSelectedScalarToVector128:
case NI_AdvSimd_Arm64_InsertSelectedScalar:
case NI_Sve_FusedMultiplyAddBySelectedScalar:
case NI_Sve_FusedMultiplySubtractBySelectedScalar:
immUpperBound = Compiler::getSIMDVectorLength(simdSize, baseType) - 1;
break;

Expand Down
Loading