Skip to content

Commit

Permalink
Query: Translate to SQL GROUP BY when aggregate operator is applied a…
Browse files Browse the repository at this point in the history
…fter GroupBy

Resolves #12826
Resolves #6658
Part of #15711
Resolves #15853
Resolves #12799
Resolves #12476
Resolves #11976

There are way too many existing issues are resolved by this PR. I haven't added regression test or verified each of them so I have put Verify-Fixed label on them for now.
  • Loading branch information
smitpatel committed Jul 2, 2019
1 parent fd9ac43 commit 8497da5
Show file tree
Hide file tree
Showing 17 changed files with 741 additions and 301 deletions.
14 changes: 14 additions & 0 deletions src/EFCore.Relational/Query/Pipeline/QuerySqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ protected virtual void GenerateSelect(SelectExpression selectExpression)
Visit(selectExpression.Predicate);
}

if (selectExpression.GroupBy.Count > 0)
{
_relationalCommandBuilder.AppendLine().Append("GROUP BY ");

GenerateList(selectExpression.GroupBy, e => Visit(e));
}

if (selectExpression.HavingExpression != null)
{
_relationalCommandBuilder.AppendLine().Append("HAVING ");

Visit(selectExpression.HavingExpression);
}

GenerateOrderings(selectExpression);
GenerateLimitOffset(selectExpression);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.NavigationExpansion;
using Microsoft.EntityFrameworkCore.Query.Pipeline;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,105 @@ protected override ShapedQueryExpression TranslateFirstOrDefault(ShapedQueryExpr
return source;
}

protected override ShapedQueryExpression TranslateGroupBy(ShapedQueryExpression source, LambdaExpression keySelector, LambdaExpression elementSelector, LambdaExpression resultSelector) => throw new NotImplementedException();
protected override ShapedQueryExpression TranslateGroupBy(
ShapedQueryExpression source,
LambdaExpression keySelector,
LambdaExpression elementSelector,
LambdaExpression resultSelector)
{
var selectExpression = (SelectExpression)source.QueryExpression;
selectExpression.PrepareForAggregate();

var remappedKeySelector = RemapLambdaBody(source.ShaperExpression, keySelector);

var translatedKey = TranslateGroupingKey(remappedKeySelector)
?? (remappedKeySelector is ConstantExpression ? remappedKeySelector : null);
if (translatedKey != null)
{
if (elementSelector != null)
{
source = TranslateSelect(source, elementSelector);
}

var sqlKeySelector = translatedKey is ConstantExpression
? _sqlExpressionFactory.ApplyDefaultTypeMapping(_sqlExpressionFactory.Constant(1))
: translatedKey;

var appliedKeySelector = selectExpression.ApplyGrouping(sqlKeySelector);
translatedKey = translatedKey is ConstantExpression ? translatedKey : appliedKeySelector;

source.ShaperExpression = new GroupByShaperExpression(translatedKey, source.ShaperExpression);

if (resultSelector == null)
{
return source;
}

var keyAccessExpression = Expression.MakeMemberAccess(
source.ShaperExpression,
source.ShaperExpression.Type.GetTypeInfo().GetMember(nameof(IGrouping<int, int>.Key))[0]);

var newResultSelectorBody = ReplacingExpressionVisitor.Replace(
resultSelector.Parameters[0], keyAccessExpression,
resultSelector.Parameters[1], source.ShaperExpression,
resultSelector.Body);

source.ShaperExpression = _projectionBindingExpressionVisitor.Translate(selectExpression, newResultSelectorBody);

return source;
}

throw new InvalidOperationException();
}

private Expression TranslateGroupingKey(Expression expression)
{
if (expression is NewExpression newExpression)
{
if (newExpression.Arguments.Count == 0)
{
return newExpression;
}

var newArguments = new Expression[newExpression.Arguments.Count];
for (var i = 0; i < newArguments.Length; i++)
{
newArguments[i] = TranslateGroupingKey(newExpression.Arguments[i]);
if (newArguments[i] == null)
{
return null;
}
}

return newExpression.Update(newArguments);
}

if (expression is MemberInitExpression memberInitExpression)
{
var updatedNewExpression = (NewExpression)TranslateGroupingKey(memberInitExpression.NewExpression);
if (updatedNewExpression == null)
{
return null;
}

var newBindings = new MemberAssignment[memberInitExpression.Bindings.Count];
for (var i = 0; i < newBindings.Length; i++)
{
var memberAssignment = (MemberAssignment)memberInitExpression.Bindings[i];
var visitedExpression = TranslateGroupingKey(memberAssignment.Expression);
if (visitedExpression == null)
{
return null;
}

newBindings[i] = memberAssignment.Update(visitedExpression);
}

return memberInitExpression.Update(updatedNewExpression, newBindings);
}

return _sqlTranslator.Translate(expression);
}

protected override ShapedQueryExpression TranslateGroupJoin(ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector)
{
Expand Down Expand Up @@ -589,17 +687,16 @@ protected override ShapedQueryExpression TranslateSelect(ShapedQueryExpression s
}

var newSelectorBody = ReplacingExpressionVisitor.Replace(selector.Parameters.Single(), source.ShaperExpression, selector.Body);

source.ShaperExpression = _projectionBindingExpressionVisitor
.Translate(selectExpression, newSelectorBody);
source.ShaperExpression = _projectionBindingExpressionVisitor.Translate(selectExpression, newSelectorBody);

return source;
}

