Skip to content

Commit

Permalink
Visit arguments in QueryableMethodNormalizingExpressionVisitor after …
Browse files Browse the repository at this point in the history
…converting List.Contains (dotnet#32219)

Fixes dotnet#32215
Fixes dotnet#32218

(cherry picked from commit 08ee676)
  • Loading branch information
roji committed Nov 9, 2023
1 parent d7cd6f3 commit 5eee914
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public class RelationalQueryableMethodTranslatingExpressionVisitor : QueryableMe
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly bool _subquery;

private static readonly bool UseOldBehavior32218 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue32218", out var enabled32218) && enabled32218;

/// <summary>
/// Creates a new instance of the <see cref="QueryableMethodTranslatingExpressionVisitor" /> class.
/// </summary>
Expand Down Expand Up @@ -288,7 +291,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// Server), we need to fall back to the previous IN translation.
if (method.IsGenericMethod
&& method.GetGenericMethodDefinition() == QueryableMethods.Contains
&& methodCallExpression.Arguments[0] is ParameterQueryRootExpression parameterSource
&& (UseOldBehavior32218
? methodCallExpression.Arguments[0]
: UnwrapAsQueryable(methodCallExpression.Arguments[0])) is ParameterQueryRootExpression parameterSource
&& TranslateExpression(methodCallExpression.Arguments[1]) is SqlExpression item
&& _sqlTranslator.Visit(parameterSource.ParameterExpression) is SqlParameterExpression sqlParameterExpression)
{
Expand All @@ -300,6 +305,12 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
.UpdateResultCardinality(ResultCardinality.Single);
return shapedQueryExpression;
}

static Expression UnwrapAsQueryable(Expression expression)
=> expression is MethodCallExpression { Method: { IsGenericMethod: true } method } methodCall
&& method.GetGenericMethodDefinition() == QueryableMethods.AsQueryable
? methodCall.Arguments[0]
: expression;
}

return translated;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ public class QueryableMethodNormalizingExpressionVisitor : ExpressionVisitor
private readonly SelectManyVerifyingExpressionVisitor _selectManyVerifyingExpressionVisitor = new();
private readonly GroupJoinConvertingExpressionVisitor _groupJoinConvertingExpressionVisitor = new();

private static readonly bool UseOldBehavior32215 =
AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue32215", out var enabled32215) && enabled32215;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
Expand Down Expand Up @@ -435,12 +438,14 @@ private Expression TryConvertListContainsToQueryableContains(MethodCallExpressio

var sourceType = methodCallExpression.Method.DeclaringType!.GetGenericArguments()[0];

return Expression.Call(
var converted = Expression.Call(
QueryableMethods.Contains.MakeGenericMethod(sourceType),
Expression.Call(
QueryableMethods.AsQueryable.MakeGenericMethod(sourceType),
methodCallExpression.Object!),
methodCallExpression.Arguments[0]);

return UseOldBehavior32215 ? converted : VisitMethodCall(converted);
}

private static bool CanConvertEnumerableToQueryable(Type enumerableType, Type queryableType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,34 @@ public virtual Task Project_primitive_collections_element(bool async)
},
assertOrder: true);

[ConditionalTheory] // #32208, #32215
[MemberData(nameof(IsAsyncData))]
public virtual Task Nested_contains_with_Lists_and_no_inferred_type_mapping(bool async)
{
var ints = new List<int> { 1, 2, 3 };
var strings = new List<string> { "one", "two", "three" };

// Note that in this query, the outer Contains really has no type mapping, neither for its source (collection parameter), nor
// for its item (the conditional expression returns constants). The default type mapping must be applied.
return AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(e => strings.Contains(ints.Contains(e.Int) ? "one" : "two")));
}

[ConditionalTheory] // #32208, #32215
[MemberData(nameof(IsAsyncData))]
public virtual Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
{
var ints = new[] { 1, 2, 3 };
var strings = new[] { "one", "two", "three" };

// Note that in this query, the outer Contains really has no type mapping, neither for its source (collection parameter), nor
// for its item (the conditional expression returns constants). The default type mapping must be applied.
return AssertQuery(
async,
ss => ss.Set<PrimitiveCollectionsEntity>().Where(e => strings.Contains(ints.Contains(e.Int) ? "one" : "two")));
}

public abstract class PrimitiveCollectionsQueryFixtureBase : SharedStoreFixtureBase<PrimitiveCollectionsContext>, IQueryFixtureBase
{
private PrimitiveArrayData? _expectedData;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,36 @@ ORDER BY [p].[Id]
""");
}

public override async Task Nested_contains_with_Lists_and_no_inferred_type_mapping(bool async)
{
await base.Nested_contains_with_Lists_and_no_inferred_type_mapping(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE CASE
WHEN [p].[Int] IN (1, 2, 3) THEN N'one'
ELSE N'two'
END IN (N'one', N'two', N'three')
""");
}

