Skip to content

Commit

Permalink
Tweaks to ValueComparer nullability (#24410)
Browse files Browse the repository at this point in the history
* Make Snapshot accept/receive non-nullable (nulls are sanitized
  externally).
* Make ValueComparer<T>.GetHashCode accept non-nullable object.
  • Loading branch information
roji committed Aug 20, 2021
1 parent 43e023c commit c3a71be
Show file tree
Hide file tree
Showing 15 changed files with 48 additions and 79 deletions.
9 changes: 2 additions & 7 deletions src/EFCore.Cosmos/ChangeTracking/Internal/ListComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,8 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
return hash.ToHashCode();
}

private static TCollection? Snapshot(TCollection? source, ValueComparer<TElement> elementComparer, bool readOnly)
private static TCollection Snapshot(TCollection source, ValueComparer<TElement> elementComparer, bool readOnly)
{
if (source is null)
{
return null;
}

if (readOnly)
{
return source;
Expand All @@ -92,7 +87,7 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
var snapshot = new List<TElement>(((IReadOnlyList<TElement>)source).Count);
foreach (var e in source)
{
snapshot.Add(elementComparer.Snapshot(e)!);
snapshot.Add(e is null ? default! : elementComparer.Snapshot(e));
}

return (TCollection)(object)snapshot;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,8 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
return hash.ToHashCode();
}

private static TCollection? Snapshot(TCollection? source, ValueComparer<TElement> elementComparer, bool readOnly)
private static TCollection Snapshot(TCollection source, ValueComparer<TElement> elementComparer, bool readOnly)
{
if (source is null)
{
return null;
}

if (readOnly)
{
return source;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,8 @@ private static int GetHashCode(TElement?[] source, ValueComparer<TElement> eleme
}

[return: NotNullIfNotNull("source")]
private static TElement?[]? Snapshot(TElement?[]? source, ValueComparer<TElement> elementComparer)
private static TElement?[] Snapshot(TElement?[] source, ValueComparer<TElement> elementComparer)
{
if (source is null)
{
return null;
}

var snapshot = new TElement?[source.Length];
for (var i = 0; i < source.Length; i++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,8 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
return hash.ToHashCode();
}

private static TCollection? Snapshot(TCollection? source, ValueComparer<TElement> elementComparer, bool readOnly)
private static TCollection Snapshot(TCollection source, ValueComparer<TElement> elementComparer, bool readOnly)
{
if (source is null)
{
return null;
}

if (readOnly)
{
return source;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,13 @@ private static int GetHashCode(TElement[] source, ValueComparer<TElement> elemen
}

[return: NotNullIfNotNull("source")]
private static TElement[]? Snapshot(TElement[]? source, ValueComparer<TElement> elementComparer)
private static TElement[] Snapshot(TElement[] source, ValueComparer<TElement> elementComparer)
{
if (source is null)
{
return null;
}

var snapshot = new TElement[source.Length];
for (var i = 0; i < source.Length; i++)
{
snapshot[i] = elementComparer.Snapshot(source[i])!;
var element = source[i];
snapshot[i] = element is null ? default! : elementComparer.Snapshot(source[i]);
}
return snapshot;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,8 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
return hash.ToHashCode();
}

private static TCollection? Snapshot(TCollection? source, ValueComparer<TElement> elementComparer, bool readOnly)
private static TCollection Snapshot(TCollection source, ValueComparer<TElement> elementComparer, bool readOnly)
{
if (source is null)
{
return null;
}

if (readOnly)
{
return source;
Expand All @@ -94,7 +89,7 @@ private static int GetHashCode(TCollection source, ValueComparer<TElement> eleme
var snapshot = new Dictionary<string, TElement>(((IReadOnlyDictionary<string, TElement>)source).Count);
foreach (var e in source)
{
snapshot.Add(e.Key, elementComparer.Snapshot(e.Value)!);
snapshot.Add(e.Key, e.Value is null ? default! : elementComparer.Snapshot(e.Value));
}

return (TCollection)(object)snapshot;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ private readonly SqlServerByteArrayTypeMapping _rowversion
comparer: new ValueComparer<byte[]>(
(v1, v2) => StructuralComparisons.StructuralEqualityComparer.Equals(v1, v2),
v => StructuralComparisons.StructuralEqualityComparer.GetHashCode(v),
v => v == null ? null : v.ToArray()),
v => v.ToArray()),
storeTypePostfix: StoreTypePostfix.None);

private readonly IntTypeMapping _int
Expand Down
2 changes: 1 addition & 1 deletion src/EFCore/ChangeTracking/ArrayStructuralComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public ArrayStructuralComparer()
: base(
CreateDefaultEqualsExpression(),
CreateDefaultHashCodeExpression(favorStructuralComparisons: true),
v => v == null ? null : v.ToArray())
v => v.ToArray())
{
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/EFCore/ChangeTracking/GeometryValueComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public GeometryValueComparer()
right);
}

private static Expression<Func<TGeometry?, TGeometry?>> GetSnapshotExpression()
private static Expression<Func<TGeometry, TGeometry>> GetSnapshotExpression()
{
var instance = Expression.Parameter(typeof(TGeometry), "instance");

Expand All @@ -71,7 +71,7 @@ public GeometryValueComparer()
body = Expression.Convert(body, typeof(TGeometry));
}

return Expression.Lambda<Func<TGeometry?, TGeometry?>>(body, instance);
return Expression.Lambda<Func<TGeometry, TGeometry>>(body, instance);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public NoNullsCustomEqualityComparer(ValueComparer comparer)
public bool Equals(TKey? x, TKey? y)
=> _equals(x, y);

public int GetHashCode(TKey obj)
public int GetHashCode([DisallowNull] TKey obj)
=> _hashCode(obj);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public NonNullNullableValueComparer(
: base(
(Expression<Func<T?, T?, bool>>)equalsExpression,
(Expression<Func<T, int>>)hashCodeExpression,
(Expression<Func<T?, T?>>)snapshotExpression)
(Expression<Func<T, T>>)snapshotExpression)
{
}
}
Expand Down
26 changes: 14 additions & 12 deletions src/EFCore/ChangeTracking/ValueComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand Down Expand Up @@ -106,6 +107,7 @@ protected ValueComparer(
/// </summary>
/// <param name="instance"> The instance. </param>
/// <returns> The snapshot. </returns>
[return: NotNullIfNotNull("instance")]
public abstract object? Snapshot(object? instance);

/// <summary>
Expand Down Expand Up @@ -196,31 +198,31 @@ public virtual Expression ExtractSnapshotBody(Expression expression)
/// <returns> The <see cref="ValueComparer{T}" />. </returns>
public static ValueComparer CreateDefault(Type type, bool favorStructuralComparisons)
{
var nonNullabletype = type.UnwrapNullableType();
var nonNullableType = type.UnwrapNullableType();

// The equality operator returns false for NaNs, but the Equals methods returns true
if (nonNullabletype == typeof(double))
if (nonNullableType == typeof(double))
{
return new DefaultDoubleValueComparer(favorStructuralComparisons);
}

if (nonNullabletype == typeof(float))
if (nonNullableType == typeof(float))
{
return new DefaultFloatValueComparer(favorStructuralComparisons);
}

if (nonNullabletype == typeof(DateTimeOffset))
if (nonNullableType == typeof(DateTimeOffset))
{
return new DefaultDateTimeOffsetValueComparer(favorStructuralComparisons);
}

var comparerType = nonNullabletype.IsInteger()
|| nonNullabletype == typeof(decimal)
|| nonNullabletype == typeof(bool)
|| nonNullabletype == typeof(string)
|| nonNullabletype == typeof(DateTime)
|| nonNullabletype == typeof(Guid)
|| nonNullabletype == typeof(TimeSpan)
var comparerType = nonNullableType.IsInteger()
|| nonNullableType == typeof(decimal)
|| nonNullableType == typeof(bool)
|| nonNullableType == typeof(string)
|| nonNullableType == typeof(DateTime)
|| nonNullableType == typeof(Guid)
|| nonNullableType == typeof(TimeSpan)
? typeof(DefaultValueComparer<>)
: typeof(ValueComparer<>);

Expand Down Expand Up @@ -253,7 +255,7 @@ public override Expression ExtractSnapshotBody(Expression expression)
public override object? Snapshot(object? instance)
=> instance;

public override T? Snapshot(T? instance)
public override T Snapshot(T instance)
=> instance;
}

Expand Down
20 changes: 10 additions & 10 deletions src/EFCore/ChangeTracking/ValueComparer`.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class ValueComparer<T> : ValueComparer, IEqualityComparer<T>
{
private Func<T?, T?, bool>? _equals;
private Func<T, int>? _hashCode;
private Func<T?, T?>? _snapshot;
private Func<T, T>? _snapshot;

/// <summary>
/// Creates a new <see cref="ValueComparer{T}" /> with a default comparison
Expand Down Expand Up @@ -82,7 +82,7 @@ public ValueComparer(
public ValueComparer(
Expression<Func<T?, T?, bool>> equalsExpression,
Expression<Func<T, int>> hashCodeExpression,
Expression<Func<T?, T?>> snapshotExpression)
Expression<Func<T, T>> snapshotExpression)
: base(equalsExpression, hashCodeExpression, snapshotExpression)
{
}
Expand Down Expand Up @@ -161,7 +161,7 @@ public ValueComparer(
/// Creates an expression for creating a snapshot of a value.
/// </summary>
/// <returns> The snapshot expression. </returns>
protected static Expression<Func<T?, T?>> CreateDefaultSnapshotExpression(bool favorStructuralComparisons)
protected static Expression<Func<T, T>> CreateDefaultSnapshotExpression(bool favorStructuralComparisons)
{
if (!favorStructuralComparisons
|| !typeof(T).IsArray)
Expand All @@ -178,7 +178,7 @@ public ValueComparer(
// var destination = new T[length];
// Array.Copy(source, destination, length);
// return destination;
return Expression.Lambda<Func<T?, T?>>(
return Expression.Lambda<Func<T, T>>(
Expression.Block(
new[] { lengthVariable, destinationVariable },
Expression.Assign(
Expand Down Expand Up @@ -257,8 +257,8 @@ public override bool Equals(object? left, object? right)
/// </summary>
/// <param name="instance"> The instance. </param>
/// <returns> The hash code. </returns>
public override int GetHashCode(object? instance)
=> instance == null ? 0 : GetHashCode((T)instance);
public override int GetHashCode(object instance)
=> instance is null ? 0 : GetHashCode((T)instance);

/// <summary>
/// Compares the two instances to determine if they are equal.
Expand Down Expand Up @@ -293,7 +293,7 @@ public virtual int GetHashCode(T instance)
/// <param name="instance"> The instance. </param>
/// <returns> The snapshot. </returns>
public override object? Snapshot(object? instance)
=> instance == null ? null : Snapshot((T?)instance);
=> instance == null ? null : Snapshot((T)instance);

/// <summary>
/// <para>
Expand All @@ -308,7 +308,7 @@ public virtual int GetHashCode(T instance)
/// </summary>
/// <param name="instance"> The instance. </param>
/// <returns> The snapshot. </returns>
public virtual T? Snapshot(T? instance)
public virtual T Snapshot(T instance)
=> NonCapturingLazyInitializer.EnsureInitialized(
ref _snapshot, this, static c => c.SnapshotExpression.Compile())(instance);

Expand Down Expand Up @@ -341,7 +341,7 @@ public override Type Type
/// reference.
/// </para>
/// </summary>
public new virtual Expression<Func<T?, T?>> SnapshotExpression
=> (Expression<Func<T?, T?>>)base.SnapshotExpression;
public new virtual Expression<Func<T, T>> SnapshotExpression
=> (Expression<Func<T, T>>)base.SnapshotExpression;
}
}
18 changes: 8 additions & 10 deletions test/EFCore.Cosmos.FunctionalTests/EndToEndCosmosTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -606,14 +606,13 @@ await Can_add_update_delete_with_collection(
public async Task Can_add_update_delete_with_nested_collections()
{
await Can_add_update_delete_with_collection(
new List<List<short>> { new List<short> { 1, 2 } },
new List<List<short>> { new() { 1, 2 } },
c =>
{
c.Collection.Clear();
c.Collection.Add(new List<short> { 3 });
},
new List<List<short>> { new List<short> { 3 } });

new List<List<short>> { new() { 3 } });
await Can_add_update_delete_with_collection<IList<byte?[]>>(
new List<byte?[]>(),
c =>
Expand All @@ -622,30 +621,29 @@ await Can_add_update_delete_with_collection(
c.Collection.Add(null);
},
new List<byte?[]> { new byte?[] { 3, null }, null });

await Can_add_update_delete_with_collection<IReadOnlyList<Dictionary<string, string>>>(
new Dictionary<string, string>[] { new Dictionary<string, string> { { "1", null } } },
new Dictionary<string, string>[] { new() { { "1", null } } },
c =>
{
var dictionary = c.Collection[0]["3"] = "2";
},
new List<Dictionary<string, string>> { new Dictionary<string, string> { { "1", null }, { "3", "2" } } });
new List<Dictionary<string, string>> { new() { { "1", null }, { "3", "2" } } });

await Can_add_update_delete_with_collection(
new List<float>[] { new List<float> { 1f }, new List<float> { 2 } },
new List<float>[] { new() { 1f }, new() { 2 } },
c =>
{
c.Collection[1][0] = 3f;
},
new List<float>[] { new List<float> { 1f }, new List<float> { 3f } });
new List<float>[] { new() { 1f }, new() { 3f } });

await Can_add_update_delete_with_collection(
new decimal?[][] { new decimal?[] { 1, null } },
new[] { new decimal?[] { 1, null } },
c =>
{
c.Collection[0][1] = 3;
},
new decimal?[][] { new decimal?[] { 1, 3 } });
new[] { new decimal?[] { 1, 3 } });

await Can_add_update_delete_with_collection(
new Dictionary<string, List<int>> { { "1", new List<int> { 1 } } },
Expand Down
2 changes: 0 additions & 2 deletions test/EFCore.Tests/Storage/ValueComparerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ private static ValueComparer CompareTest(Type type, object value1, object value2
Assert.False(comparer.Equals(null, value2));
Assert.True(comparer.Equals(null, null));

Assert.Equal(0, comparer.GetHashCode(null));
Assert.Equal(hashCode ?? value1.GetHashCode(), comparer.GetHashCode(value1));

var keyComparer = (ValueComparer)Activator.CreateInstance(typeof(ValueComparer<>).MakeGenericType(type), new object[] { true });
Expand All @@ -102,7 +101,6 @@ private static ValueComparer CompareTest(Type type, object value1, object value2
Assert.False(keyComparer.Equals(null, value2));
Assert.True(keyComparer.Equals(null, null));

Assert.Equal(0, keyComparer.GetHashCode(null));
Assert.Equal(hashCode ?? value1.GetHashCode(), keyComparer.GetHashCode(value1));

return comparer;
Expand Down

0 comments on commit c3a71be

Please sign in to comment.