private static readonly MethodInfo _defaultIfEmptyWithoutArgMethodInfo = typeof(Enumerable).GetTypeInfo()
.GetDeclaredMethods(nameof(Enumerable.DefaultIfEmpty)).Single(mi => mi.GetParameters().Length == 1);

protected override ShapedQueryExpression TranslateSelectMany(ShapedQueryExpression source, LambdaExpression collectionSelector, LambdaExpression resultSelector)
protected override ShapedQueryExpression TranslateSelectMany(
ShapedQueryExpression source, LambdaExpression collectionSelector, LambdaExpression resultSelector)
{
var collectionSelectorBody = collectionSelector.Body;
//var defaultIfEmpty = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ protected override Expression VisitExtension(Expression extensionExpression)
var collectionId = _collectionId++;
var selectExpression = (SelectExpression)collectionShaperExpression.Projection.QueryExpression;
// Do pushdown beforehand so it updates all pending collections first
if (selectExpression.IsDistinct || selectExpression.Limit != null || selectExpression.Offset != null)
if (selectExpression.IsDistinct
|| selectExpression.Limit != null
|| selectExpression.Offset != null
|| selectExpression.GroupBy.Count > 1)
{
selectExpression.PushdownIntoSubquery();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ public virtual SqlExpression Translate(Expression expression)

translation = _sqlExpressionFactory.ApplyDefaultTypeMapping(translation);

if (translation is SqlConstantExpression
&& translation.TypeMapping == null)
{
// Non-mappable constant
return null;
}

_sqlVerifyingExpressionVisitor.Visit(translation);

return translation;
Expand Down Expand Up @@ -246,13 +253,66 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
return null;
}

private Expression GetSelector(MethodCallExpression methodCallExpression, GroupByShaperExpression groupByShaperExpression)
{
if (methodCallExpression.Arguments.Count == 1)
{
return groupByShaperExpression.ElementSelector;
}

if (methodCallExpression.Arguments.Count == 2)
{
var selectorLambda = methodCallExpression.Arguments[1].UnwrapLambdaFromQuote();
return ReplacingExpressionVisitor.Replace(
selectorLambda.Parameters[0],
groupByShaperExpression.ElementSelector,
selectorLambda.Body);
}

throw new InvalidOperationException();
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
// EF.Property case
if (methodCallExpression.TryGetEFPropertyArguments(out var source, out var propertyName))
{
return BindProperty(source, propertyName);
}

// GroupBy Aggregate case
if (methodCallExpression.Object == null
&& methodCallExpression.Method.DeclaringType == typeof(Enumerable)
&& methodCallExpression.Arguments.Count > 0
&& methodCallExpression.Arguments[0] is GroupByShaperExpression groupByShaperExpression)
{
switch (methodCallExpression.Method.Name)
{
case nameof(Enumerable.Average):
return TranslateAverage(GetSelector(methodCallExpression, groupByShaperExpression));

case nameof(Enumerable.Count):
return TranslateCount();

case nameof(Enumerable.LongCount):
return TranslateLongCount();

case nameof(Enumerable.Max):
return TranslateMax(GetSelector(methodCallExpression, groupByShaperExpression));

case nameof(Enumerable.Min):
return TranslateMin(GetSelector(methodCallExpression, groupByShaperExpression));

case nameof(Enumerable.Sum):
return TranslateSum(GetSelector(methodCallExpression, groupByShaperExpression));

default:
throw new InvalidOperationException("Unknown aggregate operator encountered.");
}

}

// Subquery case
var subqueryTranslation = _queryableMethodTranslatingExpressionVisitor.TranslateSubquery(methodCallExpression);
if (subqueryTranslation != null)
{
Expand Down Expand Up @@ -280,6 +340,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
return new SubSelectExpression(subquery);
}

// MethodCall translators
var @object = Visit(methodCallExpression.Object);
if (TranslationFailed(methodCallExpression.Object, @object))
{
Expand Down Expand Up @@ -413,29 +474,25 @@ protected override Expression VisitParameter(ParameterExpression parameterExpres

protected override Expression VisitExtension(Expression extensionExpression)
{
if (extensionExpression is EntityShaperExpression)
switch (extensionExpression)
{
return extensionExpression;
}
case EntityShaperExpression _:
case SqlExpression _:
return extensionExpression;

if (extensionExpression is ProjectionBindingExpression projectionBindingExpression)
{
var selectExpression = (SelectExpression)projectionBindingExpression.QueryExpression;
case NullConditionalExpression nullConditionalExpression:
return Visit(nullConditionalExpression.AccessOperation);

return selectExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember);
}
case CorrelationPredicateExpression correlationPredicateExpression:
return Visit(correlationPredicateExpression.EqualExpression);

if (extensionExpression is NullConditionalExpression nullConditionalExpression)
{
return Visit(nullConditionalExpression.AccessOperation);
}
case ProjectionBindingExpression projectionBindingExpression:
var selectExpression = (SelectExpression)projectionBindingExpression.QueryExpression;
return selectExpression.GetMappedProjection(projectionBindingExpression.ProjectionMember);

if (extensionExpression is CorrelationPredicateExpression correlationPredicateExpression)
{
return Visit(correlationPredicateExpression.EqualExpression);
default:
return null;
}

return base.VisitExtension(extensionExpression);
}

protected override Expression VisitConditional(ConditionalExpression conditionalExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ public SelectExpression Select(SqlExpression projection)
alias: null,
new List<ProjectionExpression>(),
new List<TableExpressionBase>(),
new List<SqlExpression>(),
new List<OrderingExpression>());

if (projection != null)
Expand Down
Loading

0 comments on commit 8497da5

Please sign in to comment.