Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerating Vector512.Sum() #87851

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23864,7 +23864,6 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si

#if defined(TARGET_XARCH)
assert(!varTypeIsByte(simdBaseType) && !varTypeIsLong(simdBaseType));
assert(simdSize != 64);

// HorizontalAdd combines pairs so we need log2(vectorLength) passes to sum all elements together.
unsigned vectorLength = getSIMDVectorLength(simdSize, simdBaseType);
Expand Down Expand Up @@ -23897,6 +23896,35 @@ GenTree* Compiler::gtNewSimdSumNode(var_types type, GenTree* op1, CorInfoType si
intrinsic = NI_SSSE3_HorizontalAdd;
}

if (simdSize == 64)
{
assert(IsBaselineVector512IsaSupportedDebugOnly());
// This is roughly the following managed code:
// ...
// simd64 tmp2 = tmp1;
// tmp3 = tmp2.GetUpper();
// simd32 tmp4 = Isa.Add(tmp1.GetLower(), tmp2);
// tmp5 = tmp4;
// simd16 tmp6 = tmp4.GetUpper();
// tmp1 = Isa.Add(tmp1.GetLower(), tmp2);
// ...
// From here on we can treat this as a simd16 reduction
GenTree* op1Dup = fgMakeMultiUse(&op1);
GenTree* op1Lower32 = gtNewSimdGetUpperNode(TYP_SIMD32, op1Dup, simdBaseJitType, simdSize);
GenTree* op1Upper32 = gtNewSimdGetLowerNode(TYP_SIMD32, op1, simdBaseJitType, simdSize);

simdSize = simdSize / 2;
op1Lower32 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD32, op1Lower32, op1Upper32, simdBaseJitType, simdSize);
haddCount--;

GenTree* op1Dup32 = fgMakeMultiUse(&op1Lower32);
GenTree* op1Lower16 = gtNewSimdGetUpperNode(TYP_SIMD16, op1Lower32, simdBaseJitType, simdSize);
GenTree* op1Upper16 = gtNewSimdGetLowerNode(TYP_SIMD16, op1Dup32, simdBaseJitType, simdSize);
simdSize = simdSize / 2;
op1 = gtNewSimdBinOpNode(GT_ADD, TYP_SIMD16, op1Lower16, op1Upper16, simdBaseJitType, simdSize);
haddCount--;
}
Comment on lines +23899 to +23926
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is "correct", but has an interesting side effect in that can change the result for float/double.

Since float/double are not associative due to the rounding, summing accross as [0] + [1] + [2] + [3] + ... is different from summing pairwise as (([0] + [1]) + ([2] + [3])) + ..., which is different from summing per lane, then combining the lanes, etc.

Today, we're basically doing it per lane, then combining lanes. Within that lane we're typically doing pairwise because that's how addv (add across) works on Arm64, it's how hadd (horizontal add) works on x86/x64, and its trivial to emulate using shufps on older hardware.

We don't really want to have the results subtly change based on what the hardware supports, so we should probably try to ensure this keeps things overall consistent in how it operates.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm busy with something else but will get back to this once I'm done. Sorry for the delay

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries. This one isn't "critical" for .NET 8 and we're already generating decent (but not amazing) code that would be similar to what a user might manually write.


for (int i = 0; i < haddCount; i++)
{
tmp = fgMakeMultiUse(&op1);
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistxarch.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ HARDWARE_INTRINSIC(Vector512, StoreAligned,
HARDWARE_INTRINSIC(Vector512, StoreAlignedNonTemporal, 64, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen)
HARDWARE_INTRINSIC(Vector512, StoreUnsafe, 64, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen)
HARDWARE_INTRINSIC(Vector512, Subtract, 64, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen)
HARDWARE_INTRINSIC(Vector512, Sum, 64, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_BaseTypeFromFirstArg|HW_Flag_NoCodeGen)
HARDWARE_INTRINSIC(Vector512, ToScalar, 64, 1, true, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_movss, INS_movsd_simd}, HW_Category_SIMDScalar, HW_Flag_SpecialImport|HW_Flag_SpecialCodeGen|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Vector512, WidenLower, 64, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Vector512, WidenUpper, 64, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_SpecialImport|HW_Flag_NoCodeGen|HW_Flag_BaseTypeFromFirstArg)
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsicxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2846,6 +2846,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,

case NI_Vector128_Sum:
case NI_Vector256_Sum:
case NI_Vector512_Sum:
{
assert(sig->numArgs == 1);
var_types simdType = getSIMDTypeForSize(simdSize);
Expand Down