Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT ARM64-SVE: Add Sve.Abs() and Sve.Add() #100134

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3468,6 +3468,7 @@ class Compiler
#if defined(TARGET_ARM64)
GenTree* gtNewSimdConvertVectorToMaskNode(var_types type, GenTree* node, CorInfoType simdBaseJitType, unsigned simdSize);
GenTree* gtNewSimdConvertMaskToVectorNode(GenTreeHWIntrinsic* node, var_types type);
GenTree* gtNewSimdEmbeddedMaskNode(CorInfoType simdBaseJitType, unsigned simdSize);
#endif

//------------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions src/coreclr/jit/emitarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7875,7 +7875,7 @@ void emitter::emitIns_R_S(instruction ins, emitAttr attr, regNumber reg1, int va

// TODO-SVE: Don't assume 128bit vectors
// Predicate size is vector length / 8
scale = NaturalScale_helper(EA_2BYTE);
scale = 2;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the todo above, are we potentially missing an assert which would ensure this is updated for future instructions that need it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the todo above, are we potentially missing an assert which would ensure this is updated for future instructions that need it?

This is within code just for predicates. So will always be valid as long as the vector length is 128bits.

There is a question of what testing to do for >128bit vectors for .Net9. I suspect a lot of assumptions are made elsewhere that vector length is 128bits, and will require some major debugging. At some point I can do some testing on larger vector length machines. Due to time constraints, maybe the solution for .Net9 is have a check on startup: if vector length is >128bits then ask kernel to reduce to 128bits? Or just disable on >128 bits machines.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That question probably needs a much broader discussion.

There's notably not a lot of benefit of 128-bit SVE over AdvSimd. There's a few new instructions and the ability to emit denser assembly in a few cases, but most of that isn't for the typical hot loop of a method and in some case the code can be less dense (emitting SVE Abs requires predication and a ptrue to be generated; while AdvSimd Abs does not, so given 128-bit vectors and no predication; AdvSimd is better to use for that instruction).

At the same time, there is hardware (AWS Graviton) that has 256-bit SVE support that will most likely run on .NET 9. So it would probably be best if we could ensure it is appropriately handled and we're best able to take advantage of such hardware, not artificially limit it.

ssize_t mask = (1 << scale) - 1; // the mask of low bits that must be zero to encode the immediate

if (((imm & mask) == 0) && (isValidSimm<9>(imm >> scale)))
Expand Down Expand Up @@ -8154,7 +8154,7 @@ void emitter::emitIns_S_R(instruction ins, emitAttr attr, regNumber reg1, int va

// TODO-SVE: Don't assume 128bit vectors
// Predicate size is vector length / 8
scale = NaturalScale_helper(EA_2BYTE);
scale = 2;
ssize_t mask = (1 << scale) - 1; // the mask of low bits that must be zero to encode the immediate

if (((imm & mask) == 0) && (isValidSimm<9>(imm >> scale)))
Expand Down
70 changes: 54 additions & 16 deletions src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1396,6 +1396,60 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
GenTree* op3 = nullptr;
GenTree* op4 = nullptr;

switch (numArgs)
{
case 4:
op4 = getArgForHWIntrinsic(sigReader.GetOp4Type(), sigReader.op4ClsHnd);
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, immLowerBound, immUpperBound);
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

case 3:
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no case where a range check is needed for 3 arguments?

Should there be an assert to validate that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no case where a range check is needed for 3 arguments?

That seems to be the case.

The case for 3 args is checked further down and depending on the intrinsics, does the addRangeCheckIfNeeded on appropriate arg. Not sure if we should still add assert.

op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

case 2:
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

case 1:
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);
break;

default:
break;
}

#if defined(TARGET_ARM64)
// Embedded masks need inserting as op1.
if (HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrinsic))
{
numArgs++;
assert(numArgs <= 4);
switch (numArgs)
{
case 4:
op4 = op3;
tannergooding marked this conversation as resolved.
Show resolved Hide resolved
FALLTHROUGH;
case 3:
op3 = op2;
FALLTHROUGH;
case 2:
op2 = op1;
FALLTHROUGH;
default:
break;
}
op1 = gtNewSimdEmbeddedMaskNode(simdBaseJitType, simdSize);
Copy link
Member

@tannergooding tannergooding Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be done here?

This seems like its just inserting the implicit AllTrue mask that some instructions require, which is effectively allocating and forcing an extra node to be carried through all of HIR when the high level operation doesn't actually care about it.

Seemingly this could just be inserted as part of lowering instead so that it only impacts LSRA and codegen?

}
#endif

