From e825c742f19baaace10c1ecde4b89d17255c20a0 Mon Sep 17 00:00:00 2001 From: Andrew Peters Date: Sat, 24 Jun 2017 13:58:01 -0700 Subject: [PATCH] Query: Fix bug in IncludeCompiler where we would resolve an incorrect QSRE for grouped results. --- .../Query/OwnedQueryTestBase.cs | 35 ++++++++++++++++ .../QuerySourceTracingExpressionVisitor.cs | 14 ++----- src/EFCore/Query/Internal/IncludeCompiler.cs | 7 ++-- .../Query/Internal/QueryModelExtensions.cs | 41 ++++++++++++++++++- .../Query/OwnedQuerySqlServerTest.cs | 12 ++++++ .../Query/OwnedQuerySqliteTest.cs | 12 ++++++ 6 files changed, 106 insertions(+), 15 deletions(-) diff --git a/src/EFCore.Specification.Tests/Query/OwnedQueryTestBase.cs b/src/EFCore.Specification.Tests/Query/OwnedQueryTestBase.cs index 196e4250d5f..5a33102bd24 100644 --- a/src/EFCore.Specification.Tests/Query/OwnedQueryTestBase.cs +++ b/src/EFCore.Specification.Tests/Query/OwnedQueryTestBase.cs @@ -52,6 +52,41 @@ public virtual void Query_for_leaf_type_loads_all_owned_navs() } } + [Fact] + public virtual void Query_when_group_by() + { + using (var context = CreateContext()) + { + var people = context.Set().GroupBy(op => op.Id).ToList(); + + Assert.Equal(4, people.Count); + Assert.True(people.SelectMany(p => p).All(p => p.PersonAddress != null)); + Assert.True(people.SelectMany(p => p).OfType().All(b => b.BranchAddress != null)); + Assert.True(people.SelectMany(p => p).OfType().All(a => a.LeafAAddress != null)); + Assert.True(people.SelectMany(p => p).OfType().All(b => b.LeafBAddress != null)); + } + } + + [Fact] + public virtual void Query_when_subquery() + { + using (var context = CreateContext()) + { + var people + = context.Set() + .Distinct() + .Take(5) + .Select(op => new { op }) + .ToList(); + + Assert.Equal(4, people.Count); + Assert.True(people.All(p => p.op.PersonAddress != null)); + Assert.True(people.Select(p => p.op).OfType().All(b => b.BranchAddress != null)); + Assert.True(people.Select(p => p.op).OfType().All(a => a.LeafAAddress != null)); + Assert.True(people.Select(p => p.op).OfType().All(b => b.LeafBAddress != null)); + } + } + protected abstract DbContext CreateContext(); } } diff --git a/src/EFCore/Query/ExpressionVisitors/Internal/QuerySourceTracingExpressionVisitor.cs b/src/EFCore/Query/ExpressionVisitors/Internal/QuerySourceTracingExpressionVisitor.cs index 2cc1273f062..f3064f439c5 100644 --- a/src/EFCore/Query/ExpressionVisitors/Internal/QuerySourceTracingExpressionVisitor.cs +++ b/src/EFCore/Query/ExpressionVisitors/Internal/QuerySourceTracingExpressionVisitor.cs @@ -58,23 +58,17 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr } else { - var fromClauseBase = expression.ReferencedQuerySource as FromClauseBase; - - if (fromClauseBase != null) + if (expression.ReferencedQuerySource is FromClauseBase fromClauseBase) { Visit(fromClauseBase.FromExpression); } - - var joinClause = expression.ReferencedQuerySource as JoinClause; - - if (joinClause != null) + + if (expression.ReferencedQuerySource is JoinClause joinClause) { Visit(joinClause.InnerSequence); } - var groupJoinClause = expression.ReferencedQuerySource as GroupJoinClause; - - if (groupJoinClause != null) + if (expression.ReferencedQuerySource is GroupJoinClause groupJoinClause) { if (groupJoinClause.JoinClause.Equals(_targetQuerySource)) { diff --git a/src/EFCore/Query/Internal/IncludeCompiler.cs b/src/EFCore/Query/Internal/IncludeCompiler.cs index 9e6b4c34c57..982ed23152b 100644 --- a/src/EFCore/Query/Internal/IncludeCompiler.cs +++ b/src/EFCore/Query/Internal/IncludeCompiler.cs @@ -74,7 +74,7 @@ public virtual void CompileIncludes( return; } - foreach (var includeLoadTree in CreateIncludeLoadTrees(queryModel.SelectClause.Selector)) + foreach (var includeLoadTree in CreateIncludeLoadTrees(queryModel)) { includeLoadTree.Compile( _queryCompilationContext, @@ -111,8 +111,7 @@ public virtual void LogIgnoredIncludes() } } - private IEnumerable CreateIncludeLoadTrees( - Expression targetExpression) + private IEnumerable CreateIncludeLoadTrees(QueryModel queryModel) { var querySourceTracingExpressionVisitor = _querySourceTracingExpressionVisitorFactory.Create(); @@ -126,7 +125,7 @@ var querySourceTracingExpressionVisitor var querySourceReferenceExpression = querySourceTracingExpressionVisitor .FindResultQuerySourceReferenceExpression( - targetExpression, + queryModel.GetOutputExpression(), includeResultOperator.QuerySource); if (querySourceReferenceExpression == null diff --git a/src/EFCore/Query/Internal/QueryModelExtensions.cs b/src/EFCore/Query/Internal/QueryModelExtensions.cs index 1728e1729ee..24fdbb110e7 100644 --- a/src/EFCore/Query/Internal/QueryModelExtensions.cs +++ b/src/EFCore/Query/Internal/QueryModelExtensions.cs @@ -1,13 +1,17 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; +using System.Linq; using System.Linq.Expressions; using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Query.ExpressionVisitors; using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.ResultOperators; namespace Microsoft.EntityFrameworkCore.Query.Internal { @@ -16,7 +20,42 @@ namespace Microsoft.EntityFrameworkCore.Query.Internal /// directly from your code. This API may change or be removed in future releases. /// public static class QueryModelExtensions - { + { + /// + /// This API supports the Entity Framework Core infrastructure and is not intended to be used + /// directly from your code. This API may change or be removed in future releases. + /// + public static Expression GetOutputExpression([NotNull] this QueryModel queryModel) + { + var outputExpression = queryModel.SelectClause.Selector; + + var groupResultOperator + = queryModel.ResultOperators.OfType().LastOrDefault(); + + if (groupResultOperator != null) + { + outputExpression = groupResultOperator.ElementSelector; + } + else if (queryModel.SelectClause.Selector.Type.IsGrouping()) + { + var subqueryExpression + = (queryModel.SelectClause.Selector + .TryGetReferencedQuerySource() as MainFromClause)?.FromExpression as SubQueryExpression; + + var nestedGroupResultOperator + = subqueryExpression?.QueryModel?.ResultOperators + ?.OfType() + .LastOrDefault(); + + if (nestedGroupResultOperator != null) + { + outputExpression = nestedGroupResultOperator.ElementSelector; + } + } + + return outputExpression; + } + /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used /// directly from your code. This API may change or be removed in future releases. diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/OwnedQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/OwnedQuerySqlServerTest.cs index 019279d5bc9..b3f58dba60f 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/OwnedQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/OwnedQuerySqlServerTest.cs @@ -38,6 +38,18 @@ public override void Query_for_leaf_type_loads_all_owned_navs() AssertSql(""); } + [Fact(Skip = "#8907")] + public override void Query_when_group_by() + { + base.Query_when_group_by(); + } + + [Fact(Skip = "#8907")] + public override void Query_when_subquery() + { + base.Query_when_subquery(); + } + protected override DbContext CreateContext() => _fixture.CreateContext(); private void AssertSql(params string[] expected) diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/OwnedQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/OwnedQuerySqliteTest.cs index a8398d1790d..88a7b556546 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/OwnedQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/OwnedQuerySqliteTest.cs @@ -32,6 +32,18 @@ public override void Query_for_leaf_type_loads_all_owned_navs() base.Query_for_leaf_type_loads_all_owned_navs(); } + [Fact(Skip = "#8907")] + public override void Query_when_group_by() + { + base.Query_when_group_by(); + } + + [Fact(Skip = "#8907")] + public override void Query_when_subquery() + { + base.Query_when_subquery(); + } + protected override DbContext CreateContext() => _fixture.CreateContext(); } }