Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand String.EndsWith/MemoryExtensions.EndsWith in JIT #98593

Merged
merged 6 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4423,12 +4423,18 @@ class Compiler
Eq, // (d1 == cns1) && (s2 == cns2)
Xor, // (d1 ^ cns1) | (s2 ^ cns2)
};
GenTree* impStringEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO* sig, unsigned methodFlags);
GenTree* impSpanEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO* sig, unsigned methodFlags);
enum StringComparisonKind
EgorBo marked this conversation as resolved.
Show resolved Hide resolved
{
Equals,
StartsWith,
EndsWith
};
GenTree* impUtf16StringComparison(StringComparisonKind kind, CORINFO_SIG_INFO* sig, unsigned methodFlags);
GenTree* impUtf16SpanComparison(StringComparisonKind kind, CORINFO_SIG_INFO* sig, unsigned methodFlags);
GenTree* impExpandHalfConstEquals(GenTreeLclVarCommon* data,
GenTree* lengthFld,
bool checkForNull,
bool startsWith,
StringComparisonKind kind,
WCHAR* cnsData,
int len,
int dataOffset,
Expand Down
28 changes: 24 additions & 4 deletions src/coreclr/jit/importercalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2827,26 +2827,38 @@ GenTree* Compiler::impIntrinsic(GenTree* newobjThis,

case NI_System_String_Equals:
{
retNode = impStringEqualsOrStartsWith(/*startsWith:*/ false, sig, methodFlags);
retNode = impUtf16StringComparison(Equals, sig, methodFlags);
break;
}

case NI_System_MemoryExtensions_Equals:
case NI_System_MemoryExtensions_SequenceEqual:
{
retNode = impSpanEqualsOrStartsWith(/*startsWith:*/ false, sig, methodFlags);
retNode = impUtf16SpanComparison(Equals, sig, methodFlags);
break;
}

case NI_System_String_StartsWith:
{
retNode = impStringEqualsOrStartsWith(/*startsWith:*/ true, sig, methodFlags);
retNode = impUtf16StringComparison(StartsWith, sig, methodFlags);
break;
}

case NI_System_String_EndsWith:
{
retNode = impUtf16StringComparison(EndsWith, sig, methodFlags);
break;
}

case NI_System_MemoryExtensions_StartsWith:
{
retNode = impSpanEqualsOrStartsWith(/*startsWith:*/ true, sig, methodFlags);
retNode = impUtf16SpanComparison(StartsWith, sig, methodFlags);
break;
}

case NI_System_MemoryExtensions_EndsWith:
{
retNode = impUtf16SpanComparison(EndsWith, sig, methodFlags);
break;
}

Expand Down Expand Up @@ -8932,6 +8944,10 @@ NamedIntrinsic Compiler::lookupNamedIntrinsic(CORINFO_METHOD_HANDLE method)
{
result = NI_System_MemoryExtensions_StartsWith;
}
else if (strcmp(methodName, "EndsWith") == 0)
{
result = NI_System_MemoryExtensions_EndsWith;
}
}
break;
}
Expand Down Expand Up @@ -9032,6 +9048,10 @@ NamedIntrinsic Compiler::lookupNamedIntrinsic(CORINFO_METHOD_HANDLE method)
{
result = NI_System_String_StartsWith;
}
else if (strcmp(methodName, "EndsWith") == 0)
{
result = NI_System_String_EndsWith;
}
}
break;
}
Expand Down
93 changes: 66 additions & 27 deletions src/coreclr/jit/importervectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
// 8) MemoryExtensions.StartsWith<char>(ROS<char>, ROS<char>)
// 9) MemoryExtensions.StartsWith(ROS<char>, ROS<char>, Ordinal or OrdinalIgnoreCase)
//
// 10) str.EndsWith(string, Ordinal or OrdinalIgnoreCase)
// 11) MemoryExtensions.EndsWith<char>(ROS<char>, ROS<char>)
// 12) MemoryExtensions.EndsWith(ROS<char>, ROS<char>, Ordinal or OrdinalIgnoreCase)
//
// When one of the arguments is a constant string of a [0..32] size so we can inline
// a vectorized comparison against it using SWAR or SIMD techniques (e.g. via two V256 vectors)
//
Expand Down Expand Up @@ -426,7 +430,7 @@ GenTree* Compiler::impExpandHalfConstEqualsSWAR(
// data - Pointer (LCL_VAR) to a data to vectorize
// lengthFld - Pointer (LCL_VAR or GT_IND) to Length field
// checkForNull - Check data for null
// startsWith - Is it StartsWith or Equals?
// kind - Is it StartsWith, Equals or EndsWith?
// cns - Constant data (array of 2-byte chars)
// len - Number of 2-byte chars in the cns
// dataOffset - Offset for data
Expand All @@ -439,7 +443,7 @@ GenTree* Compiler::impExpandHalfConstEqualsSWAR(
GenTree* Compiler::impExpandHalfConstEquals(GenTreeLclVarCommon* data,
GenTree* lengthFld,
bool checkForNull,
bool startsWith,
StringComparisonKind kind,
WCHAR* cnsData,
int len,
int dataOffset,
Expand All @@ -454,30 +458,41 @@ GenTree* Compiler::impExpandHalfConstEquals(GenTreeLclVarCommon* data,
return nullptr;
}

const genTreeOps cmpOp = startsWith ? GT_GE : GT_EQ;
const genTreeOps cmpOp = kind == Equals ? GT_EQ : GT_GE;
GenTree* elementsCount = gtNewIconNode(len);
GenTree* lenCheckNode;
if (len == 0)
{
// For zero length we don't need to compare content, the following expression is enough:
//
// varData != null && lengthFld == 0
// varData != null && lengthFld cmpOp 0
//
lenCheckNode = gtNewOperNode(cmpOp, TYP_INT, lengthFld, elementsCount);
}
else
{
assert(cnsData != nullptr);

GenTreeLclVarCommon* dataAddr = gtClone(data)->AsLclVarCommon();

if (kind == EndsWith)
{
// For EndsWith we need to adjust dataAddr to point to the end of the string minus value's length
// We spawn a local that we're going to set below
unsigned dataTmp = lvaGrabTemp(true DEBUGARG("clonning data ptr"));
lvaTable[dataTmp].lvType = TYP_BYREF;
dataAddr = gtNewLclvNode(dataTmp, TYP_BYREF);
}

GenTree* indirCmp = nullptr;
if (len < 8) // SWAR impl supports len == 8 but we'd better give it to SIMD
{
indirCmp = impExpandHalfConstEqualsSWAR(gtClone(data)->AsLclVarCommon(), cnsData, len, dataOffset, cmpMode);
indirCmp = impExpandHalfConstEqualsSWAR(dataAddr, cnsData, len, dataOffset, cmpMode);
}
#if defined(FEATURE_HW_INTRINSICS)
else if (IsBaselineSimdIsaSupported())
{
indirCmp = impExpandHalfConstEqualsSIMD(gtClone(data)->AsLclVarCommon(), cnsData, len, dataOffset, cmpMode);
indirCmp = impExpandHalfConstEqualsSIMD(dataAddr, cnsData, len, dataOffset, cmpMode);
}
#endif

Expand All @@ -488,9 +503,24 @@ GenTree* Compiler::impExpandHalfConstEquals(GenTreeLclVarCommon* data,
}
assert(indirCmp->TypeIs(TYP_INT, TYP_UBYTE));

if (kind == EndsWith)
{
// len is expected to be small, so no overflow is possible
assert((len * 2) > len);
EgorBo marked this conversation as resolved.
Show resolved Hide resolved

// dataAddr = dataAddr + (length * 2 - len * 2)
GenTree* castedLen = gtNewCastNode(TYP_I_IMPL, gtCloneExpr(lengthFld), false, TYP_I_IMPL);
GenTree* byteLen = gtNewOperNode(GT_MUL, TYP_I_IMPL, castedLen, gtNewIconNode(2, TYP_I_IMPL));
GenTreeOp* cmpStart = gtNewOperNode(GT_ADD, TYP_BYREF, gtClone(data),
gtNewOperNode(GT_SUB, TYP_I_IMPL, byteLen,
gtNewIconNode((ssize_t)(len * 2), TYP_I_IMPL)));
GenTree* storeTmp = gtNewTempStore(dataAddr->GetLclNum(), cmpStart);
indirCmp = gtNewOperNode(GT_COMMA, indirCmp->TypeGet(), storeTmp, indirCmp);
}

GenTreeColon* lenCheckColon = gtNewColonNode(TYP_INT, indirCmp, gtNewFalse());

// For StartsWith we use GT_GE, e.g.: `x.Length >= 10`
// For StartsWith/EndsWith we use GT_GE, e.g.: `x.Length >= 10`
lenCheckNode = gtNewQmarkNode(TYP_INT, gtNewOperNode(cmpOp, TYP_INT, lengthFld, elementsCount), lenCheckColon);
}

Expand Down Expand Up @@ -556,7 +586,7 @@ GenTreeStrCon* Compiler::impGetStrConFromSpan(GenTree* span)
}

//------------------------------------------------------------------------
// impStringEqualsOrStartsWith: The main entry-point for String methods
// impUtf16StringComparison: The main entry-point for String methods
// We're going to unroll & vectorize the following cases:
// 1) String.Equals(obj, "cns")
// 2) String.Equals(obj, "cns", Ordinal or OrdinalIgnoreCase)
Expand All @@ -570,26 +600,29 @@ GenTreeStrCon* Compiler::impGetStrConFromSpan(GenTree* span)
// 9) obj.StartsWith("cns", Ordinal or OrdinalIgnoreCase)
// 10) "cns".StartsWith(obj, Ordinal or OrdinalIgnoreCase)
//
// 11) obj.EndsWith("cns", Ordinal or OrdinalIgnoreCase)
// 12) "cns".EndsWith(obj, Ordinal or OrdinalIgnoreCase)
//
// For cases 5, 6 and 9 we don't emit "obj != null"
// NOTE: String.Equals(object) is not supported currently
//
// Arguments:
// startsWith - Is it StartsWith or Equals?
// sig - signature of StartsWith or Equals method
// kind - Is it StartsWith, EndsWith or Equals?
// sig - signature of StartsWith, EndsWith or Equals method
// methodFlags - its flags
//
// Returns:
// GenTree representing vectorized comparison or nullptr
//
GenTree* Compiler::impStringEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO* sig, unsigned methodFlags)
GenTree* Compiler::impUtf16StringComparison(StringComparisonKind kind, CORINFO_SIG_INFO* sig, unsigned methodFlags)
{
const bool isStatic = methodFlags & CORINFO_FLG_STATIC;
const int argsCount = sig->numArgs + (isStatic ? 0 : 1);

// This optimization spawns several temps so make sure we have a room
if (lvaHaveManyLocals(0.75))
{
JITDUMP("impSpanEqualsOrStartsWith: Method has too many locals - bail out.\n")
JITDUMP("impStringComparison: Method has too many locals - bail out.\n")
EgorBo marked this conversation as resolved.
Show resolved Hide resolved
return nullptr;
}

Expand Down Expand Up @@ -630,9 +663,9 @@ GenTree* Compiler::impStringEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO
}
else
{
if (startsWith)
if (kind != Equals)
{
// StartsWith is not commutative
// StartsWith and EndsWith are not commutative
return nullptr;
}
cnsStr = op1->AsStrCon();
Expand All @@ -647,6 +680,7 @@ GenTree* Compiler::impStringEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO
// obj.Equals("cns")
// obj.Equals("cns", Ordinal or OrdinalIgnoreCase)
// obj.StartsWith("cns", Ordinal or OrdinalIgnoreCase)
// obj.EndsWith("cns", Ordinal or OrdinalIgnoreCase)
//
// instead, it should throw NRE if it's null
needsNullcheck = false;
Expand All @@ -658,7 +692,7 @@ GenTree* Compiler::impStringEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO
{
// check for fake "" first
cnsLength = 0;
JITDUMP("Trying to unroll String.Equals|StartsWith(op1, \"\")...\n", str)
JITDUMP("Trying to unroll String.Equals|StartsWith|EndsWith(op1, \"\")...\n", str)
}
else
{
Expand All @@ -668,7 +702,7 @@ GenTree* Compiler::impStringEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO
// We were unable to get the literal (e.g. dynamic context)
return nullptr;
}
JITDUMP("Trying to unroll String.Equals|StartsWith(op1, \"cns\")...\n")
JITDUMP("Trying to unroll String.Equals|StartsWith|EndsWith(op1, \"cns\")...\n")
}

// Create a temp which is safe to gtClone for varStr
Expand All @@ -682,7 +716,7 @@ GenTree* Compiler::impStringEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO
GenTree* lenNode = gtNewArrLen(TYP_INT, varStrLcl, strLenOffset, compCurBB);
varStrLcl = gtClone(varStrLcl)->AsLclVar();

GenTree* unrolled = impExpandHalfConstEquals(varStrLcl, lenNode, needsNullcheck, startsWith, (WCHAR*)str, cnsLength,
GenTree* unrolled = impExpandHalfConstEquals(varStrLcl, lenNode, needsNullcheck, kind, (WCHAR*)str, cnsLength,
strLenOffset + sizeof(int), cmpMode);
if (unrolled != nullptr)
{
Expand All @@ -706,7 +740,7 @@ GenTree* Compiler::impStringEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO
}

//------------------------------------------------------------------------
// impSpanEqualsOrStartsWith: The main entry-point for [ReadOnly]Span<char> methods
// impUtf16SpanComparison: The main entry-point for [ReadOnly]Span<char> methods
// We're going to unroll & vectorize the following cases:
// 1) MemoryExtensions.SequenceEqual<char>(var, "cns")
// 2) MemoryExtensions.SequenceEqual<char>("cns", var)
Expand All @@ -717,23 +751,28 @@ GenTree* Compiler::impStringEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO
// 7) MemoryExtensions.StartsWith("cns", var, Ordinal or OrdinalIgnoreCase)
// 8) MemoryExtensions.StartsWith(var, "cns", Ordinal or OrdinalIgnoreCase)
//
// 9) MemoryExtensions.EndsWith<char>("cns", var)
// 10) MemoryExtensions.EndsWith<char>(var, "cns")
// 11) MemoryExtensions.EndsWith("cns", var, Ordinal or OrdinalIgnoreCase)
// 12) MemoryExtensions.EndsWith(var, "cns", Ordinal or OrdinalIgnoreCase)
//
// Arguments:
// startsWith - Is it StartsWith or Equals?
// sig - signature of StartsWith or Equals method
// kind - Is it StartsWith, EndsWith or Equals?
// sig - signature of StartsWith, EndsWith or Equals method
// methodFlags - its flags
//
// Returns:
// GenTree representing vectorized comparison or nullptr
//
GenTree* Compiler::impSpanEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO* sig, unsigned methodFlags)
GenTree* Compiler::impUtf16SpanComparison(StringComparisonKind kind, CORINFO_SIG_INFO* sig, unsigned methodFlags)
{
const bool isStatic = methodFlags & CORINFO_FLG_STATIC;
const int argsCount = sig->numArgs + (isStatic ? 0 : 1);

// This optimization spawns several temps so make sure we have a room
if (lvaHaveManyLocals(0.75))
{
JITDUMP("impSpanEqualsOrStartsWith: Method has too many locals - bail out.\n")
JITDUMP("impUtf16SpanComparison: Method has too many locals - bail out.\n")
return nullptr;
}

Expand All @@ -760,7 +799,7 @@ GenTree* Compiler::impSpanEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO*
op2 = impStackTop(0).val;
}

// For generic StartsWith and Equals we need to make sure T is char
// For generic StartsWith, EndsWith and Equals we need to make sure T is char
if (sig->sigInst.methInstCount != 0)
{
assert(sig->sigInst.methInstCount == 1);
Expand Down Expand Up @@ -790,9 +829,9 @@ GenTree* Compiler::impSpanEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO*
}
else
{
if (startsWith)
if (kind != Equals)
{
// StartsWith is not commutative
// StartsWith and EndsWith are not commutative
return nullptr;
}
cnsStr = op1Str;
Expand Down Expand Up @@ -835,8 +874,8 @@ GenTree* Compiler::impSpanEqualsOrStartsWith(bool startsWith, CORINFO_SIG_INFO*

GenTreeLclFld* spanReferenceFld = gtNewLclFldNode(spanLclNum, TYP_BYREF, OFFSETOF__CORINFO_Span__reference);
GenTreeLclFld* spanLengthFld = gtNewLclFldNode(spanLclNum, TYP_INT, OFFSETOF__CORINFO_Span__length);
GenTree* unrolled = impExpandHalfConstEquals(spanReferenceFld, spanLengthFld, false, startsWith, (WCHAR*)str,
cnsLength, 0, cmpMode);
GenTree* unrolled =
impExpandHalfConstEquals(spanReferenceFld, spanLengthFld, false, kind, (WCHAR*)str, cnsLength, 0, cmpMode);

if (unrolled != nullptr)
{
Expand Down
2 changes: 2 additions & 0 deletions src/coreclr/jit/namedintrinsiclist.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ enum NamedIntrinsic : unsigned short
NI_System_String_get_Length,
NI_System_String_op_Implicit,
NI_System_String_StartsWith,
NI_System_String_EndsWith,
NI_System_Span_get_Item,
NI_System_Span_get_Length,
NI_System_SpanHelpers_SequenceEqual,
Expand All @@ -125,6 +126,7 @@ enum NamedIntrinsic : unsigned short
NI_System_MemoryExtensions_Equals,
NI_System_MemoryExtensions_SequenceEqual,
NI_System_MemoryExtensions_StartsWith,
NI_System_MemoryExtensions_EndsWith,

NI_System_Threading_Interlocked_And,
NI_System_Threading_Interlocked_Or,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ public static int ToUpperInvariant(this ReadOnlySpan<char> source, Span<char> de
/// <param name="span">The source span.</param>
/// <param name="value">The sequence to compare to the end of the source span.</param>
/// <param name="comparisonType">One of the enumeration values that determines how the <paramref name="span"/> and <paramref name="value"/> are compared.</param>
[Intrinsic] // Unrolled and vectorized for half-constant input (Ordinal)
public static bool EndsWith(this ReadOnlySpan<char> span, ReadOnlySpan<char> value, StringComparison comparisonType)
{
string.CheckStringComparison(comparisonType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2573,6 +2573,7 @@ ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
/// Determines whether the specified sequence appears at the end of the span.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[Intrinsic] // Unrolled and vectorized for half-constant input
public static unsafe bool EndsWith<T>(this Span<T> span, ReadOnlySpan<T> value) where T : IEquatable<T>?
{
int spanLength = span.Length;
Expand All @@ -2597,6 +2598,7 @@ ref MemoryMarshal.GetReference(value),
/// Determines whether the specified sequence appears at the end of the span.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[Intrinsic] // Unrolled and vectorized for half-constant input
public static unsafe bool EndsWith<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> value) where T : IEquatable<T>?
{
int spanLength = span.Length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ public bool EndsWith(string value)
return EndsWith(value, StringComparison.CurrentCulture);
}

[Intrinsic] // Unrolled and vectorized for half-constant input (Ordinal)
public bool EndsWith(string value, StringComparison comparisonType)
{
ArgumentNullException.ThrowIfNull(value);
Expand Down
Loading