Skip to content

Commit

Permalink
Improve the handling of SIMD comparisons (#104944)
Browse files Browse the repository at this point in the history
* Ensure that we can constant fold op_Equality and op_Inequality for SIMD

* Optimize comparisons against AllBitsSet on pre-AVX512 hardware
  • Loading branch information
tannergooding authored Jul 19, 2024
1 parent 39968e7 commit e0ecd1f
Show file tree
Hide file tree
Showing 3 changed files with 397 additions and 88 deletions.
68 changes: 68 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30940,6 +30940,32 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
}
#endif

case NI_Vector128_op_Equality:
#if defined(TARGET_ARM64)
case NI_Vector64_op_Equality:
#elif defined(TARGET_XARCH)
case NI_Vector256_op_Equality:
case NI_Vector512_op_Equality:
#endif // !TARGET_ARM64 && !TARGET_XARCH
{
cnsNode->AsVecCon()->EvaluateBinaryInPlace(GT_EQ, isScalar, simdBaseType, otherNode->AsVecCon());
resultNode = gtNewIconNode(cnsNode->AsVecCon()->IsAllBitsSet() ? 1 : 0, retType);
break;
}

case NI_Vector128_op_Inequality:
#if defined(TARGET_ARM64)
case NI_Vector64_op_Inequality:
#elif defined(TARGET_XARCH)
case NI_Vector256_op_Inequality:
case NI_Vector512_op_Inequality:
#endif // !TARGET_ARM64 && !TARGET_XARCH
{
cnsNode->AsVecCon()->EvaluateBinaryInPlace(GT_NE, isScalar, simdBaseType, otherNode->AsVecCon());
resultNode = gtNewIconNode(cnsNode->AsVecCon()->IsZero() ? 0 : 1, retType);
break;
}

default:
{
break;
Expand Down Expand Up @@ -31380,6 +31406,48 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
}
#endif

case NI_Vector128_op_Equality:
#if defined(TARGET_ARM64)
case NI_Vector64_op_Equality:
#elif defined(TARGET_XARCH)
case NI_Vector256_op_Equality:
case NI_Vector512_op_Equality:
#endif // !TARGET_ARM64 && !TARGET_XARCH
{
if (varTypeIsFloating(simdBaseType))
{
// Handle `(x == NaN) == false` and `(NaN == x) == false` for floating-point types
if (cnsNode->IsVectorNaN(simdBaseType))
{
resultNode = gtNewIconNode(0, retType);
resultNode = gtWrapWithSideEffects(resultNode, otherNode, GTF_ALL_EFFECT);
break;
}
}
break;
}

case NI_Vector128_op_Inequality:
#if defined(TARGET_ARM64)
case NI_Vector64_op_Inequality:
#elif defined(TARGET_XARCH)
case NI_Vector256_op_Inequality:
case NI_Vector512_op_Inequality:
#endif // !TARGET_ARM64 && !TARGET_XARCH
{
if (varTypeIsFloating(simdBaseType))
{
// Handle `(x != NaN) == true` and `(NaN != x) == true` for floating-point types
if (cnsNode->IsVectorNaN(simdBaseType))
{
resultNode = gtNewIconNode(1, retType);
resultNode = gtWrapWithSideEffects(resultNode, otherNode, GTF_ALL_EFFECT);
break;
}
}
break;
}

default:
{
break;
Expand Down
185 changes: 97 additions & 88 deletions src/coreclr/jit/lowerxarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2577,7 +2577,7 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm
CorInfoType maskBaseJitType = simdBaseJitType;
var_types maskBaseType = simdBaseType;

if (op1Msk->OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector))
if (op1Msk->OperIsConvertMaskToVector())
{
GenTreeHWIntrinsic* cvtMaskToVector = op1Msk->AsHWIntrinsic();

Expand All @@ -2588,122 +2588,131 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm
maskBaseType = cvtMaskToVector->GetSimdBaseType();
}

