From e0ecd1f9f980d90d937dafb186b0bef90e6317bd Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Fri, 19 Jul 2024 11:55:27 -0700 Subject: [PATCH] Improve the handling of SIMD comparisons (#104944) * Ensure that we can constant fold op_Equality and op_Inequality for SIMD * Optimize comparisons against AllBitsSet on pre-AVX512 hardware --- src/coreclr/jit/gentree.cpp | 68 ++++++++++ src/coreclr/jit/lowerxarch.cpp | 185 +++++++++++++------------- src/coreclr/jit/valuenum.cpp | 232 +++++++++++++++++++++++++++++++++ 3 files changed, 397 insertions(+), 88 deletions(-) diff --git a/src/coreclr/jit/gentree.cpp b/src/coreclr/jit/gentree.cpp index 759dde9b584fb..9b08a8fea32c5 100644 --- a/src/coreclr/jit/gentree.cpp +++ b/src/coreclr/jit/gentree.cpp @@ -30940,6 +30940,32 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) } #endif + case NI_Vector128_op_Equality: +#if defined(TARGET_ARM64) + case NI_Vector64_op_Equality: +#elif defined(TARGET_XARCH) + case NI_Vector256_op_Equality: + case NI_Vector512_op_Equality: +#endif // !TARGET_ARM64 && !TARGET_XARCH + { + cnsNode->AsVecCon()->EvaluateBinaryInPlace(GT_EQ, isScalar, simdBaseType, otherNode->AsVecCon()); + resultNode = gtNewIconNode(cnsNode->AsVecCon()->IsAllBitsSet() ? 1 : 0, retType); + break; + } + + case NI_Vector128_op_Inequality: +#if defined(TARGET_ARM64) + case NI_Vector64_op_Inequality: +#elif defined(TARGET_XARCH) + case NI_Vector256_op_Inequality: + case NI_Vector512_op_Inequality: +#endif // !TARGET_ARM64 && !TARGET_XARCH + { + cnsNode->AsVecCon()->EvaluateBinaryInPlace(GT_NE, isScalar, simdBaseType, otherNode->AsVecCon()); + resultNode = gtNewIconNode(cnsNode->AsVecCon()->IsZero() ? 0 : 1, retType); + break; + } + default: { break; @@ -31380,6 +31406,48 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree) } #endif + case NI_Vector128_op_Equality: +#if defined(TARGET_ARM64) + case NI_Vector64_op_Equality: +#elif defined(TARGET_XARCH) + case NI_Vector256_op_Equality: + case NI_Vector512_op_Equality: +#endif // !TARGET_ARM64 && !TARGET_XARCH + { + if (varTypeIsFloating(simdBaseType)) + { + // Handle `(x == NaN) == false` and `(NaN == x) == false` for floating-point types + if (cnsNode->IsVectorNaN(simdBaseType)) + { + resultNode = gtNewIconNode(0, retType); + resultNode = gtWrapWithSideEffects(resultNode, otherNode, GTF_ALL_EFFECT); + break; + } + } + break; + } + + case NI_Vector128_op_Inequality: +#if defined(TARGET_ARM64) + case NI_Vector64_op_Inequality: +#elif defined(TARGET_XARCH) + case NI_Vector256_op_Inequality: + case NI_Vector512_op_Inequality: +#endif // !TARGET_ARM64 && !TARGET_XARCH + { + if (varTypeIsFloating(simdBaseType)) + { + // Handle `(x != NaN) == true` and `(NaN != x) == true` for floating-point types + if (cnsNode->IsVectorNaN(simdBaseType)) + { + resultNode = gtNewIconNode(1, retType); + resultNode = gtWrapWithSideEffects(resultNode, otherNode, GTF_ALL_EFFECT); + break; + } + } + break; + } + default: { break; diff --git a/src/coreclr/jit/lowerxarch.cpp b/src/coreclr/jit/lowerxarch.cpp index de78eaa8bfbf6..db997aa3f426c 100644 --- a/src/coreclr/jit/lowerxarch.cpp +++ b/src/coreclr/jit/lowerxarch.cpp @@ -2577,7 +2577,7 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm CorInfoType maskBaseJitType = simdBaseJitType; var_types maskBaseType = simdBaseType; - if (op1Msk->OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector)) + if (op1Msk->OperIsConvertMaskToVector()) { GenTreeHWIntrinsic* cvtMaskToVector = op1Msk->AsHWIntrinsic(); @@ -2588,122 +2588,131 @@ GenTree* Lowering::LowerHWIntrinsicCmpOp(GenTreeHWIntrinsic* node, genTreeOps cm maskBaseType = cvtMaskToVector->GetSimdBaseType(); } - if (!varTypeIsFloating(simdBaseType) && (simdSize != 64) && op2->IsVectorZero() && - comp->compOpportunisticallyDependsOn(InstructionSet_SSE41) && !varTypeIsMask(op1Msk)) + if (!varTypeIsFloating(simdBaseType) && (simdSize != 64) && !varTypeIsMask(op1Msk)) { - // On SSE4.1 or higher we can optimize comparisons against zero to - // just use PTEST. We can't support it for floating-point, however, - // as it has both +0.0 and -0.0 where +0.0 == -0.0 + bool isOp2VectorZero = op2->IsVectorZero(); - bool skipReplaceOperands = false; - - if (op1->OperIsHWIntrinsic()) + if ((isOp2VectorZero || op2->IsVectorAllBitsSet()) && + comp->compOpportunisticallyDependsOn(InstructionSet_SSE41)) { - GenTreeHWIntrinsic* op1Intrinsic = op1->AsHWIntrinsic(); - NamedIntrinsic op1IntrinsicId = op1Intrinsic->GetHWIntrinsicId(); + // On SSE4.1 or higher we can optimize comparisons against Zero or AllBitsSet to + // just use PTEST. We can't support it for floating-point, however, as it has + // both +0.0 and -0.0 where +0.0 == -0.0 - GenTree* nestedOp1 = nullptr; - GenTree* nestedOp2 = nullptr; - bool isEmbeddedBroadcast = false; + bool skipReplaceOperands = false; - if (op1Intrinsic->GetOperandCount() == 2) + if (!isOp2VectorZero) { - nestedOp1 = op1Intrinsic->Op(1); - nestedOp2 = op1Intrinsic->Op(2); + // We can optimize to TestC(op1, allbitsset) + // + // This works out because TestC sets CF if (~x & y) == 0, so: + // ~00 & 11 = 11; 11 & 11 = 11; NC + // ~01 & 11 = 01; 10 & 11 = 10; NC + // ~10 & 11 = 10; 01 & 11 = 01; NC + // ~11 & 11 = 11; 00 & 11 = 00; C - assert(!nestedOp1->isContained()); - isEmbeddedBroadcast = nestedOp2->isContained() && nestedOp2->OperIsHWIntrinsic(); - } + assert(op2->IsVectorAllBitsSet()); + cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC; - switch (op1IntrinsicId) + skipReplaceOperands = true; + } + else if (op1->OperIsHWIntrinsic()) { - case NI_SSE_And: - case NI_SSE2_And: - case NI_AVX_And: - case NI_AVX2_And: + assert(op2->IsVectorZero()); + + GenTreeHWIntrinsic* op1Intrinsic = op1->AsHWIntrinsic(); + + if (op1Intrinsic->GetOperandCount() == 2) { - // We can optimize to TestZ(op1.op1, op1.op2) + GenTree* nestedOp1 = op1Intrinsic->Op(1); + GenTree* nestedOp2 = op1Intrinsic->Op(2); + + assert(!nestedOp1->isContained()); + bool isEmbeddedBroadcast = nestedOp2->isContained() && nestedOp2->OperIsHWIntrinsic(); - if (isEmbeddedBroadcast) + bool isScalar = false; + genTreeOps oper = op1Intrinsic->GetOperForHWIntrinsicId(&isScalar); + + switch (oper) { - // PTEST doesn't support embedded broadcast - break; - } + case GT_AND: + { + // We can optimize to TestZ(op1.op1, op1.op2) - node->Op(1) = nestedOp1; - node->Op(2) = nestedOp2; + if (isEmbeddedBroadcast) + { + // PTEST doesn't support embedded broadcast + break; + } - BlockRange().Remove(op1); - BlockRange().Remove(op2); + node->Op(1) = nestedOp1; + node->Op(2) = nestedOp2; - skipReplaceOperands = true; - break; - } + BlockRange().Remove(op1); + BlockRange().Remove(op2); - case NI_SSE_AndNot: - case NI_SSE2_AndNot: - case NI_AVX_AndNot: - case NI_AVX2_AndNot: - { - // We can optimize to TestC(op1.op1, op1.op2) + skipReplaceOperands = true; + break; + } - if (isEmbeddedBroadcast) - { - // PTEST doesn't support embedded broadcast - break; - } + case GT_AND_NOT: + { + // We can optimize to TestC(op1.op1, op1.op2) + + if (isEmbeddedBroadcast) + { + // PTEST doesn't support embedded broadcast + break; + } - cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC; + cmpCnd = (cmpOp == GT_EQ) ? GenCondition::C : GenCondition::NC; - node->Op(1) = nestedOp1; - node->Op(2) = nestedOp2; + node->Op(1) = nestedOp1; + node->Op(2) = nestedOp2; - BlockRange().Remove(op1); - BlockRange().Remove(op2); + BlockRange().Remove(op1); + BlockRange().Remove(op2); - skipReplaceOperands = true; - break; - } + skipReplaceOperands = true; + break; + } - default: - { - break; + default: + { + break; + } + } } } - } - - if (!skipReplaceOperands) - { - // Default handler, emit a TestZ(op1, op1) - node->Op(1) = op1; - BlockRange().Remove(op2); + if (!skipReplaceOperands) + { + // Default handler, emit a TestZ(op1, op1) + assert(op2->IsVectorZero()); - LIR::Use op1Use(BlockRange(), &node->Op(1), node); - ReplaceWithLclVar(op1Use); - op1 = node->Op(1); + node->Op(1) = op1; + BlockRange().Remove(op2); - op2 = comp->gtClone(op1); - BlockRange().InsertAfter(op1, op2); - node->Op(2) = op2; - } + LIR::Use op1Use(BlockRange(), &node->Op(1), node); + ReplaceWithLclVar(op1Use); + op1 = node->Op(1); - if (simdSize == 32) - { - // TODO-Review: LowerHWIntrinsicCC resets the id again, so why is this needed? - node->ChangeHWIntrinsicId(NI_AVX_TestZ); - LowerHWIntrinsicCC(node, NI_AVX_PTEST, cmpCnd); - } - else - { - assert(simdSize == 16); + op2 = comp->gtClone(op1); + BlockRange().InsertAfter(op1, op2); + node->Op(2) = op2; + } - // TODO-Review: LowerHWIntrinsicCC resets the id again, so why is this needed? - node->ChangeHWIntrinsicId(NI_SSE41_TestZ); - LowerHWIntrinsicCC(node, NI_SSE41_PTEST, cmpCnd); + if (simdSize == 32) + { + LowerHWIntrinsicCC(node, NI_AVX_PTEST, cmpCnd); + } + else + { + assert(simdSize == 16); + LowerHWIntrinsicCC(node, NI_SSE41_PTEST, cmpCnd); + } + return LowerNode(node); } - - return LowerNode(node); } // TODO-XARCH-AVX512: We should handle TYP_SIMD12 here under the EVEX path, but doing @@ -3579,7 +3588,7 @@ GenTree* Lowering::LowerHWIntrinsicTernaryLogic(GenTreeHWIntrinsic* node) } } - if (condition->OperIsHWIntrinsic(NI_EVEX_ConvertMaskToVector)) + if (condition->OperIsConvertMaskToVector()) { GenTree* tmp = condition->AsHWIntrinsic()->Op(1); BlockRange().Remove(condition); diff --git a/src/coreclr/jit/valuenum.cpp b/src/coreclr/jit/valuenum.cpp index f2c4d180cb054..4d5ad347744a0 100644 --- a/src/coreclr/jit/valuenum.cpp +++ b/src/coreclr/jit/valuenum.cpp @@ -8362,6 +8362,113 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, break; } + case GT_EQ: + { + if (varTypeIsFloating(baseType)) + { + // Handle `(x == NaN) == false` and `(NaN == x) == false` for floating-point types + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); + + if (VNIsVectorNaN(simdType, baseType, cnsVN)) + { + return VNZeroForType(type); + } + } + break; + } + + case GT_GT: + { + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); + + if (varTypeIsUnsigned(baseType)) + { + // Handle `(0 > x) == false` for unsigned types. + if ((cnsVN == arg0VN) && (cnsVN == VNZeroForType(simdType))) + { + return VNZeroForType(type); + } + } + else if (varTypeIsFloating(baseType)) + { + // Handle `(x > NaN) == false` and `(NaN > x) == false` for floating-point types + if (VNIsVectorNaN(simdType, baseType, cnsVN)) + { + return VNZeroForType(type); + } + } + break; + } + + case GT_GE: + { + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); + + if (varTypeIsUnsigned(baseType)) + { + // Handle `x >= 0 == true` for unsigned types. + if ((cnsVN == arg1VN) && (cnsVN == VNZeroForType(simdType))) + { + return VNAllBitsForType(type); + } + } + else if (varTypeIsFloating(baseType)) + { + // Handle `(x >= NaN) == false` and `(NaN >= x) == false` for floating-point types + if (VNIsVectorNaN(simdType, baseType, cnsVN)) + { + return VNZeroForType(type); + } + } + break; + } + + case GT_LT: + { + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); + + if (varTypeIsUnsigned(baseType)) + { + // Handle `x < 0 == false` for unsigned types. + if ((cnsVN == arg1VN) && (cnsVN == VNZeroForType(simdType))) + { + return VNZeroForType(type); + } + } + else if (varTypeIsFloating(baseType)) + { + // Handle `(x < NaN) == false` and `(NaN < x) == false` for floating-point types + if (VNIsVectorNaN(simdType, baseType, cnsVN)) + { + return VNZeroForType(type); + } + } + break; + } + + case GT_LE: + { + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); + + if (varTypeIsUnsigned(baseType)) + { + // Handle `0 <= x == true` for unsigned types. + if ((cnsVN == arg0VN) && (cnsVN == VNZeroForType(simdType))) + { + return VNAllBitsForType(type); + } + } + else if (varTypeIsFloating(baseType)) + { + // Handle `(x <= NaN) == false` and `(NaN <= x) == false` for floating-point types + if (VNIsVectorNaN(simdType, baseType, cnsVN)) + { + return VNZeroForType(type); + } + } + break; + } + case GT_MUL: { if (!varTypeIsFloating(baseType)) @@ -8409,6 +8516,21 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, break; } + case GT_NE: + { + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); + + if (varTypeIsFloating(baseType)) + { + // Handle `(x != NaN) == true` and `(NaN != x) == true` for floating-point types + if (VNIsVectorNaN(simdType, baseType, cnsVN)) + { + return VNAllBitsForType(type); + } + } + break; + } + case GT_OR: { // Handle `x | 0 == x` and `0 | x == x` @@ -8576,6 +8698,48 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, } #endif + case NI_Vector128_op_Equality: +#if defined(TARGET_ARM64) + case NI_Vector64_op_Equality: +#elif defined(TARGET_XARCH) + case NI_Vector256_op_Equality: + case NI_Vector512_op_Equality: +#endif // !TARGET_ARM64 && !TARGET_XARCH + { + if (varTypeIsFloating(baseType)) + { + // Handle `(x == NaN) == false` and `(NaN == x) == false` for floating-point types + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); + + if (VNIsVectorNaN(simdType, baseType, cnsVN)) + { + return VNZeroForType(type); + } + } + break; + } + + case NI_Vector128_op_Inequality: +#if defined(TARGET_ARM64) + case NI_Vector64_op_Inequality: +#elif defined(TARGET_XARCH) + case NI_Vector256_op_Inequality: + case NI_Vector512_op_Inequality: +#endif // !TARGET_ARM64 && !TARGET_XARCH + { + if (varTypeIsFloating(baseType)) + { + // Handle `(x != NaN) == true` and `(NaN != x) == true` for floating-point types + var_types simdType = Compiler::getSIMDTypeForSize(simdSize); + + if (VNIsVectorNaN(simdType, baseType, cnsVN)) + { + return VNOneForType(type); + } + } + break; + } + default: { break; @@ -8604,6 +8768,32 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, return arg0VN; } + case GT_EQ: + case GT_GE: + case GT_LE: + { + // We can't handle floating-point due to NaN + + if (varTypeIsIntegral(baseType)) + { + return VNAllBitsForType(type); + } + break; + } + + case GT_GT: + case GT_LT: + case GT_NE: + { + // We can't handle floating-point due to NaN + + if (varTypeIsIntegral(baseType)) + { + return VNZeroForType(type); + } + break; + } + case GT_OR: { // Handle `x | x == x` @@ -8631,6 +8821,48 @@ ValueNum ValueNumStore::EvalHWIntrinsicFunBinary(GenTreeHWIntrinsic* tree, default: break; } + + switch (ni) + { + case NI_Vector128_op_Equality: +#if defined(TARGET_ARM64) + case NI_Vector64_op_Equality: +#elif defined(TARGET_XARCH) + case NI_Vector256_op_Equality: + case NI_Vector512_op_Equality: +#endif // !TARGET_ARM64 && !TARGET_XARCH + { + // We can't handle floating-point due to NaN + + if (varTypeIsIntegral(baseType)) + { + return VNOneForType(type); + } + break; + } + + case NI_Vector128_op_Inequality: +#if defined(TARGET_ARM64) + case NI_Vector64_op_Inequality: +#elif defined(TARGET_XARCH) + case NI_Vector256_op_Inequality: + case NI_Vector512_op_Inequality: +#endif // !TARGET_ARM64 && !TARGET_XARCH + { + // We can't handle floating-point due to NaN + + if (varTypeIsIntegral(baseType)) + { + return VNZeroForType(type); + } + break; + } + + default: + { + break; + } + } } if (encodeResultType)