Skip to content

Commit

Permalink
[2.0.1] - Fix #9570 - Exception in Client vs. Server Evaluation with …
Browse files Browse the repository at this point in the history
…async/await in EF Core 2.0.0

Two issues:

1) CollectionNavigationSetOperatorSubqueryInjector would incorrectly introduce MaterializeCollectionNavigation calls around
subqueries. Fix is to reset the ShouldInject flag when visiting subqueries.
2) After fixing 1, we would deadlock due to blocking on the subquery call. Extended task lifting to deal with this case.
  • Loading branch information
anpete committed Sep 15, 2017
1 parent 119d30b commit 2accafb
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 9 deletions.
7 changes: 7 additions & 0 deletions src/EFCore.Specification.Tests/Query/AsyncQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ namespace Microsoft.EntityFrameworkCore.Query
public abstract class AsyncQueryTestBase<TFixture> : IClassFixture<TFixture>
where TFixture : NorthwindQueryFixtureBase, new()
{
[ConditionalFact]
public virtual async Task Projection_when_client_evald_subquery()
{
await AssertQuery<Customer>(
cs => cs.Select(c => string.Join(", ", c.Orders.Select(o => o.CustomerID))));
}

[ConditionalFact]
public virtual async Task ToArray_on_nav_subquery_in_projection()
{
Expand Down
7 changes: 7 additions & 0 deletions src/EFCore.Specification.Tests/Query/QueryTestBase.Select.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ public virtual void Projection_when_null_value()
cs => cs.Select(c => c.Region));
}

[ConditionalFact]
public virtual void Projection_when_client_evald_subquery()
{
AssertQuery<Customer>(
cs => cs.Select(c => string.Join(", ", c.Orders.Select(o => o.CustomerID))));
}

[ConditionalFact]
public virtual void Project_to_object_array()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Linq;
using System.Linq.Expressions;
using JetBrains.Annotations;
Expand Down Expand Up @@ -30,15 +31,30 @@ public CollectionNavigationSetOperatorSubqueryInjector([NotNull] EntityQueryMode
/// </summary>
protected override Expression VisitSubQuery(SubQueryExpression expression)
{
expression.QueryModel.TransformExpressions(Visit);
bool shouldInject;

if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue9570", out var isEnabled)
&& isEnabled)
{
expression.QueryModel.TransformExpressions(Visit);
}
else
{
shouldInject = ShouldInject;
ShouldInject = false;

expression.QueryModel.TransformExpressions(Visit);

ShouldInject = shouldInject;
}

foreach (var resultOperator in expression.QueryModel.ResultOperators.Where(
ro => ro is ConcatResultOperator
|| ro is UnionResultOperator
|| ro is IntersectResultOperator
|| ro is ExceptResultOperator))
{
var shouldInject = ShouldInject;
shouldInject = ShouldInject;
ShouldInject = true;

resultOperator.TransformExpressions(Visit);
Expand Down
49 changes: 42 additions & 7 deletions src/EFCore/Query/ExpressionVisitors/ProjectionExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ public ProjectionExpressionVisitor([NotNull] EntityQueryModelVisitor entityQuery
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue9128", out var isEnabled)
|| AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue9570", out isEnabled)
&& isEnabled)
{
return base.VisitMethodCall(methodCallExpression);
}

var newExpression = base.VisitMethodCall(methodCallExpression);

switch (newExpression)
{
case MethodCallExpression newMethodCallExpression
Expand Down Expand Up @@ -82,8 +83,8 @@ when newMethodCallExpression2.Method
&& innerMethodCallExpression2.Arguments[0].Type.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>)
? (Expression) Expression.Property(
ResultOperatorHandler.CallWithPossibleCancellationToken(
MaterializeCollectionNavigationAsyncMethodInfo.MakeGenericMethod(
newMethodCallExpression2.Method.GetGenericArguments()[0]),
_materializeCollectionNavigationAsyncMethodInfo
.MakeGenericMethod(newMethodCallExpression2.Method.GetGenericArguments()[0]),
newMethodCallExpression2.Arguments[0],
innerMethodCallExpression2.Arguments[0]),
nameof(Task<object>.Result))
Expand All @@ -92,23 +93,57 @@ when newMethodCallExpression2.Method
newMethodCallExpression2.Arguments[0],
innerMethodCallExpression2.Arguments[0]);
}
case MethodCallExpression newMethodCallExpression3
when newMethodCallExpression3.Method.Equals(methodCallExpression.Method)
&& newMethodCallExpression3.Arguments.Count > 0:
{
// Transforms a sync sequence argument to an async buffered version
// We call .Result here so that the types still line up (we remove the .Result in TaskLiftingExpressionVisitor).

var newArguments = newMethodCallExpression3.Arguments.ToList();

for (var i = 0; i < newArguments.Count; i++)
{
var argument = newArguments[i];

if (argument is MethodCallExpression argumentMethodCallExpression
&& argumentMethodCallExpression.Method.MethodIsClosedFormOf(QueryModelVisitor.LinqOperatorProvider.ToEnumerable))
{
newArguments[i]
= Expression.Property(
ResultOperatorHandler.CallWithPossibleCancellationToken(
_toArrayAsync.MakeGenericMethod(
argumentMethodCallExpression.Method.GetGenericArguments()),
argumentMethodCallExpression.Arguments.ToArray()),
nameof(Task<object>.Result));
}
}

return newMethodCallExpression3.Update(newMethodCallExpression3.Object, newArguments);
}
}

return newExpression;
}

private static readonly MethodInfo MaterializeCollectionNavigationAsyncMethodInfo
private static readonly MethodInfo _toArrayAsync
= typeof(AsyncEnumerable).GetTypeInfo()
.GetDeclaredMethods(nameof(AsyncEnumerable.ToArray))
.Single(mi => mi.GetParameters().Length == 2);

private static readonly MethodInfo _materializeCollectionNavigationAsyncMethodInfo
= typeof(ProjectionExpressionVisitor).GetTypeInfo()
.GetDeclaredMethod(nameof(MaterializeCollectionNavigationAsync));

[UsedImplicitly]
private static async Task<ICollection<TEntity>> MaterializeCollectionNavigationAsync<TEntity>(
INavigation navigation,
IAsyncEnumerable<object> elements)
IAsyncEnumerable<object> elements,
CancellationToken cancellationToken)
{
var collection = (ICollection<TEntity>) navigation.GetCollectionAccessor().Create();

await elements.ForEachAsync(e => collection.Add((TEntity) e));
await elements.ForEachAsync(e => collection.Add((TEntity) e), cancellationToken);

return collection;
}
Expand Down

0 comments on commit 2accafb

Please sign in to comment.