Skip to content

Commit

Permalink
Translate array.Where(i => i != x) to array_remove (#3328)
Browse files Browse the repository at this point in the history
Closes #3078
  • Loading branch information
roji authored Oct 26, 2024
1 parent a8677e4 commit 1e3b9b2
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,61 @@ [new PgUnnestExpression(tableAlias, sliceExpression, "value")],
return base.TranslateTake(source, count);
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
protected override ShapedQueryExpression? TranslateWhere(ShapedQueryExpression source, LambdaExpression predicate)
{
// Simplify x.Array.Where(i => i != 3) => array_remove(x.Array, 3) instead of subquery
if (predicate.Body is BinaryExpression
{
NodeType: ExpressionType.NotEqual,
Left: var left,
Right: var right
}
&& (left == predicate.Parameters[0] ? right : right == predicate.Parameters[0] ? left : null) is Expression itemToFilterOut
&& source.TryExtractArray(out var array, out var projectedColumn)
&& TranslateExpression(itemToFilterOut) is SqlExpression translatedItemToFilterOut)
{
var simplifiedTranslation = _sqlExpressionFactory.Function(
"array_remove",
[array, translatedItemToFilterOut],
nullable: true,
argumentsPropagateNullability: TrueArrays[2],
array.Type,
array.TypeMapping);

#pragma warning disable EF1001 // SelectExpression constructors are currently internal
var tableAlias = ((SelectExpression)source.QueryExpression).Tables[0].Alias!;
var selectExpression = new SelectExpression(
[new PgUnnestExpression(tableAlias, simplifiedTranslation, "value")],
new ColumnExpression("value", tableAlias, projectedColumn.Type, projectedColumn.TypeMapping, projectedColumn.IsNullable),
[GenerateOrdinalityIdentifier(tableAlias)],
_queryCompilationContext.SqlAliasManager);
#pragma warning restore EF1001 // Internal EF Core API usage.

// TODO: Simplify by using UpdateQueryExpression after https://github.com/dotnet/efcore/issues/31511
Expression shaperExpression = new ProjectionBindingExpression(
selectExpression, new ProjectionMember(), source.ShaperExpression.Type.MakeNullable());

if (source.ShaperExpression.Type != shaperExpression.Type)
{
Check.DebugAssert(
source.ShaperExpression.Type.MakeNullable() == shaperExpression.Type,
"expression.Type must be nullable of targetType");

shaperExpression = Expression.Convert(shaperExpression, source.ShaperExpression.Type);
}

return new ShapedQueryExpression(selectExpression, shaperExpression);
}

return base.TranslateWhere(source, predicate);
}

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,24 @@ await Assert.ThrowsAsync<InvalidOperationException>(

public override async Task Contains_with_local_enumerable_inline_closure_mix(bool async)
{
// Issue #31776
await Assert.ThrowsAsync<InvalidOperationException>(
async () =>
await base.Contains_with_local_enumerable_inline_closure_mix(async));
await base.Contains_with_local_enumerable_inline_closure_mix(async);

AssertSql();
AssertSql(
"""
@__p_0={ 'ABCDE', 'ALFKI' } (DbType = Object)
SELECT c."CustomerID", c."Address", c."City", c."CompanyName", c."ContactName", c."ContactTitle", c."Country", c."Fax", c."Phone", c."PostalCode", c."Region"
FROM "Customers" AS c
WHERE c."CustomerID" = ANY (array_remove(@__p_0, NULL))
""",
//
"""
@__p_0={ 'ABCDE', 'ANATR' } (DbType = Object)
SELECT c."CustomerID", c."Address", c."City", c."CompanyName", c."ContactName", c."ContactTitle", c."Country", c."Fax", c."Phone", c."PostalCode", c."Region"
FROM "Customers" AS c
WHERE c."CustomerID" = ANY (array_remove(@__p_0, NULL))
""");
}

public override async Task Contains_with_local_non_primitive_list_closure_mix(bool async)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,23 @@ WHERE CASE
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual async Task Array_remove(bool async)
{
await AssertQuery(
async,
// ReSharper disable once ReplaceWithSingleCallToCount
ss => ss.Set<PrimitiveCollectionsEntity>().Where(e => e.Ints.Where(i => i != 1).Count() == 1));

AssertSql(
"""
SELECT p."Id", p."Bool", p."Bools", p."DateTime", p."DateTimes", p."Enum", p."Enums", p."Int", p."Ints", p."NullableInt", p."NullableInts", p."NullableString", p."NullableStrings", p."String", p."Strings"
FROM "PrimitiveCollectionsEntity" AS p
WHERE cardinality(array_remove(p."Ints", 1)) = 1
""");
}

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down

0 comments on commit 1e3b9b2

Please sign in to comment.