Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the handling of SIMD comparisons #104944

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31012,6 +31012,32 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
}
#endif

case NI_Vector128_op_Equality:
#if defined(TARGET_ARM64)
case NI_Vector64_op_Equality:
#elif defined(TARGET_XARCH)
case NI_Vector256_op_Equality:
case NI_Vector512_op_Equality:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also NI_VectorX_EqualsAll (unless they're normalized to op_Equality somewhere). Btw the last time I tried to constant fold these, you told me that is odd to cover only EQ/NE relation operators ;-) #85584 (comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We’ve already added support for all the other comparisons (elementwise eq/ge/gt/le/lt/ne), what was remaining was the ==/!= operators, which this pr covers.

EqualsAll/Any and the other All/Any APIs are then imported as elementwise compare + op ==/!=, so this covers the full set

#endif // !TARGET_ARM64 && !TARGET_XARCH
{
cnsNode->AsVecCon()->EvaluateBinaryInPlace(GT_EQ, isScalar, simdBaseType, otherNode->AsVecCon());
resultNode = gtNewIconNode(cnsNode->AsVecCon()->IsAllBitsSet() ? 1 : 0, retType);
break;
}

case NI_Vector128_op_Inequality:
#if defined(TARGET_ARM64)
case NI_Vector64_op_Inequality:
#elif defined(TARGET_XARCH)
case NI_Vector256_op_Inequality:
case NI_Vector512_op_Inequality:
#endif // !TARGET_ARM64 && !TARGET_XARCH
{
cnsNode->AsVecCon()->EvaluateBinaryInPlace(GT_NE, isScalar, simdBaseType, otherNode->AsVecCon());
resultNode = gtNewIconNode(cnsNode->AsVecCon()->IsZero() ? 0 : 1, retType);
break;
}

default:
{
break;
Expand Down Expand Up @@ -31452,6 +31478,48 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
}
#endif

case NI_Vector128_op_Equality:
#if defined(TARGET_ARM64)
case NI_Vector64_op_Equality:
#elif defined(TARGET_XARCH)
case NI_Vector256_op_Equality:
case NI_Vector512_op_Equality:
#endif // !TARGET_ARM64 && !TARGET_XARCH
{
if (varTypeIsFloating(simdBaseType))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this path different from case NI_Vector128_op_Equality: above? if it's only for floats, then why it's not an assert?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's when only one of the operand is constant?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it’s for when one is all nan, which is an optimization we can do for float/double

{
// Handle `(x == NaN) == false` and `(NaN == x) == false` for floating-point types
if (cnsNode->IsVectorNaN(simdBaseType))
{
resultNode = gtNewIconNode(0, retType);
resultNode = gtWrapWithSideEffects(resultNode, otherNode, GTF_ALL_EFFECT);
break;
}
}
break;
}

case NI_Vector128_op_Inequality:
#if defined(TARGET_ARM64)
case NI_Vector64_op_Inequality:
#elif defined(TARGET_XARCH)
case NI_Vector256_op_Inequality:
case NI_Vector512_op_Inequality:
#endif // !TARGET_ARM64 && !TARGET_XARCH
{
if (varTypeIsFloating(simdBaseType))
{
// Handle `(x != NaN) == true` and `(NaN != x) == true` for floating-point types
if (cnsNode->IsVectorNaN(simdBaseType))
{
resultNode = gtNewIconNode(1, retType);
resultNode = gtWrapWithSideEffects(resultNode, otherNode, GTF_ALL_EFFECT);
break;
}
}
break;
}

default:
{
break;
Expand Down
185 changes: 97 additions & 88 deletions src/coreclr/jit/lowerxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2488,7 +2488,7 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm
CorInfoType maskBaseJitType = simdBaseJitType;
var_types maskBaseType = simdBaseType;

if (op1Msk->OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector))
if (op1Msk->OperIsConvertMaskToVector())
{
GenTreeHWIntrinsic* cvtMaskToVector = op1Msk->AsHWIntrinsic();

Expand All @@ -2499,122 +2499,131 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm
maskBaseType = cvtMaskToVector->GetSimdBaseType();
}

if (!varTypeIsFloating(simdBaseType) && (simdSize != 64) && op2->IsVectorZero() &&
comp->compOpportunisticallyDependsOn(InstructionSet_SSE41) && !varTypeIsMask(op1Msk))
if (!varTypeIsFloating(simdBaseType) && (simdSize != 64) && !varTypeIsMask(op1Msk))
{
// On SSE4.1 or higher we can optimize comparisons against zero to
// just use PTEST. We can't support it for floating-point, however,
// as it has both +0.0 and -0.0 where +0.0 == -0.0
bool isOp2VectorZero = op2->IsVectorZero();

bool skipReplaceOperands = false;

if (op1->OperIsHWIntrinsic())
if ((isOp2VectorZero || op2->IsVectorAllBitsSet()) &&
comp->compOpportunisticallyDependsOn(InstructionSet_SSE41))
{
GenTreeHWIntrinsic* op1Intrinsic = op1->AsHWIntrinsic();
NamedIntrinsic op1IntrinsicId = op1Intrinsic->GetHWIntrinsicId();
// On SSE4.1 or higher we can optimize comparisons against Zero or AllBitsSet to
// just use PTEST. We can't support it for floating-point, however, as it has
// both +0.0 and -0.0 where +0.0 == -0.0

GenTree* nestedOp1 = nullptr;
GenTree* nestedOp2 = nullptr;
bool isEmbeddedBroadcast = false;
bool skipReplaceOperands = false;

if (op1Intrinsic->GetOperandCount() == 2)
if (!isOp2VectorZero)
{
nestedOp1 = op1Intrinsic->Op(1);
nestedOp2 = op1Intrinsic->Op(2);
// We can optimize to TestC(op1, allbitsset)
//
// This works out because TestC sets CF if (~x & y) == 0, so:
// ~00 & 11 = 11; 11 & 11 = 11; NC
// ~01 & 11 = 01; 10 & 11 = 10; NC
// ~10 & 11 = 10; 01 & 11 = 01; NC
// ~11 & 11 = 11; 00 & 11 = 00; C

assert(!nestedOp1->isContained());
isEmbeddedBroadcast = nestedOp2->isContained() && nestedOp2->OperIsHWIntrinsic();
}
assert(op2->IsVectorAllBitsSet());
cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC;

switch (op1IntrinsicId)
skipReplaceOperands = true;
}
else if (op1->OperIsHWIntrinsic())
{
case NI_SSE_And:
case NI_SSE2_And:
case NI_AVX_And:
case NI_AVX2_And:
assert(op2->IsVectorZero());

GenTreeHWIntrinsic* op1Intrinsic = op1->AsHWIntrinsic();

if (op1Intrinsic->GetOperandCount() == 2)
{
// We can optimize to TestZ(op1.op1, op1.op2)
GenTree* nestedOp1 = op1Intrinsic->Op(1);
GenTree* nestedOp2 = op1Intrinsic->Op(2);

assert(!nestedOp1->isContained());
bool isEmbeddedBroadcast = nestedOp2->isContained() && nestedOp2->OperIsHWIntrinsic();

if (isEmbeddedBroadcast)
bool isScalar = false;
genTreeOps oper = op1Intrinsic->GetOperForHWIntrinsicId(&isScalar);

switch (oper)
{
// PTEST doesn't support embedded broadcast
break;
}
case GT_AND:
{
// We can optimize to TestZ(op1.op1, op1.op2)

node->Op(1) = nestedOp1;
node->Op(2) = nestedOp2;
if (isEmbeddedBroadcast)
{
// PTEST doesn't support embedded broadcast
break;
}

BlockRange().Remove(op1);
BlockRange().Remove(op2);
node->Op(1) = nestedOp1;
node->Op(2) = nestedOp2;

skipReplaceOperands = true;
break;
}
BlockRange().Remove(op1);
BlockRange().Remove(op2);

case NI_SSE_AndNot:
case NI_SSE2_AndNot:
case NI_AVX_AndNot:
case NI_AVX2_AndNot:
{
// We can optimize to TestC(op1.op1, op1.op2)
skipReplaceOperands = true;
break;
}

if (isEmbeddedBroadcast)
{
// PTEST doesn't support embedded broadcast
break;
}
case GT_AND_NOT:
{
// We can optimize to TestC(op1.op1, op1.op2)

if (isEmbeddedBroadcast)
{
// PTEST doesn't support embedded broadcast
break;
}

cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC;
cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC;

node->Op(1) = nestedOp1;
node->Op(2) = nestedOp2;
node->Op(1) = nestedOp1;
node->Op(2) = nestedOp2;

BlockRange().Remove(op1);
BlockRange().Remove(op2);
BlockRange().Remove(op1);
BlockRange().Remove(op2);

skipReplaceOperands = true;
break;
}
skipReplaceOperands = true;
break;
}

default:
{
break;
default:
{
break;
}
}
}
}
}

