From 2bf7d2117769922c60ae6019385e37d3635775c0 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Mon, 1 Jul 2024 11:41:38 -0400 Subject: [PATCH] Make randomized string equality comparers implement IAlternateEqualityComparer --- .../OutOfBoundsRegression.cs | 38 ++++++++++++++++ .../RandomizedStringEqualityComparer.cs | 45 ++++++++++++++++++- 2 files changed, 81 insertions(+), 2 deletions(-) diff --git a/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs b/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs index 6fe06dd432e74..a5c5a01b85fd4 100644 --- a/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs +++ b/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs @@ -16,6 +16,8 @@ public class InternalHashCodeTests_Dictionary_NullComparer : InternalHashCodeTes protected override Dictionary CreateCollection() => new Dictionary(); protected override void AddKey(Dictionary collection, string key) => collection.Add(key, key); protected override bool ContainsKey(Dictionary collection, string key) => collection.ContainsKey(key); + protected override bool ContainsKey(Dictionary collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().ContainsKey(key); protected override IEqualityComparer GetComparer(Dictionary collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType; @@ -58,6 +60,8 @@ public class InternalHashCodeTests_Dictionary_DefaultComparer : InternalHashCode protected override Dictionary CreateCollection() => new Dictionary(EqualityComparer.Default); protected override void AddKey(Dictionary collection, string key) => collection.Add(key, key); protected override bool ContainsKey(Dictionary collection, string key) => collection.ContainsKey(key); + protected override bool ContainsKey(Dictionary collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().ContainsKey(key); protected override IEqualityComparer GetComparer(Dictionary collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType; @@ -70,6 +74,8 @@ public class InternalHashCodeTests_Dictionary_OrdinalComparer : InternalHashCode protected override Dictionary CreateCollection() => new Dictionary(StringComparer.Ordinal); protected override void AddKey(Dictionary collection, string key) => collection.Add(key, key); protected override bool ContainsKey(Dictionary collection, string key) => collection.ContainsKey(key); + protected override bool ContainsKey(Dictionary collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().ContainsKey(key); protected override IEqualityComparer GetComparer(Dictionary collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType; @@ -82,6 +88,8 @@ public class InternalHashCodeTests_Dictionary_OrdinalIgnoreCaseComparer : Intern protected override Dictionary CreateCollection() => new Dictionary(StringComparer.OrdinalIgnoreCase); protected override void AddKey(Dictionary collection, string key) => collection.Add(key, key); protected override bool ContainsKey(Dictionary collection, string key) => collection.ContainsKey(key); + protected override bool ContainsKey(Dictionary collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().ContainsKey(key); protected override IEqualityComparer GetComparer(Dictionary collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalIgnoreCaseComparerType; @@ -94,6 +102,8 @@ public class InternalHashCodeTests_Dictionary_LinguisticComparer : InternalHashC protected override Dictionary CreateCollection() => new Dictionary(StringComparer.InvariantCulture); protected override void AddKey(Dictionary collection, string key) => collection.Add(key, key); protected override bool ContainsKey(Dictionary collection, string key) => collection.ContainsKey(key); + protected override bool ContainsKey(Dictionary collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().ContainsKey(key); protected override IEqualityComparer GetComparer(Dictionary collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => StringComparer.InvariantCulture.GetType(); @@ -108,6 +118,8 @@ public class InternalHashCodeTests_Dictionary_GetValueRefOrAddDefault : Internal protected override Dictionary CreateCollection() => new Dictionary(StringComparer.Ordinal); protected override void AddKey(Dictionary collection, string key) => CollectionsMarshal.GetValueRefOrAddDefault(collection, key, out _) = null; protected override bool ContainsKey(Dictionary collection, string key) => collection.ContainsKey(key); + protected override bool ContainsKey(Dictionary collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().ContainsKey(key); protected override IEqualityComparer GetComparer(Dictionary collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType; @@ -122,6 +134,8 @@ public class InternalHashCodeTests_HashSet_NullComparer : InternalHashCodeTests< protected override HashSet CreateCollection() => new HashSet(); protected override void AddKey(HashSet collection, string key) => collection.Add(key); protected override bool ContainsKey(HashSet collection, string key) => collection.Contains(key); + protected override bool ContainsKey(HashSet collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().Contains(key); protected override IEqualityComparer GetComparer(HashSet collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType; @@ -134,6 +148,8 @@ public class InternalHashCodeTests_HashSet_DefaultComparer : InternalHashCodeTes protected override HashSet CreateCollection() => new HashSet(EqualityComparer.Default); protected override void AddKey(HashSet collection, string key) => collection.Add(key); protected override bool ContainsKey(HashSet collection, string key) => collection.Contains(key); + protected override bool ContainsKey(HashSet collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().Contains(key); protected override IEqualityComparer GetComparer(HashSet collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType; @@ -146,6 +162,8 @@ public class InternalHashCodeTests_HashSet_OrdinalComparer : InternalHashCodeTes protected override HashSet CreateCollection() => new HashSet(StringComparer.Ordinal); protected override void AddKey(HashSet collection, string key) => collection.Add(key); protected override bool ContainsKey(HashSet collection, string key) => collection.Contains(key); + protected override bool ContainsKey(HashSet collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().Contains(key); protected override IEqualityComparer GetComparer(HashSet collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType; @@ -158,6 +176,8 @@ public class InternalHashCodeTests_HashSet_OrdinalIgnoreCaseComparer : InternalH protected override HashSet CreateCollection() => new HashSet(StringComparer.OrdinalIgnoreCase); protected override void AddKey(HashSet collection, string key) => collection.Add(key); protected override bool ContainsKey(HashSet collection, string key) => collection.Contains(key); + protected override bool ContainsKey(HashSet collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().Contains(key); protected override IEqualityComparer GetComparer(HashSet collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalIgnoreCaseComparerType; @@ -170,6 +190,8 @@ public class InternalHashCodeTests_HashSet_LinguisticComparer : InternalHashCode protected override HashSet CreateCollection() => new HashSet(StringComparer.InvariantCulture); protected override void AddKey(HashSet collection, string key) => collection.Add(key); protected override bool ContainsKey(HashSet collection, string key) => collection.Contains(key); + protected override bool ContainsKey(HashSet collection, ReadOnlySpan key) => + collection.GetAlternateLookup>().Contains(key); protected override IEqualityComparer GetComparer(HashSet collection) => collection.Comparer; protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => StringComparer.InvariantCulture.GetType(); @@ -189,6 +211,8 @@ public class InternalHashCodeTests_OrderedDictionary_NullComparer : InternalHash protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType; protected override IEqualityComparer ExpectedPublicComparerBeforeCollisionThreshold => EqualityComparer.Default; protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => EqualityComparer.Default.GetType(); + + protected override bool SupportsAlternateLookup(OrderedDictionary collection) => false; } public class InternalHashCodeTests_OrderedDictionary_DefaultComparer : InternalHashCodeTests> @@ -201,6 +225,8 @@ public class InternalHashCodeTests_OrderedDictionary_DefaultComparer : InternalH protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType; protected override IEqualityComparer ExpectedPublicComparerBeforeCollisionThreshold => EqualityComparer.Default; protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => EqualityComparer.Default.GetType(); + + protected override bool SupportsAlternateLookup(OrderedDictionary collection) => false; } public class InternalHashCodeTests_OrderedDictionary_OrdinalComparer : InternalHashCodeTests> @@ -213,6 +239,8 @@ public class InternalHashCodeTests_OrderedDictionary_OrdinalComparer : InternalH protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType; protected override IEqualityComparer ExpectedPublicComparerBeforeCollisionThreshold => StringComparer.Ordinal; protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => StringComparer.Ordinal.GetType(); + + protected override bool SupportsAlternateLookup(OrderedDictionary collection) => false; } public class InternalHashCodeTests_OrderedDictionary_OrdinalIgnoreCaseComparer : InternalHashCodeTests> @@ -225,6 +253,8 @@ public class InternalHashCodeTests_OrderedDictionary_OrdinalIgnoreCaseComparer : protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalIgnoreCaseComparerType; protected override IEqualityComparer ExpectedPublicComparerBeforeCollisionThreshold => StringComparer.OrdinalIgnoreCase; protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => StringComparer.OrdinalIgnoreCase.GetType(); + + protected override bool SupportsAlternateLookup(OrderedDictionary collection) => false; } public class InternalHashCodeTests_OrderedDictionary_LinguisticComparer : InternalHashCodeTests> // (not optimized) @@ -237,6 +267,8 @@ public class InternalHashCodeTests_OrderedDictionary_LinguisticComparer : Intern protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => StringComparer.InvariantCulture.GetType(); protected override IEqualityComparer ExpectedPublicComparerBeforeCollisionThreshold => StringComparer.InvariantCulture; protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => StringComparer.InvariantCulture.GetType(); + + protected override bool SupportsAlternateLookup(OrderedDictionary collection) => false; } #endregion @@ -252,6 +284,8 @@ public abstract class InternalHashCodeTests protected abstract void AddKey(TCollection collection, string key); protected abstract bool ContainsKey(TCollection collection, string key); protected abstract IEqualityComparer GetComparer(TCollection collection); + protected virtual bool SupportsAlternateLookup(TCollection collection) => true; + protected virtual bool ContainsKey(TCollection collection, ReadOnlySpan key) => throw new NotSupportedException(); protected abstract Type ExpectedInternalComparerTypeBeforeCollisionThreshold { get; } protected abstract IEqualityComparer ExpectedPublicComparerBeforeCollisionThreshold { get; } @@ -305,6 +339,10 @@ public void ComparerImplementations_Dictionary_WithWellKnownStringComparers() foreach (string key in allKeys) { Assert.True(ContainsKey(collection, key)); + if (SupportsAlternateLookup(collection)) + { + Assert.True(ContainsKey(collection, key.AsSpan())); + } } // Also make sure we didn't accidentally put the internal comparer in the serialized object data. diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs index 45fd297e5af27..659ac09c7259e 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; namespace System.Collections.Generic { @@ -46,15 +47,29 @@ private struct MarvinSeed internal uint p1; } - private sealed class OrdinalComparer : RandomizedStringEqualityComparer + private sealed class OrdinalComparer : RandomizedStringEqualityComparer, IAlternateEqualityComparer, string?> { internal OrdinalComparer(IEqualityComparer wrappedComparer) : base(wrappedComparer) { } + string IAlternateEqualityComparer, string?>.Create(ReadOnlySpan span) => + span.ToString(); + public override bool Equals(string? x, string? y) => string.Equals(x, y); + bool IAlternateEqualityComparer, string?>.Equals(ReadOnlySpan alternate, string? other) + { + // See explanation in System.OrdinalComparer.Equals. + if (alternate.IsEmpty && other is null) + { + return false; + } + + return alternate.SequenceEqual(other); + } + public override int GetHashCode(string? obj) { if (obj is null) @@ -69,17 +84,37 @@ ref Unsafe.As(ref obj.GetRawStringData()), (uint)obj.Length * 2, _seed.p0, _seed.p1); } + + int IAlternateEqualityComparer, string?>.GetHashCode(ReadOnlySpan alternate) => + Marvin.ComputeHash32( + ref Unsafe.As(ref MemoryMarshal.GetReference(alternate)), + (uint)alternate.Length * 2, + _seed.p0, _seed.p1); } - private sealed class OrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer + private sealed class OrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer, IAlternateEqualityComparer, string?> { internal OrdinalIgnoreCaseComparer(IEqualityComparer wrappedComparer) : base(wrappedComparer) { } + string IAlternateEqualityComparer, string?>.Create(ReadOnlySpan span) => + span.ToString(); + public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y); + bool IAlternateEqualityComparer, string?>.Equals(ReadOnlySpan alternate, string? other) + { + // See explanation in System.OrdinalComparer.Equals. + if (alternate.IsEmpty && other is null) + { + return false; + } + + return alternate.EqualsOrdinalIgnoreCase(other); + } + public override int GetHashCode(string? obj) { if (obj is null) @@ -94,6 +129,12 @@ ref obj.GetRawStringData(), obj.Length, _seed.p0, _seed.p1); } + + int IAlternateEqualityComparer, string?>.GetHashCode(ReadOnlySpan alternate) => + Marvin.ComputeHash32OrdinalIgnoreCase( + ref MemoryMarshal.GetReference(alternate), + alternate.Length, + _seed.p0, _seed.p1); } } }