diff --git a/src/EFCore.Specification.Tests/Query/AsyncQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/AsyncQueryTestBase.cs index 798119e3afe..e067500505e 100644 --- a/src/EFCore.Specification.Tests/Query/AsyncQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/AsyncQueryTestBase.cs @@ -23,6 +23,13 @@ namespace Microsoft.EntityFrameworkCore.Query public abstract class AsyncQueryTestBase : IClassFixture where TFixture : NorthwindQueryFixtureBase, new() { + [ConditionalFact] + public virtual async Task Projection_when_client_evald_subquery() + { + await AssertQuery( + cs => cs.Select(c => string.Join(", ", c.Orders.Select(o => o.CustomerID)))); + } + [ConditionalFact] public virtual async Task ToArray_on_nav_subquery_in_projection() { diff --git a/src/EFCore.Specification.Tests/Query/QueryTestBase.Select.cs b/src/EFCore.Specification.Tests/Query/QueryTestBase.Select.cs index c8bef993bac..d68a0030b19 100644 --- a/src/EFCore.Specification.Tests/Query/QueryTestBase.Select.cs +++ b/src/EFCore.Specification.Tests/Query/QueryTestBase.Select.cs @@ -41,6 +41,13 @@ public virtual void Projection_when_null_value() cs => cs.Select(c => c.Region)); } + [ConditionalFact] + public virtual void Projection_when_client_evald_subquery() + { + AssertQuery( + cs => cs.Select(c => string.Join(", ", c.Orders.Select(o => o.CustomerID)))); + } + [ConditionalFact] public virtual void Project_to_object_array() { diff --git a/src/EFCore/Query/ExpressionVisitors/Internal/CollectionNavigationSetOperatorSubqueryInjector.cs b/src/EFCore/Query/ExpressionVisitors/Internal/CollectionNavigationSetOperatorSubqueryInjector.cs index df5bad41d6e..2a98bda92ec 100644 --- a/src/EFCore/Query/ExpressionVisitors/Internal/CollectionNavigationSetOperatorSubqueryInjector.cs +++ b/src/EFCore/Query/ExpressionVisitors/Internal/CollectionNavigationSetOperatorSubqueryInjector.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Linq; using System.Linq.Expressions; using JetBrains.Annotations; @@ -30,7 +31,22 @@ public CollectionNavigationSetOperatorSubqueryInjector([NotNull] EntityQueryMode /// protected override Expression VisitSubQuery(SubQueryExpression expression) { - expression.QueryModel.TransformExpressions(Visit); + bool shouldInject; + + if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue9570", out var isEnabled) + && isEnabled) + { + expression.QueryModel.TransformExpressions(Visit); + } + else + { + shouldInject = ShouldInject; + ShouldInject = false; + + expression.QueryModel.TransformExpressions(Visit); + + ShouldInject = shouldInject; + } foreach (var resultOperator in expression.QueryModel.ResultOperators.Where( ro => ro is ConcatResultOperator @@ -38,7 +54,7 @@ protected override Expression VisitSubQuery(SubQueryExpression expression) || ro is IntersectResultOperator || ro is ExceptResultOperator)) { - var shouldInject = ShouldInject; + shouldInject = ShouldInject; ShouldInject = true; resultOperator.TransformExpressions(Visit); diff --git a/src/EFCore/Query/ExpressionVisitors/ProjectionExpressionVisitor.cs b/src/EFCore/Query/ExpressionVisitors/ProjectionExpressionVisitor.cs index 1d4ebf3ec8c..03da4e7aac9 100644 --- a/src/EFCore/Query/ExpressionVisitors/ProjectionExpressionVisitor.cs +++ b/src/EFCore/Query/ExpressionVisitors/ProjectionExpressionVisitor.cs @@ -42,13 +42,14 @@ public ProjectionExpressionVisitor([NotNull] EntityQueryModelVisitor entityQuery protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression) { if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue9128", out var isEnabled) + || AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue9570", out isEnabled) && isEnabled) { return base.VisitMethodCall(methodCallExpression); } - + var newExpression = base.VisitMethodCall(methodCallExpression); - + switch (newExpression) { case MethodCallExpression newMethodCallExpression @@ -82,8 +83,8 @@ when newMethodCallExpression2.Method && innerMethodCallExpression2.Arguments[0].Type.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>) ? (Expression) Expression.Property( ResultOperatorHandler.CallWithPossibleCancellationToken( - MaterializeCollectionNavigationAsyncMethodInfo.MakeGenericMethod( - newMethodCallExpression2.Method.GetGenericArguments()[0]), + _materializeCollectionNavigationAsyncMethodInfo + .MakeGenericMethod(newMethodCallExpression2.Method.GetGenericArguments()[0]), newMethodCallExpression2.Arguments[0], innerMethodCallExpression2.Arguments[0]), nameof(Task.Result)) @@ -92,23 +93,57 @@ when newMethodCallExpression2.Method newMethodCallExpression2.Arguments[0], innerMethodCallExpression2.Arguments[0]); } + case MethodCallExpression newMethodCallExpression3 + when newMethodCallExpression3.Method.Equals(methodCallExpression.Method) + && newMethodCallExpression3.Arguments.Count > 0: + { + // Transforms a sync sequence argument to an async buffered version + // We call .Result here so that the types still line up (we remove the .Result in TaskLiftingExpressionVisitor). + + var newArguments = newMethodCallExpression3.Arguments.ToList(); + + for (var i = 0; i < newArguments.Count; i++) + { + var argument = newArguments[i]; + + if (argument is MethodCallExpression argumentMethodCallExpression + && argumentMethodCallExpression.Method.MethodIsClosedFormOf(QueryModelVisitor.LinqOperatorProvider.ToEnumerable)) + { + newArguments[i] + = Expression.Property( + ResultOperatorHandler.CallWithPossibleCancellationToken( + _toArrayAsync.MakeGenericMethod( + argumentMethodCallExpression.Method.GetGenericArguments()), + argumentMethodCallExpression.Arguments.ToArray()), + nameof(Task.Result)); + } + } + + return newMethodCallExpression3.Update(newMethodCallExpression3.Object, newArguments); + } } return newExpression; } - private static readonly MethodInfo MaterializeCollectionNavigationAsyncMethodInfo + private static readonly MethodInfo _toArrayAsync + = typeof(AsyncEnumerable).GetTypeInfo() + .GetDeclaredMethods(nameof(AsyncEnumerable.ToArray)) + .Single(mi => mi.GetParameters().Length == 2); + + private static readonly MethodInfo _materializeCollectionNavigationAsyncMethodInfo = typeof(ProjectionExpressionVisitor).GetTypeInfo() .GetDeclaredMethod(nameof(MaterializeCollectionNavigationAsync)); [UsedImplicitly] private static async Task> MaterializeCollectionNavigationAsync( INavigation navigation, - IAsyncEnumerable elements) + IAsyncEnumerable elements, + CancellationToken cancellationToken) { var collection = (ICollection) navigation.GetCollectionAccessor().Create(); - await elements.ForEachAsync(e => collection.Add((TEntity) e)); + await elements.ForEachAsync(e => collection.Add((TEntity) e), cancellationToken); return collection; }