Skip to content

Commit

Permalink
Fix to #35239 - EF9: SaveChanges() is significantly slower in .NET9 v…
Browse files Browse the repository at this point in the history
…s. .NET8 when using .ToJson() Mapping vs. PostgreSQL Legacy POCO mapping

Problem was that as part of AOT refactoring we changed way that we build comparers. Specifically, comparers of collections - ListOfValueTypesComparer, ListOfNullableValueTypesComparer and ListOfReferenceTypesComparer. Before those list comparer Compare, Hashcode and Snapshot methods would take as argument element comparer, which was responsible for comparing elements. We need to be able to express these in code for AOT but we are not able to generate constant of type ValueComparer (or ValueComparer) that was needed. As a solution, each comparer now stores expression describing how it can be constructed, so we use that instead (as we are perfectly capable to expressing that in code form). Problem is that now every time compare, snapshot or hashcode method is called for array type, we construct new ValueComparer for the element type. As a result in the reported case we would generate 1000s of comparers which all have to be compiled and that causes huge overhead.

Fix is to pass relevant func from the element comparer to the outer comparer. We only passed the element comparer object to the outer Compare/Hashcode/Snapshot function to call that relevant func. This way we avoid constructing redundant comparers.

For ListOfReferenceTypesComparer it's a bit trickier - TElement of the outer comparer doesn't always match TElement of the element comparer (presumably to allow support of nested array types in the future), we we can't blindly copy over Equals/Hashcode/Snapshot method lambdas in case types are not matching. For now we check if the both type arguments match - if they do, we copy the funcs, otherwise we fallback to the old behavior, i.e. construct new element comparer so that it can deal with type argument discrepancies.

Fixes #35239
  • Loading branch information
maumar committed Dec 14, 2024
1 parent c099cef commit 9b4fbbf
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 93 deletions.
179 changes: 158 additions & 21 deletions src/EFCore.Cosmos/ChangeTracking/Internal/StringDictionaryComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,21 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.ChangeTracking.Internal;
public sealed class StringDictionaryComparer<TDictionary, TElement> : ValueComparer<object>, IInfrastructure<ValueComparer>
{
private static readonly MethodInfo CompareMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(object), typeof(Func<TElement, TElement, bool>)])!;

private static readonly MethodInfo LegacyCompareMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(object), typeof(ValueComparer)])!;

private static readonly MethodInfo GetHashCodeMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(Func<TElement, int>)])!;

private static readonly MethodInfo LegacyGetHashCodeMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic, [typeof(IEnumerable), typeof(ValueComparer)])!;

private static readonly MethodInfo SnapshotMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(Func<TElement, TElement>)])!;

private static readonly MethodInfo LegacySnapshotMethod = typeof(StringDictionaryComparer<TDictionary, TElement>).GetMethod(
nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic, [typeof(object), typeof(ValueComparer)])!;

/// <summary>
Expand Down Expand Up @@ -52,46 +61,134 @@ ValueComparer IInfrastructure<ValueComparer>.Instance
var prm1 = Expression.Parameter(typeof(object), "a");
var prm2 = Expression.Parameter(typeof(object), "b");

return Expression.Lambda<Func<object?, object?, bool>>(
Expression.Call(
CompareMethod,
if (elementComparer is ValueComparer<TElement>)
{
// (a, b) => Compare(a, b, elementComparer.Equals)
return Expression.Lambda<Func<object?, object?, bool>>(
Expression.Call(
CompareMethod,
prm1,
prm2,
elementComparer.EqualsExpression),
prm1,
prm2,
prm2);
}
else
{
// (a, b) => Compare(a, b, new Comparer(...))
return Expression.Lambda<Func<object?, object?, bool>>(
Expression.Call(
LegacyCompareMethod,
prm1,
prm2,
#pragma warning disable EF9100
elementComparer.ConstructorExpression),
elementComparer.ConstructorExpression),
#pragma warning restore EF9100
prm1,
prm2);
prm1,
prm2);
}
}

