Skip to content

Commit

Permalink
CSHARP-4468: LINQ V3 SelectMany + GroupBy results with redundant $pus…
Browse files Browse the repository at this point in the history
…h within $group.
  • Loading branch information
rstam committed Jan 12, 2023
1 parent 70ed174 commit 396830c
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ public static AstPipeline Optimize(AstPipeline pipeline)
#endregion

private readonly AccumulatorSet _accumulators = new AccumulatorSet();
private AstExpression _element; // normally either "$$ROOT" or "$_v"

private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStage groupStage)
{
try
{
if (IsOptimizableGroupStage(groupStage))
if (IsOptimizableGroupStage(groupStage, out _element))
{
var followingStages = GetFollowingStagesToOptimize(pipeline, i + 1);
if (followingStages == null)
Expand All @@ -71,22 +72,22 @@ private AstPipeline OptimizeGroupStage(AstPipeline pipeline, int i, AstGroupStag

return pipeline;

static bool IsOptimizableGroupStage(AstGroupStage groupStage)
static bool IsOptimizableGroupStage(AstGroupStage groupStage, out AstExpression element)
{
// { $group : { _id : ?, _elements : { $push : "$$ROOT" } } }
// { $group : { _id : ?, _elements : { $push : element } } }
if (groupStage.Fields.Count == 1)
{
var field = groupStage.Fields[0];
if (field.Path == "_elements" &&
field.Value is AstUnaryAccumulatorExpression unaryAccumulatorExpression &&
unaryAccumulatorExpression.Operator == AstUnaryAccumulatorOperator.Push &&
unaryAccumulatorExpression.Arg is AstVarExpression varExpression &&
varExpression.Name == "ROOT")
unaryAccumulatorExpression.Operator == AstUnaryAccumulatorOperator.Push)
{
element = unaryAccumulatorExpression.Arg;
return true;
}
}

element = null;
return false;
}

Expand Down Expand Up @@ -173,7 +174,7 @@ private AstStage OptimizeLimitStage(AstLimitStage stage)

private AstStage OptimizeMatchStage(AstMatchStage stage)
{
var optimizedFilter = AccumulatorMover.MoveAccumulators(_accumulators, stage.Filter);
var optimizedFilter = AccumulatorMover.MoveAccumulators(_accumulators, _element, stage.Filter);
return stage.Update(optimizedFilter);
}

Expand Down Expand Up @@ -201,7 +202,7 @@ private AstProjectStageSpecification OptimizeProjectStageSpecification(AstProjec

private AstProjectStageSpecification OptimizeProjectStageSetFieldSpecification(AstProjectStageSetFieldSpecification specification)
{
var optimizedValue = AccumulatorMover.MoveAccumulators(_accumulators, specification.Value);
var optimizedValue = AccumulatorMover.MoveAccumulators(_accumulators, _element, specification.Value);
return specification.Update(optimizedValue);
}

Expand Down Expand Up @@ -249,27 +250,29 @@ public string AddAccumulatorExpression(AstAccumulatorExpression value)
private class AccumulatorMover : AstNodeVisitor
{
#region static
public static TNode MoveAccumulators<TNode>(AccumulatorSet accumulators, TNode node)
public static TNode MoveAccumulators<TNode>(AccumulatorSet accumulators, AstExpression element, TNode node)
where TNode : AstNode
{
var mover = new AccumulatorMover(accumulators);
var mover = new AccumulatorMover(accumulators, element);
return mover.VisitAndConvert(node);
}
#endregion

private readonly AccumulatorSet _accumulators;
private readonly AstExpression _element;

private AccumulatorMover(AccumulatorSet accumulator)
private AccumulatorMover(AccumulatorSet accumulator, AstExpression element)
{
_accumulators = accumulator;
_element = element;
}

public override AstNode VisitFilterField(AstFilterField node)
{
// "_elements.0.X" => { __agg0 : { $first : "$$ROOT" } } + "__agg0.X"
// "_elements.0.X" => { __agg0 : { $first : element } } + "__agg0.X"
if (node.Path.StartsWith("_elements.0."))
{
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.First, AstExpression.Var("ROOT"));
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.First, _element);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
var restOfPath = node.Path.Substring("_elements.0.".Length);
var rewrittenPath = $"{accumulatorFieldName}.{restOfPath}";
Expand All @@ -288,9 +291,7 @@ public override AstNode VisitGetFieldExpression(AstGetFieldExpression node)
{
if (node.FieldName is AstConstantExpression constantFieldName &&
constantFieldName.Value.IsString &&
constantFieldName.Value.AsString == "_elements" &&
node.Input is AstVarExpression varExpression &&
varExpression.Name == "ROOT")
constantFieldName.Value.AsString == "_elements")
{
throw new UnableToRemoveReferenceToElementsException();
}
Expand All @@ -300,18 +301,18 @@ node.Input is AstVarExpression varExpression &&

public override AstNode VisitMapExpression(AstMapExpression node)
{
// { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => root) } } + "$__agg0"
// { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => element) } } + "$__agg0"
if (node.Input is AstGetFieldExpression mapInputGetFieldExpression &&
mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFieldExpression &&
mapInputconstantFieldExpression.Value.IsString &&
mapInputconstantFieldExpression.Value.AsString == "_elements" &&
mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpression &&
mapInputGetFieldVarExpression.Name == "ROOT")
{
var root = AstExpression.Var("ROOT", isCurrent: true);
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, root));
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, _element));
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Push, rewrittenArg);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
var root = AstExpression.Var("ROOT", isCurrent: true);
return AstExpression.GetField(root, accumulatorFieldName);
}

Expand All @@ -321,7 +322,7 @@ mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpressi
public override AstNode VisitPickExpression(AstPickExpression node)
{
// { $pickOperator : { source : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", sortBy : s, selector : f(x) } }
// => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => root) } } } + "$__agg0"
// => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element) } } } + "$__agg0"
if (node.Source is AstGetFieldExpression getFieldExpression &&
getFieldExpression.Input is AstVarExpression varExpression &&
varExpression.Name == "ROOT" &&
Expand All @@ -330,10 +331,10 @@ getFieldExpression.FieldName is AstConstantExpression constantFieldNameExpressio
constantFieldNameExpression.Value.AsString == "_elements")
{
var @operator = node.Operator.ToAccumulatorOperator();
var root = AstExpression.Var("ROOT", isCurrent: true);
var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, root));
var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, _element));
var accumulatorExpression = new AstPickAccumulatorExpression(@operator, node.SortBy, rewrittenSelector, node.N);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
var root = AstExpression.Var("ROOT", isCurrent: true);
return AstExpression.GetField(root, accumulatorFieldName);
}