if (!skipReplaceOperands)
{
// Default handler, emit a TestZ(op1, op1)

node->Op(1) = op1;
BlockRange().Remove(op2);
if (!skipReplaceOperands)
{
// Default handler, emit a TestZ(op1, op1)
assert(op2->IsVectorZero());

LIR::Use op1Use(BlockRange(), &node->Op(1), node);
ReplaceWithLclVar(op1Use);
op1 = node->Op(1);
node->Op(1) = op1;
BlockRange().Remove(op2);

op2 = comp->gtClone(op1);
BlockRange().InsertAfter(op1, op2);
node->Op(2) = op2;
}
LIR::Use op1Use(BlockRange(), &node->Op(1), node);
ReplaceWithLclVar(op1Use);
op1 = node->Op(1);

if (simdSize == 32)
{
// TODO-Review: LowerHWIntrinsicCC resets the id again, so why is this needed?
node->ChangeHWIntrinsicId(NI_AVX_TestZ);
LowerHWIntrinsicCC(node, NI_AVX_PTEST, cmpCnd);
}
else
{
assert(simdSize == 16);
op2 = comp->gtClone(op1);
BlockRange().InsertAfter(op1, op2);
node->Op(2) = op2;
}

// TODO-Review: LowerHWIntrinsicCC resets the id again, so why is this needed?
node->ChangeHWIntrinsicId(NI_SSE41_TestZ);
LowerHWIntrinsicCC(node, NI_SSE41_PTEST, cmpCnd);
if (simdSize == 32)
{
LowerHWIntrinsicCC(node, NI_AVX_PTEST, cmpCnd);
}
else
{
assert(simdSize == 16);
LowerHWIntrinsicCC(node, NI_SSE41_PTEST, cmpCnd);
}
return LowerNode(node);
}

return LowerNode(node);
}

// TODO-XARCH-AVX512: We should handle TYP_SIMD12 here under the EVEX path, but doing
Expand Down Expand Up @@ -3490,7 +3499,7 @@ GenTree* Lowering::LowerHWIntrinsicTernaryLogic(GenTreeHWIntrinsic* node)
}
}

if (condition->OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector))
if (condition->OperIsConvertMaskToVector())
{
GenTree* tmp = condition->AsHWIntrinsic()->Op(1);
BlockRange().Remove(condition);
Expand Down
Loading
Loading