Skip to content

Commit

Permalink
Tweaks to ValueComparer nullability
Browse files Browse the repository at this point in the history
Following #24261
  • Loading branch information
roji committed Mar 16, 2021
1 parent 847bf3d commit a29ced6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
2 changes: 2 additions & 0 deletions src/EFCore/ChangeTracking/ValueComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Utilities;
using CA = System.Diagnostics.CodeAnalysis;

namespace Microsoft.EntityFrameworkCore.ChangeTracking
{
Expand Down Expand Up @@ -103,6 +104,7 @@ protected ValueComparer(
/// </summary>
/// <param name="instance"> The instance. </param>
/// <returns> The snapshot. </returns>
[return: CA.NotNullIfNotNull("instance")]
public abstract object? Snapshot([CanBeNull] object? instance);

/// <summary>
Expand Down
34 changes: 18 additions & 16 deletions src/EFCore/ChangeTracking/ValueComparer`.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Internal;
using CA = System.Diagnostics.CodeAnalysis;

namespace Microsoft.EntityFrameworkCore.ChangeTracking
{
Expand All @@ -31,7 +32,7 @@ public class ValueComparer<T> : ValueComparer, IEqualityComparer<T>
{
private Func<T?, T?, bool>? _equals;
private Func<T, int>? _hashCode;
private Func<T?, T?>? _snapshot;
private Func<T, T>? _snapshot;

/// <summary>
/// Creates a new <see cref="ValueComparer{T}" /> with a default comparison
Expand Down Expand Up @@ -80,7 +81,7 @@ public ValueComparer(
public ValueComparer(
[NotNull] Expression<Func<T?, T?, bool>> equalsExpression,
[NotNull] Expression<Func<T, int>> hashCodeExpression,
[NotNull] Expression<Func<T?, T?>> snapshotExpression)
[NotNull] Expression<Func<T, T>> snapshotExpression)
: base(equalsExpression, hashCodeExpression, snapshotExpression)
{
}
Expand Down Expand Up @@ -159,7 +160,7 @@ public ValueComparer(
/// Creates an expression for creating a snapshot of a value.
/// </summary>
/// <returns> The snapshot expression. </returns>
protected static Expression<Func<T?, T?>> CreateDefaultSnapshotExpression(bool favorStructuralComparisons)
protected static Expression<Func<T, T>> CreateDefaultSnapshotExpression(bool favorStructuralComparisons)
{
if (!favorStructuralComparisons
|| !typeof(T).IsArray)
Expand All @@ -176,7 +177,7 @@ public ValueComparer(
// var destination = new T[length];
// Array.Copy(source, destination, length);
// return destination;
return Expression.Lambda<Func<T?, T?>>(
return Expression.Lambda<Func<T, T>>(
Expression.Block(
new[] { lengthVariable, destinationVariable },
Expression.Assign(
Expand Down Expand Up @@ -243,20 +244,21 @@ var expression
/// <param name="right"> The second instance. </param>
/// <returns> <see langword="true" /> if they are equal; <see langword="false" /> otherwise. </returns>
public override bool Equals(object? left, object? right)
{
var v1Null = left == null;
var v2Null = right == null;

return v1Null || v2Null ? v1Null && v2Null : Equals((T?)left, (T?)right);
}
=> (left, right) switch
{
(null, null) => true,
(null, _) => false,
(_, null) => false,
_ => Equals((T)left, (T)right)
};

/// <summary>
/// Returns the hash code for the given instance.
/// </summary>
/// <param name="instance"> The instance. </param>
/// <returns> The hash code. </returns>
public override int GetHashCode(object? instance)
=> instance == null ? 0 : GetHashCode((T)instance);
public override int GetHashCode(object instance)
=> GetHashCode((T)instance);

/// <summary>
/// Compares the two instances to determine if they are equal.
Expand Down Expand Up @@ -291,7 +293,7 @@ public virtual int GetHashCode(T instance)
/// <param name="instance"> The instance. </param>
/// <returns> The snapshot. </returns>
public override object? Snapshot(object? instance)
=> instance == null ? null : (object?)Snapshot((T?)instance);
=> instance == null ? null : (object?)Snapshot((T)instance);

/// <summary>
/// <para>
Expand All @@ -306,7 +308,7 @@ public virtual int GetHashCode(T instance)
/// </summary>
/// <param name="instance"> The instance. </param>
/// <returns> The snapshot. </returns>
public virtual T? Snapshot([CanBeNull] T? instance)
public virtual T Snapshot([NotNull] T instance)
=> NonCapturingLazyInitializer.EnsureInitialized(
ref _snapshot, this, static c => c.SnapshotExpression.Compile())(instance);

Expand Down Expand Up @@ -339,7 +341,7 @@ public override Type Type
/// reference.
/// </para>
/// </summary>
public new virtual Expression<Func<T?, T?>> SnapshotExpression
=> (Expression<Func<T?, T?>>)base.SnapshotExpression;
public new virtual Expression<Func<T, T>> SnapshotExpression
=> (Expression<Func<T, T>>)base.SnapshotExpression;
}
}

0 comments on commit a29ced6

Please sign in to comment.