private static Expression<Func<object, int>> GetHashCodeLambda(ValueComparer elementComparer)
{
var prm = Expression.Parameter(typeof(object), "o");

return Expression.Lambda<Func<object, int>>(
Expression.Call(
GetHashCodeMethod,
Expression.Convert(
prm,
typeof(IEnumerable)),
if (elementComparer is ValueComparer<TElement>)
{
// o => GetHashCode((IEnumerable)o, elementComparer.GetHashCode)
return Expression.Lambda<Func<object, int>>(
Expression.Call(
GetHashCodeMethod,
Expression.Convert(
prm,
typeof(IEnumerable)),
elementComparer.HashCodeExpression),
prm);
}
else
{
// o => GetHashCode((IEnumerable)o, new Comparer(...))
return Expression.Lambda<Func<object, int>>(
Expression.Call(
LegacyGetHashCodeMethod,
Expression.Convert(
prm,
typeof(IEnumerable)),
#pragma warning disable EF9100
elementComparer.ConstructorExpression),
elementComparer.ConstructorExpression),
#pragma warning restore EF9100
prm);
prm);
}
}

private static Expression<Func<object, object>> SnapshotLambda(ValueComparer elementComparer)
{
var prm = Expression.Parameter(typeof(object), "source");

return Expression.Lambda<Func<object, object>>(
Expression.Call(
SnapshotMethod,
prm,
if (elementComparer is ValueComparer<TElement>)
{
// source => Snapshot(source, elementComparer.Snapshot)
return Expression.Lambda<Func<object, object>>(
Expression.Call(
SnapshotMethod,
prm,
elementComparer.SnapshotExpression),
prm);
}
else
{
// source => Snapshot(source, new Comparer(..))
return Expression.Lambda<Func<object, object>>(
Expression.Call(
LegacySnapshotMethod,
prm,
#pragma warning disable EF9100
elementComparer.ConstructorExpression),
elementComparer.ConstructorExpression),
#pragma warning restore EF9100
prm);
prm);
}
}

private static bool Compare(object? a, object? b, Func<TElement?, TElement?, bool> elementCompare)
{
if (ReferenceEquals(a, b))
{
return true;
}

if (a is null)
{
return b is null;
}

if (b is null)
{
return false;
}

if (a is IReadOnlyDictionary<string, TElement?> aDictionary && b is IReadOnlyDictionary<string, TElement?> bDictionary)
{
if (aDictionary.Count != bDictionary.Count)
{
return false;
}

foreach (var pair in aDictionary)
{
if (!bDictionary.TryGetValue(pair.Key, out var bValue)
|| !elementCompare(pair.Value, bValue))
{
return false;
}
}

return true;
}

throw new InvalidOperationException(
CosmosStrings.BadDictionaryType(
(a is IDictionary<string, TElement?> ? b : a).GetType().ShortDisplayName(),
typeof(IDictionary<,>).MakeGenericType(typeof(string), typeof(TElement)).ShortDisplayName()));
}

private static bool Compare(object? a, object? b, ValueComparer elementComparer)
Expand Down Expand Up @@ -136,6 +233,27 @@ private static bool Compare(object? a, object? b, ValueComparer elementComparer)
typeof(IDictionary<,>).MakeGenericType(typeof(string), elementComparer.Type).ShortDisplayName()));
}

private static int GetHashCode(IEnumerable source, Func<TElement?, int> elementGetHashCode)
{
if (source is not IReadOnlyDictionary<string, TElement?> sourceDictionary)
{
throw new InvalidOperationException(
CosmosStrings.BadDictionaryType(
source.GetType().ShortDisplayName(),
typeof(IList<>).MakeGenericType(typeof(TElement)).ShortDisplayName()));
}

var hash = new HashCode();

foreach (var pair in sourceDictionary)
{
hash.Add(pair.Key);
hash.Add(pair.Value == null ? 0 : elementGetHashCode(pair.Value));
}

return hash.ToHashCode();
}

