Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make randomized string equality comparers implement IAlternateEqualityComparer #104252

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public class InternalHashCodeTests_Dictionary_NullComparer : InternalHashCodeTes
protected override Dictionary<string, string> CreateCollection() => new Dictionary<string, string>();
protected override void AddKey(Dictionary<string, string> collection, string key) => collection.Add(key, key);
protected override bool ContainsKey(Dictionary<string, string> collection, string key) => collection.ContainsKey(key);
protected override bool ContainsKey(Dictionary<string, string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, string, ReadOnlySpan<char>>().ContainsKey(key);
protected override IEqualityComparer<string> GetComparer(Dictionary<string, string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
Expand Down Expand Up @@ -58,6 +60,8 @@ public class InternalHashCodeTests_Dictionary_DefaultComparer : InternalHashCode
protected override Dictionary<string, string> CreateCollection() => new Dictionary<string, string>(EqualityComparer<string>.Default);
protected override void AddKey(Dictionary<string, string> collection, string key) => collection.Add(key, key);
protected override bool ContainsKey(Dictionary<string, string> collection, string key) => collection.ContainsKey(key);
protected override bool ContainsKey(Dictionary<string, string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, string, ReadOnlySpan<char>>().ContainsKey(key);
protected override IEqualityComparer<string> GetComparer(Dictionary<string, string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
Expand All @@ -70,6 +74,8 @@ public class InternalHashCodeTests_Dictionary_OrdinalComparer : InternalHashCode
protected override Dictionary<string, string> CreateCollection() => new Dictionary<string, string>(StringComparer.Ordinal);
protected override void AddKey(Dictionary<string, string> collection, string key) => collection.Add(key, key);
protected override bool ContainsKey(Dictionary<string, string> collection, string key) => collection.ContainsKey(key);
protected override bool ContainsKey(Dictionary<string, string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, string, ReadOnlySpan<char>>().ContainsKey(key);
protected override IEqualityComparer<string> GetComparer(Dictionary<string, string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
Expand All @@ -82,6 +88,8 @@ public class InternalHashCodeTests_Dictionary_OrdinalIgnoreCaseComparer : Intern
protected override Dictionary<string, string> CreateCollection() => new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
protected override void AddKey(Dictionary<string, string> collection, string key) => collection.Add(key, key);
protected override bool ContainsKey(Dictionary<string, string> collection, string key) => collection.ContainsKey(key);
protected override bool ContainsKey(Dictionary<string, string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, string, ReadOnlySpan<char>>().ContainsKey(key);
protected override IEqualityComparer<string> GetComparer(Dictionary<string, string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalIgnoreCaseComparerType;
Expand All @@ -94,6 +102,8 @@ public class InternalHashCodeTests_Dictionary_LinguisticComparer : InternalHashC
protected override Dictionary<string, string> CreateCollection() => new Dictionary<string, string>(StringComparer.InvariantCulture);
protected override void AddKey(Dictionary<string, string> collection, string key) => collection.Add(key, key);
protected override bool ContainsKey(Dictionary<string, string> collection, string key) => collection.ContainsKey(key);
protected override bool ContainsKey(Dictionary<string, string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, string, ReadOnlySpan<char>>().ContainsKey(key);
protected override IEqualityComparer<string> GetComparer(Dictionary<string, string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => StringComparer.InvariantCulture.GetType();
Expand All @@ -108,6 +118,8 @@ public class InternalHashCodeTests_Dictionary_GetValueRefOrAddDefault : Internal
protected override Dictionary<string, string> CreateCollection() => new Dictionary<string, string>(StringComparer.Ordinal);
protected override void AddKey(Dictionary<string, string> collection, string key) => CollectionsMarshal.GetValueRefOrAddDefault(collection, key, out _) = null;
protected override bool ContainsKey(Dictionary<string, string> collection, string key) => collection.ContainsKey(key);
protected override bool ContainsKey(Dictionary<string, string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, string, ReadOnlySpan<char>>().ContainsKey(key);
protected override IEqualityComparer<string> GetComparer(Dictionary<string, string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
Expand All @@ -122,6 +134,8 @@ public class InternalHashCodeTests_HashSet_NullComparer : InternalHashCodeTests<
protected override HashSet<string> CreateCollection() => new HashSet<string>();
protected override void AddKey(HashSet<string> collection, string key) => collection.Add(key);
protected override bool ContainsKey(HashSet<string> collection, string key) => collection.Contains(key);
protected override bool ContainsKey(HashSet<string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, ReadOnlySpan<char>>().Contains(key);
protected override IEqualityComparer<string> GetComparer(HashSet<string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
Expand All @@ -134,6 +148,8 @@ public class InternalHashCodeTests_HashSet_DefaultComparer : InternalHashCodeTes
protected override HashSet<string> CreateCollection() => new HashSet<string>(EqualityComparer<string>.Default);
protected override void AddKey(HashSet<string> collection, string key) => collection.Add(key);
protected override bool ContainsKey(HashSet<string> collection, string key) => collection.Contains(key);
protected override bool ContainsKey(HashSet<string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, ReadOnlySpan<char>>().Contains(key);
protected override IEqualityComparer<string> GetComparer(HashSet<string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
Expand All @@ -146,6 +162,8 @@ public class InternalHashCodeTests_HashSet_OrdinalComparer : InternalHashCodeTes
protected override HashSet<string> CreateCollection() => new HashSet<string>(StringComparer.Ordinal);
protected override void AddKey(HashSet<string> collection, string key) => collection.Add(key);
protected override bool ContainsKey(HashSet<string> collection, string key) => collection.Contains(key);
protected override bool ContainsKey(HashSet<string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, ReadOnlySpan<char>>().Contains(key);
protected override IEqualityComparer<string> GetComparer(HashSet<string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
Expand All @@ -158,6 +176,8 @@ public class InternalHashCodeTests_HashSet_OrdinalIgnoreCaseComparer : InternalH
protected override HashSet<string> CreateCollection() => new HashSet<string>(StringComparer.OrdinalIgnoreCase);
protected override void AddKey(HashSet<string> collection, string key) => collection.Add(key);
protected override bool ContainsKey(HashSet<string> collection, string key) => collection.Contains(key);
protected override bool ContainsKey(HashSet<string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, ReadOnlySpan<char>>().Contains(key);
protected override IEqualityComparer<string> GetComparer(HashSet<string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalIgnoreCaseComparerType;
Expand All @@ -170,6 +190,8 @@ public class InternalHashCodeTests_HashSet_LinguisticComparer : InternalHashCode
protected override HashSet<string> CreateCollection() => new HashSet<string>(StringComparer.InvariantCulture);
protected override void AddKey(HashSet<string> collection, string key) => collection.Add(key);
protected override bool ContainsKey(HashSet<string> collection, string key) => collection.Contains(key);
protected override bool ContainsKey(HashSet<string> collection, ReadOnlySpan<char> key) =>
collection.GetAlternateLookup<string, ReadOnlySpan<char>>().Contains(key);
protected override IEqualityComparer<string> GetComparer(HashSet<string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => StringComparer.InvariantCulture.GetType();
Expand All @@ -189,6 +211,8 @@ public class InternalHashCodeTests_OrderedDictionary_NullComparer : InternalHash
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => EqualityComparer<string>.Default;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => EqualityComparer<string>.Default.GetType();

protected override bool SupportsAlternateLookup(OrderedDictionary<string, string> collection) => false;
}

public class InternalHashCodeTests_OrderedDictionary_DefaultComparer : InternalHashCodeTests<OrderedDictionary<string, string>>
Expand All @@ -201,6 +225,8 @@ public class InternalHashCodeTests_OrderedDictionary_DefaultComparer : InternalH
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => EqualityComparer<string>.Default;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => EqualityComparer<string>.Default.GetType();

protected override bool SupportsAlternateLookup(OrderedDictionary<string, string> collection) => false;
}

public class InternalHashCodeTests_OrderedDictionary_OrdinalComparer : InternalHashCodeTests<OrderedDictionary<string, string>>
Expand All @@ -213,6 +239,8 @@ public class InternalHashCodeTests_OrderedDictionary_OrdinalComparer : InternalH
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => StringComparer.Ordinal;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => StringComparer.Ordinal.GetType();

protected override bool SupportsAlternateLookup(OrderedDictionary<string, string> collection) => false;
}

public class InternalHashCodeTests_OrderedDictionary_OrdinalIgnoreCaseComparer : InternalHashCodeTests<OrderedDictionary<string, string>>
Expand All @@ -225,6 +253,8 @@ public class InternalHashCodeTests_OrderedDictionary_OrdinalIgnoreCaseComparer :
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalIgnoreCaseComparerType;
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => StringComparer.OrdinalIgnoreCase;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => StringComparer.OrdinalIgnoreCase.GetType();

protected override bool SupportsAlternateLookup(OrderedDictionary<string, string> collection) => false;
}

public class InternalHashCodeTests_OrderedDictionary_LinguisticComparer : InternalHashCodeTests<OrderedDictionary<string, string>> // (not optimized)
Expand All @@ -237,6 +267,8 @@ public class InternalHashCodeTests_OrderedDictionary_LinguisticComparer : Intern
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => StringComparer.InvariantCulture.GetType();
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => StringComparer.InvariantCulture;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => StringComparer.InvariantCulture.GetType();

protected override bool SupportsAlternateLookup(OrderedDictionary<string, string> collection) => false;
}
#endregion

Expand All @@ -251,6 +283,8 @@ public abstract class InternalHashCodeTests<TCollection>
protected abstract void AddKey(TCollection collection, string key);
protected abstract bool ContainsKey(TCollection collection, string key);
protected abstract IEqualityComparer<string> GetComparer(TCollection collection);
protected virtual bool SupportsAlternateLookup(TCollection collection) => true;
protected virtual bool ContainsKey(TCollection collection, ReadOnlySpan<char> key) => throw new NotSupportedException();

protected abstract Type ExpectedInternalComparerTypeBeforeCollisionThreshold { get; }
protected abstract IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold { get; }
Expand Down Expand Up @@ -304,6 +338,10 @@ public void ComparerImplementations_Dictionary_WithWellKnownStringComparers()
foreach (string key in allKeys)
{
Assert.True(ContainsKey(collection, key));
if (SupportsAlternateLookup(collection))
{
Assert.True(ContainsKey(collection, key.AsSpan()));
}
}

// Also make sure we didn't accidentally put the internal comparer in the serialized object data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Collections.Generic
{
Expand Down Expand Up @@ -46,15 +47,29 @@ private struct MarvinSeed
internal uint p1;
}

private sealed class OrdinalComparer : RandomizedStringEqualityComparer
private sealed class OrdinalComparer : RandomizedStringEqualityComparer, IAlternateEqualityComparer<ReadOnlySpan<char>, string?>
{
internal OrdinalComparer(IEqualityComparer<string?> wrappedComparer)
: base(wrappedComparer)
{
}

string IAlternateEqualityComparer<ReadOnlySpan<char>, string?>.Create(ReadOnlySpan<char> span) =>
span.ToString();

public override bool Equals(string? x, string? y) => string.Equals(x, y);

bool IAlternateEqualityComparer<ReadOnlySpan<char>, string?>.Equals(ReadOnlySpan<char> alternate, string? other)
{
// See explanation in System.OrdinalComparer.Equals.
if (alternate.IsEmpty && other is null)
{
return false;
}

return alternate.SequenceEqual(other);
}

public override int GetHashCode(string? obj)
{
if (obj is null)
Expand All @@ -69,17 +84,37 @@ ref Unsafe.As<char, byte>(ref obj.GetRawStringData()),
(uint)obj.Length * 2,
_seed.p0, _seed.p1);
}

int IAlternateEqualityComparer<ReadOnlySpan<char>, string?>.GetHashCode(ReadOnlySpan<char> alternate) =>
Marvin.ComputeHash32(
ref Unsafe.As<char, byte>(ref MemoryMarshal.GetReference(alternate)),
(uint)alternate.Length * 2,
_seed.p0, _seed.p1);
}

private sealed class OrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer
private sealed class OrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer, IAlternateEqualityComparer<ReadOnlySpan<char>, string?>
{
internal OrdinalIgnoreCaseComparer(IEqualityComparer<string?> wrappedComparer)
: base(wrappedComparer)
{
}

string IAlternateEqualityComparer<ReadOnlySpan<char>, string?>.Create(ReadOnlySpan<char> span) =>
span.ToString();

public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y);

bool IAlternateEqualityComparer<ReadOnlySpan<char>, string?>.Equals(ReadOnlySpan<char> alternate, string? other)
{
// See explanation in System.OrdinalComparer.Equals.
if (alternate.IsEmpty && other is null)
{
return false;
}

return alternate.EqualsOrdinalIgnoreCase(other);
}

public override int GetHashCode(string? obj)
{
if (obj is null)
Expand All @@ -94,6 +129,12 @@ ref obj.GetRawStringData(),
obj.Length,
_seed.p0, _seed.p1);
}

int IAlternateEqualityComparer<ReadOnlySpan<char>, string?>.GetHashCode(ReadOnlySpan<char> alternate) =>
Marvin.ComputeHash32OrdinalIgnoreCase(
ref MemoryMarshal.GetReference(alternate),
alternate.Length,
_seed.p0, _seed.p1);
}
}
}
Loading