if (!varTypeIsFloating(simdBaseType) && (simdSize != 64) && op2->IsVectorZero() &&
comp->compOpportunisticallyDependsOn(InstructionSet_SSE41) && !varTypeIsMask(op1Msk))
if (!varTypeIsFloating(simdBaseType) && (simdSize != 64) && !varTypeIsMask(op1Msk))
{
// On SSE4.1 or higher we can optimize comparisons against zero to
// just use PTEST. We can't support it for floating-point, however,
// as it has both +0.0 and -0.0 where +0.0 == -0.0
bool isOp2VectorZero = op2->IsVectorZero();

bool skipReplaceOperands = false;

if (op1->OperIsHWIntrinsic())
if ((isOp2VectorZero || op2->IsVectorAllBitsSet()) &&
comp->compOpportunisticallyDependsOn(InstructionSet_SSE41))
{
GenTreeHWIntrinsic* op1Intrinsic = op1->AsHWIntrinsic();
NamedIntrinsic op1IntrinsicId = op1Intrinsic->GetHWIntrinsicId();
// On SSE4.1 or higher we can optimize comparisons against Zero or AllBitsSet to
// just use PTEST. We can't support it for floating-point, however, as it has
// both +0.0 and -0.0 where +0.0 == -0.0

GenTree* nestedOp1 = nullptr;
GenTree* nestedOp2 = nullptr;
bool isEmbeddedBroadcast = false;
bool skipReplaceOperands = false;

if (op1Intrinsic->GetOperandCount() == 2)
if (!isOp2VectorZero)
{
nestedOp1 = op1Intrinsic->Op(1);
nestedOp2 = op1Intrinsic->Op(2);
// We can optimize to TestC(op1, allbitsset)
//
// This works out because TestC sets CF if (~x & y) == 0, so:
// ~00 & 11 = 11; 11 & 11 = 11; NC
// ~01 & 11 = 01; 10 & 11 = 10; NC
// ~10 & 11 = 10; 01 & 11 = 01; NC
// ~11 & 11 = 11; 00 & 11 = 00; C

assert(!nestedOp1->isContained());
isEmbeddedBroadcast = nestedOp2->isContained() && nestedOp2->OperIsHWIntrinsic();
}
assert(op2->IsVectorAllBitsSet());
cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC;

switch (op1IntrinsicId)
skipReplaceOperands = true;
}
else if (op1->OperIsHWIntrinsic())
{
case NI_SSE_And:
case NI_SSE2_And:
case NI_AVX_And:
case NI_AVX2_And:
assert(op2->IsVectorZero());

GenTreeHWIntrinsic* op1Intrinsic = op1->AsHWIntrinsic();

if (op1Intrinsic->GetOperandCount() == 2)
{
// We can optimize to TestZ(op1.op1, op1.op2)
GenTree* nestedOp1 = op1Intrinsic->Op(1);
GenTree* nestedOp2 = op1Intrinsic->Op(2);

assert(!nestedOp1->isContained());
bool isEmbeddedBroadcast = nestedOp2->isContained() && nestedOp2->OperIsHWIntrinsic();

if (isEmbeddedBroadcast)
bool isScalar = false;
genTreeOps oper = op1Intrinsic->GetOperForHWIntrinsicId(&isScalar);

switch (oper)
{
// PTEST doesn't support embedded broadcast
break;
}
case GT_AND:
{
// We can optimize to TestZ(op1.op1, op1.op2)

node->Op(1) = nestedOp1;
node->Op(2) = nestedOp2;
if (isEmbeddedBroadcast)
{
// PTEST doesn't support embedded broadcast
break;
}

BlockRange().Remove(op1);
BlockRange().Remove(op2);
node->Op(1) = nestedOp1;
node->Op(2) = nestedOp2;

skipReplaceOperands = true;
break;
}
BlockRange().Remove(op1);
BlockRange().Remove(op2);

case NI_SSE_AndNot:
case NI_SSE2_AndNot:
case NI_AVX_AndNot:
case NI_AVX2_AndNot:
{
// We can optimize to TestC(op1.op1, op1.op2)
skipReplaceOperands = true;
break;
}

if (isEmbeddedBroadcast)
{
// PTEST doesn't support embedded broadcast
break;
}
case GT_AND_NOT:
{
// We can optimize to TestC(op1.op1, op1.op2)

if (isEmbeddedBroadcast)
{
// PTEST doesn't support embedded broadcast
break;
}

cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC;
cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC;

node->Op(1) = nestedOp1;
node->Op(2) = nestedOp2;
node->Op(1) = nestedOp1;
node->Op(2) = nestedOp2;

BlockRange().Remove(op1);
BlockRange().Remove(op2);
BlockRange().Remove(op1);
BlockRange().Remove(op2);

skipReplaceOperands = true;
break;
}
skipReplaceOperands = true;
break;
}

default:
{
break;
default:
{
break;
}
}
}
}
}

if (!skipReplaceOperands)
{
// Default handler, emit a TestZ(op1, op1)

node->Op(1) = op1;
BlockRange().Remove(op2);
if (!skipReplaceOperands)
{
// Default handler, emit a TestZ(op1, op1)
assert(op2->IsVectorZero());

LIR::Use op1Use(BlockRange(), &node->Op(1), node);
ReplaceWithLclVar(op1Use);
op1 = node->Op(1);
node->Op(1) = op1;
BlockRange().Remove(op2);

op2 = comp->gtClone(op1);
BlockRange().InsertAfter(op1, op2);
node->Op(2) = op2;
}
LIR::Use op1Use(BlockRange(), &node->Op(1), node);
ReplaceWithLclVar(op1Use);
op1 = node->Op(1);

if (simdSize == 32)
{
// TODO-Review: LowerHWIntrinsicCC resets the id again, so why is this needed?
node->ChangeHWIntrinsicId(NI_AVX_TestZ);
LowerHWIntrinsicCC(node, NI_AVX_PTEST, cmpCnd);
}
else
{
assert(simdSize == 16);
op2 = comp->gtClone(op1);
BlockRange().InsertAfter(op1, op2);
node->Op(2) = op2;
}

// TODO-Review: LowerHWIntrinsicCC resets the id again, so why is this needed?
node->ChangeHWIntrinsicId(NI_SSE41_TestZ);
LowerHWIntrinsicCC(node, NI_SSE41_PTEST, cmpCnd);
if (simdSize == 32)
{
LowerHWIntrinsicCC(node, NI_AVX_PTEST, cmpCnd);
}
else
{
assert(simdSize == 16);
LowerHWIntrinsicCC(node, NI_SSE41_PTEST, cmpCnd);
}
return LowerNode(node);
}

return LowerNode(node);
}

// TODO-XARCH-AVX512: We should handle TYP_SIMD12 here under the EVEX path, but doing
Expand Down Expand Up @@ -3579,7 +3588,7 @@ GenTree* Lowering::LowerHWIntrinsicTernaryLogic(GenTreeHWIntrinsic* node)
}
}

if (condition->OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector))
if (condition->OperIsConvertMaskToVector())
{
GenTree* tmp = condition->AsHWIntrinsic()->Op(1);
BlockRange().Remove(condition);
Expand Down
Loading

0 comments on commit e0ecd1f

Please sign in to comment.