private static int GetHashCode(IEnumerable source, ValueComparer elementComparer)
{
if (source is not IReadOnlyDictionary<string, TElement?> sourceDictionary)
Expand All @@ -157,6 +275,25 @@ private static int GetHashCode(IEnumerable source, ValueComparer elementComparer
return hash.ToHashCode();
}

private static IReadOnlyDictionary<string, TElement?> Snapshot(object source, Func<TElement?, TElement?> elementSnapshot)
{
if (source is not IReadOnlyDictionary<string, TElement?> sourceDictionary)
{
throw new InvalidOperationException(
CosmosStrings.BadDictionaryType(
source.GetType().ShortDisplayName(),
typeof(IDictionary<,>).MakeGenericType(typeof(string), typeof(TElement)).ShortDisplayName()));
}

var snapshot = new Dictionary<string, TElement?>();
foreach (var pair in sourceDictionary)
{
snapshot[pair.Key] = pair.Value == null ? default : (TElement?)elementSnapshot(pair.Value);
}

return snapshot;
}

private static IReadOnlyDictionary<string, TElement?> Snapshot(object source, ValueComparer elementComparer)
{
if (source is not IReadOnlyDictionary<string, TElement?> sourceDictionary)
Expand Down
42 changes: 18 additions & 24 deletions src/EFCore/ChangeTracking/ListOfNullableValueTypesComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ public sealed class ListOfNullableValueTypesComparer<TConcreteList, TElement> :

private static readonly MethodInfo CompareMethod = typeof(ListOfNullableValueTypesComparer<TConcreteList, TElement>).GetMethod(
nameof(Compare), BindingFlags.Static | BindingFlags.NonPublic,
[typeof(IEnumerable<TElement?>), typeof(IEnumerable<TElement?>), typeof(ValueComparer<TElement?>)])!;
[typeof(IEnumerable<TElement?>), typeof(IEnumerable<TElement?>), typeof(Func<TElement?, TElement?, bool>)])!;

private static readonly MethodInfo GetHashCodeMethod = typeof(ListOfNullableValueTypesComparer<TConcreteList, TElement>).GetMethod(
nameof(GetHashCode), BindingFlags.Static | BindingFlags.NonPublic,
[typeof(IEnumerable<TElement?>), typeof(ValueComparer<TElement?>)])!;
[typeof(IEnumerable<TElement?>), typeof(Func<TElement?, int>)])!;

private static readonly MethodInfo SnapshotMethod = typeof(ListOfNullableValueTypesComparer<TConcreteList, TElement>).GetMethod(
nameof(Snapshot), BindingFlags.Static | BindingFlags.NonPublic,
[typeof(IEnumerable<TElement?>), typeof(ValueComparer<TElement?>)])!;
[typeof(IEnumerable<TElement?>), typeof(Func<TElement?, TElement?>)])!;

/// <summary>
/// Creates a new instance of the list comparer.
Expand All @@ -67,15 +67,13 @@ ValueComparer IInfrastructure<ValueComparer>.Instance
var prm1 = Expression.Parameter(typeof(IEnumerable<TElement?>), "a");
var prm2 = Expression.Parameter(typeof(IEnumerable<TElement?>), "b");

//(a, b) => Compare(a, b, (ValueComparer<TElement?>)elementComparer)
//(a, b) => Compare(a, b, elementComparer.Equals)
return Expression.Lambda<Func<IEnumerable<TElement?>?, IEnumerable<TElement?>?, bool>>(
Expression.Call(
CompareMethod,
prm1,
prm2,
Expression.Convert(
elementComparer.ConstructorExpression,
typeof(ValueComparer<TElement?>))),
elementComparer.EqualsExpression),
prm1,
prm2);
}
Expand All @@ -84,33 +82,29 @@ ValueComparer IInfrastructure<ValueComparer>.Instance
{
var prm = Expression.Parameter(typeof(IEnumerable<TElement?>), "o");

//o => GetHashCode(o, (ValueComparer<TElement?>)elementComparer)
//o => GetHashCode(o, elementComparer.GetHashCode)
return Expression.Lambda<Func<IEnumerable<TElement?>, int>>(
Expression.Call(
GetHashCodeMethod,
prm,
Expression.Convert(
elementComparer.ConstructorExpression,
typeof(ValueComparer<TElement?>))),
elementComparer.HashCodeExpression),
prm);
}