Expand Down Expand Up @@ -384,7 +385,7 @@ argGetFieldExpression.FieldName is AstConstantExpression constantFieldNameExpres

bool TryOptimizeAccumulatorOfElements(out AstExpression optimizedExpression)
{
// { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : "$$ROOT" } } + "$__agg0"
// { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0"
if (node.Operator.IsAccumulator(out var accumulatorOperator) &&
node.Arg is AstGetFieldExpression getFieldExpression &&
getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameExpression &&
Expand All @@ -393,7 +394,7 @@ getFieldExpression.FieldName is AstConstantExpression getFieldConstantFieldNameE
getFieldExpression.Input is AstVarExpression getFieldInputVarExpression &&
getFieldInputVarExpression.Name == "ROOT")
{
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, root);
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
optimizedExpression = AstExpression.GetField(root, accumulatorFieldName);
return true;
Expand All @@ -406,7 +407,7 @@ getFieldExpression.Input is AstVarExpression getFieldInputVarExpression &&

bool TryOptimizeAccumulatorOfMappedElements(out AstExpression optimizedExpression)
{
// { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => root) } } + "$__agg0"
// { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element) } } + "$__agg0"
if (node.Operator.IsAccumulator(out var accumulatorOperator) &&
node.Arg is AstMapExpression mapExpression &&
mapExpression.Input is AstGetFieldExpression mapInputGetFieldExpression &&
Expand All @@ -416,7 +417,7 @@ mapInputGetFieldExpression.FieldName is AstConstantExpression mapInputconstantFi
mapInputGetFieldExpression.Input is AstVarExpression mapInputGetFieldVarExpression &&
mapInputGetFieldVarExpression.Name == "ROOT")
{
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, root));
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, _element));
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, rewrittenArg);
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
optimizedExpression = AstExpression.GetField(root, accumulatorFieldName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1537,8 +1537,8 @@ group f by f.D into g
4,
"{ $project : { _v : '$G', _id : 0 } }",
"{ $unwind : '$_v' }",
"{ $group : { _id : '$_v.D', _elements : { $push : '$_v' } } }",
"{ $project : { Key : '$_id', SumF : { $sum : '$_elements.E.F' }, _id : 0 } }");
"{ $group : { _id : '$_v.D', __agg0 : { $sum : '$_v.E.F' } } }",
"{ $project : { Key : '$_id', SumF : '$__agg0', _id : 0 } }");
}

[Fact]
Expand Down Expand Up @@ -1567,8 +1567,8 @@ group s by s.D into g
"{ $unwind : '$_v' }",
"{ $project : { '_v' : '$_v.S', '_id' : 0 } }",
"{ $unwind : '$_v' }",
"{ $group : { _id : '$_v.D', _elements : { $push : '$_v' } } }",
"{ $project : { Key : '$_id', SumF : { $sum : '$_elements.E.F' }, _id : 0 } }");
"{ $group : { _id : '$_v.D', __agg0 : { $sum : '$_v.E.F' } } }",
"{ $project : { Key : '$_id', SumF : '$__agg0', _id : 0 } }");
}

[Fact]
Expand Down
Loading

0 comments on commit 396830c

Please sign in to comment.