public override async Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
{
await base.Nested_contains_with_arrays_and_no_inferred_type_mapping(async);

AssertSql(
"""
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE CASE
WHEN [p].[Int] IN (1, 2, 3) THEN N'one'
ELSE N'two'
END IN (N'one', N'two', N'three')
""");
}

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,54 @@ ORDER BY [p].[Id]
""");
}

public override async Task Nested_contains_with_Lists_and_no_inferred_type_mapping(bool async)
{
await base.Nested_contains_with_Lists_and_no_inferred_type_mapping(async);

AssertSql(
"""
@__ints_1='[1,2,3]' (Size = 4000)
@__strings_0='["one","two","three"]' (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE CASE
WHEN [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_1) WITH ([value] int '$') AS [i]
) THEN N'one'
ELSE N'two'
END IN (
SELECT [s].[value]
FROM OPENJSON(@__strings_0) WITH ([value] nvarchar(max) '$') AS [s]
)
""");
}

public override async Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
{
await base.Nested_contains_with_arrays_and_no_inferred_type_mapping(async);

AssertSql(
"""
@__ints_1='[1,2,3]' (Size = 4000)
@__strings_0='["one","two","three"]' (Size = 4000)
SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings]
FROM [PrimitiveCollectionsEntity] AS [p]
WHERE CASE
WHEN [p].[Int] IN (
SELECT [i].[value]
FROM OPENJSON(@__ints_1) WITH ([value] int '$') AS [i]
) THEN N'one'
ELSE N'two'
END IN (
SELECT [s].[value]
FROM OPENJSON(@__strings_0) WITH ([value] nvarchar(max) '$') AS [s]
)
""");
}

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down
6 changes: 5 additions & 1 deletion test/EFCore.SqlServer.FunctionalTests/Query/QueryBugsTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3982,13 +3982,17 @@ public virtual async Task Nested_contains_with_enum()

AssertSql(
"""
@__todoTypes_1='[0]' (Size = 4000)
@__key_2='5f221fb9-66f4-442a-92c9-d97ed5989cc7'
@__keys_0='["0a47bcb7-a1cb-4345-8944-c58f82d6aac7","5f221fb9-66f4-442a-92c9-d97ed5989cc7"]' (Size = 4000)
SELECT [t].[Id], [t].[Type]
FROM [Todos] AS [t]
WHERE CASE
WHEN [t].[Type] = 0 THEN @__key_2
WHEN [t].[Type] IN (
SELECT [t0].[value]
FROM OPENJSON(@__todoTypes_1) WITH ([value] int '$') AS [t0]
) THEN @__key_2
ELSE @__key_2
END IN (
SELECT [k].[value]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,54 @@ public override async Task Project_empty_collection_of_nullables_and_collection_
(await Assert.ThrowsAsync<InvalidOperationException>(
() => base.Project_empty_collection_of_nullables_and_collection_only_containing_nulls(async))).Message);

public override async Task Nested_contains_with_Lists_and_no_inferred_type_mapping(bool async)
{
await base.Nested_contains_with_Lists_and_no_inferred_type_mapping(async);

AssertSql(
"""
@__ints_1='[1,2,3]' (Size = 7)
@__strings_0='["one","two","three"]' (Size = 21)
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE CASE
WHEN "p"."Int" IN (
SELECT "i"."value"
FROM json_each(@__ints_1) AS "i"
) THEN 'one'
ELSE 'two'
END IN (
SELECT "s"."value"
FROM json_each(@__strings_0) AS "s"
)
""");
}

public override async Task Nested_contains_with_arrays_and_no_inferred_type_mapping(bool async)
{
await base.Nested_contains_with_arrays_and_no_inferred_type_mapping(async);

AssertSql(
"""
@__ints_1='[1,2,3]' (Size = 7)
@__strings_0='["one","two","three"]' (Size = 21)
SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings"
FROM "PrimitiveCollectionsEntity" AS "p"
WHERE CASE
WHEN "p"."Int" IN (
SELECT "i"."value"
FROM json_each(@__ints_1) AS "i"
) THEN 'one'
ELSE 'two'
END IN (
SELECT "s"."value"
FROM json_each(@__strings_0) AS "s"
)
""");
}

[ConditionalFact]
public virtual void Check_all_tests_overridden()
=> TestHelpers.AssertAllMethodsOverridden(GetType());
Expand Down

0 comments on commit 5eee914

Please sign in to comment.