switch (numArgs)
{
case 0:
Expand All @@ -1407,8 +1461,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 1:
{
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

if ((category == HW_Category_MemoryLoad) && op1->OperIs(GT_CAST))
{
// Although the API specifies a pointer, if what we have is a BYREF, that's what
Expand Down Expand Up @@ -1467,10 +1519,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 2:
{
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op2 = addRangeCheckIfNeeded(intrinsic, op2, mustExpand, immLowerBound, immUpperBound);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

retNode = isScalar
? gtNewScalarHWIntrinsicNode(nodeRetType, op1, op2, intrinsic)
: gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, intrinsic, simdBaseJitType, simdSize);
Expand Down Expand Up @@ -1524,10 +1572,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 3:
{
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

#ifdef TARGET_ARM64
if (intrinsic == NI_AdvSimd_LoadAndInsertScalar)
{
Expand Down Expand Up @@ -1569,12 +1613,6 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,

case 4:
{
op4 = getArgForHWIntrinsic(sigReader.GetOp4Type(), sigReader.op4ClsHnd);
op4 = addRangeCheckIfNeeded(intrinsic, op4, mustExpand, immLowerBound, immUpperBound);
op3 = getArgForHWIntrinsic(sigReader.GetOp3Type(), sigReader.op3ClsHnd);
op2 = getArgForHWIntrinsic(sigReader.GetOp2Type(), sigReader.op2ClsHnd);
op1 = getArgForHWIntrinsic(sigReader.GetOp1Type(), sigReader.op1ClsHnd);

assert(!isScalar);
retNode =
gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);
Expand Down
11 changes: 10 additions & 1 deletion src/coreclr/jit/hwintrinsic.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ enum HWIntrinsicFlag : unsigned int
// The intrinsic uses a mask in arg1 to select elements present in the result, and must use a low register.
HW_Flag_LowMaskedOperation = 0x40000,

// The intrinsic uses a mask in arg1 to select elements present in the result, which is not present in the API call
HW_Flag_EmbeddedMaskedOperation = 0x80000,

#else
#error Unsupported platform
#endif
Expand Down Expand Up @@ -872,7 +875,7 @@ struct HWIntrinsicInfo
static bool IsMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return ((flags & HW_Flag_MaskedOperation) != 0) || IsLowMaskedOperation(id);
return ((flags & HW_Flag_MaskedOperation) != 0) || IsLowMaskedOperation(id) || IsEmbeddedMaskedOperation(id);
}

static bool IsLowMaskedOperation(NamedIntrinsic id)
Expand All @@ -881,6 +884,12 @@ struct HWIntrinsicInfo
return (flags & HW_Flag_LowMaskedOperation) != 0;
}

static bool IsEmbeddedMaskedOperation(NamedIntrinsic id)
{
const HWIntrinsicFlag flags = lookupFlags(id);
return (flags & HW_Flag_EmbeddedMaskedOperation) != 0;
}

#endif // TARGET_ARM64

static bool HasSpecialSideEffect(NamedIntrinsic id)
Expand Down
17 changes: 16 additions & 1 deletion src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2222,7 +2222,7 @@ GenTree* Compiler::gtNewSimdConvertVectorToMaskNode(var_types type,
assert(varTypeIsSIMD(node));

// ConvertVectorToMask uses cmpne which requires an embedded mask.
GenTree* embeddedMask = gtNewSimdHWIntrinsicNode(TYP_MASK, NI_Sve_CreateTrueMaskAll, simdBaseJitType, simdSize);
GenTree* embeddedMask = gtNewSimdEmbeddedMaskNode(simdBaseJitType, simdSize);
return gtNewSimdHWIntrinsicNode(TYP_MASK, embeddedMask, node, NI_Sve_ConvertVectorToMask, simdBaseJitType,
simdSize);
}
Expand All @@ -2246,4 +2246,19 @@ GenTree* Compiler::gtNewSimdConvertMaskToVectorNode(GenTreeHWIntrinsic* node, va
node->GetSimdSize());
}

//------------------------------------------------------------------------
// gtNewSimdEmbeddedMaskNode: Create an embedded mask
//
// Arguments:
// simdBaseJitType -- the base jit type of the nodes being masked
// simdSize -- the simd size of the nodes being masked
//
// Return Value:
// The mask
//
GenTree* Compiler::gtNewSimdEmbeddedMaskNode(CorInfoType simdBaseJitType, unsigned simdSize)
kunalspathak marked this conversation as resolved.
Show resolved Hide resolved
{
return gtNewSimdHWIntrinsicNode(TYP_MASK, NI_Sve_CreateTrueMaskAll, simdBaseJitType, simdSize);
}

#endif // FEATURE_HW_INTRINSICS
81 changes: 58 additions & 23 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,64 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
unreached();
}
}
else if (isRMW)
{
assert(!hasImmediateOperand);
assert(!HWIntrinsicInfo::SupportsContainment(intrin.id));

// Move the RMW register out of the way and do not pass it to the emit.

if (HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrin.id))
{
// op1Reg contains a mask, op2Reg contains the RMW register.

if (targetReg != op2Reg)
{
assert(targetReg != op3Reg);
GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op2Reg, /* canSkip */ true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if it is needed here, but we should start document the rules of using movprfx instruction. Do you have a good documentation that explains it and where it is used? I remember seeing it long back but not sure what it was.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same general mov needed for any RMW instruction since you need the destination register and one of the input registers to be the same. The only really different thing is that movprfx is specifically the one you want for SVE instructions that have an embedded mask. -- The potentially confusing part here is that it's checking IsEmbeddedMaskedOperation when that check as currently implemented just means "I have a mask, but its all true" (I gave feedback that should be renamed above)

So I'm not sure there's really any rules to document here, we certainly don't document it at all for x86 or x64 or any of the existing Arm64 RMW cases.

}

switch (intrin.numOperands)
{
case 2:
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
break;

case 3:
assert(targetReg != op3Reg);
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op3Reg, opt);
break;

default:
unreached();
}
}
else
{
// op1Reg contains the RMW register.

if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);
GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op1Reg, /* canSkip */ true);
}