private static Expression<Func<IEnumerable<TElement?>, IEnumerable<TElement?>>> SnapshotLambda(ValueComparer elementComparer)
{
var prm = Expression.Parameter(typeof(IEnumerable<TElement?>), "source");

//source => Snapshot(source, (ValueComparer<TElement?>)elementComparer)
//source => Snapshot(source, elementComparer.Snapshot)
return Expression.Lambda<Func<IEnumerable<TElement?>, IEnumerable<TElement?>>>(
Expression.Call(
SnapshotMethod,
prm,
Expression.Convert(
elementComparer.ConstructorExpression,
typeof(ValueComparer<TElement?>))),
elementComparer.SnapshotExpression),
prm);
}

private static bool Compare(IEnumerable<TElement?>? a, IEnumerable<TElement?>? b, ValueComparer<TElement?> elementComparer)
private static bool Compare(IEnumerable<TElement?>? a, IEnumerable<TElement?>? b, Func<TElement?, TElement?, bool> elementCompare)
{
if (ReferenceEquals(a, b))
{
Expand Down Expand Up @@ -152,7 +146,7 @@ private static bool Compare(IEnumerable<TElement?>? a, IEnumerable<TElement?>? b
return false;
}

if (!elementComparer.Equals(el1, el2))
if (!elementCompare(el1, el2))
{
return false;
}
Expand All @@ -164,29 +158,29 @@ private static bool Compare(IEnumerable<TElement?>? a, IEnumerable<TElement?>? b
throw new InvalidOperationException(
CoreStrings.BadListType(
(a is IList<TElement?> ? b : a).GetType().ShortDisplayName(),
typeof(IList<>).MakeGenericType(elementComparer.Type.MakeNullable()).ShortDisplayName()));
typeof(IList<>).MakeGenericType(typeof(TElement).MakeNullable()).ShortDisplayName()));
}

private static int GetHashCode(IEnumerable<TElement?> source, ValueComparer<TElement?> elementComparer)
private static int GetHashCode(IEnumerable<TElement?> source, Func<TElement?, int> elementGetHashCode)
{
var hash = new HashCode();

foreach (var el in source)
{
hash.Add(el == null ? 0 : elementComparer.GetHashCode(el));
hash.Add(el == null ? 0 : elementGetHashCode(el));
}

return hash.ToHashCode();
}

private static IList<TElement?> Snapshot(IEnumerable<TElement?> source, ValueComparer<TElement?> elementComparer)
private static IList<TElement?> Snapshot(IEnumerable<TElement?> source, Func<TElement?, TElement?> elementSnapshot)
{
if (source is not IList<TElement?> sourceList)
{
throw new InvalidOperationException(
CoreStrings.BadListType(
source.GetType().ShortDisplayName(),
typeof(IList<>).MakeGenericType(elementComparer.Type.MakeNullable()).ShortDisplayName()));
typeof(IList<>).MakeGenericType(typeof(TElement).MakeNullable()).ShortDisplayName()));
}

if (IsArray)
Expand All @@ -195,7 +189,7 @@ private static int GetHashCode(IEnumerable<TElement?> source, ValueComparer<TEle
for (var i = 0; i < sourceList.Count; i++)
{
var instance = sourceList[i];
snapshot[i] = instance == null ? null : elementComparer.Snapshot(instance);
snapshot[i] = instance == null ? null : elementSnapshot(instance);
}

return snapshot;
Expand All @@ -205,7 +199,7 @@ private static int GetHashCode(IEnumerable<TElement?> source, ValueComparer<TEle
var snapshot = IsReadOnly ? new List<TElement?>() : (IList<TElement?>)Activator.CreateInstance<TConcreteList>()!;
foreach (var e in sourceList)
{
snapshot.Add(e == null ? null : elementComparer.Snapshot(e));
snapshot.Add(e == null ? null : elementSnapshot(e));
}

return IsReadOnly
Expand Down
Loading

0 comments on commit 9b4fbbf

Please sign in to comment.