switch (intrin.numOperands)
{
case 2:
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op2Reg, opt);
break;

case 3:
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op2Reg, op3Reg, opt);
Copy link
Member

@tannergooding tannergooding Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need an assert that none of the input registers are mask registers?

I'd guess the only RMW instructions with masks are ones that have at least 4 operands, so we shouldn't ever see that here.

break;

default:
unreached();
}
}
}
else
{
assert(!hasImmediateOperand);
Expand All @@ -416,35 +474,12 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
{
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
}
else if (isRMW)
{
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);

GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op1Reg,
/* canSkip */ true);
}
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op2Reg, opt);
}
else
{
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op1Reg, op2Reg, opt);
}
break;

case 3:
assert(isRMW);
if (targetReg != op1Reg)
{
assert(targetReg != op2Reg);
assert(targetReg != op3Reg);

GetEmitter()->emitIns_Mov(INS_mov, emitTypeSize(node), targetReg, op1Reg, /* canSkip */ true);
}
GetEmitter()->emitIns_R_R_R(ins, emitSize, targetReg, op2Reg, op3Reg, opt);
break;

default:
unreached();
}
Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
// SVE Intrinsics

// Sve
HARDWARE_INTRINSIC(Sve, Abs, -1, -1, false, {INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_fabs, INS_sve_fabs}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the Scalable flag? Can that not just be detected by the InstructionSet being Sve?


HARDWARE_INTRINSIC(Sve, Add, -1, -1, false, {INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_fadd, INS_sve_fadd}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate a bit on what LowMaskedOperation means?

I'm looking at the architecture manual and don't see any limitations on the Add (predicated) instruction. The attached comment to the enum entry is // The intrinsic uses a mask in arg1 to select elements present in the result, and must use a low register.`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate a bit on what LowMaskedOperation means?

The instruction only has 3bits for a predicate register so is limited to using predicate registers 0 to 7. This is quite a common pattern across Sve, hence using an common enum for it (I wouldn't do similar for the handful have only 2 bits)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I think a flag makes sense then, but I'm not a huge fan of the name given how low is used in various other contexts.

Maybe something more explicit like RestrictedPredicateRegisterSet (or an alternative name) would work and make it clearer what the mask means/implies?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this one in particular need EmbeddedMaskedOperation, the manual has entries for both Add (vectors, predicated) and Add (vectors, unpredicated). Both entries look to be for SVE1


HARDWARE_INTRINSIC(Sve, CreateTrueMaskByte, -1, 1, false, {INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskDouble, -1, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_ptrue}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskInt16, -1, 1, false, {INS_invalid, INS_invalid, INS_sve_ptrue, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_EnumPattern, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_ReturnsPerElementMask)
Expand Down
Loading
Loading