Skip to content

Commit

Permalink
Progress
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkMpn committed Sep 10, 2023
1 parent 95226dd commit 2fe11e9
Show file tree
Hide file tree
Showing 4 changed files with 251 additions and 18 deletions.
59 changes: 53 additions & 6 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ private void ConvertStatement(TSqlStatement statement, ExecutionPlanOptimizer op
}

var anchorQuery = new AliasNode(plan, cte.ExpressionName, _nodeContext);
_cteSubplans.Add(cte.ExpressionName.Value, anchorQuery);

if (cteValidator.RecursiveQueries.Count > 0)
{
Expand Down Expand Up @@ -281,22 +282,56 @@ private void ConvertStatement(TSqlStatement statement, ExecutionPlanOptimizer op
var recurseLoop = new NestedLoopNode
{
LeftSource = incrementRecursionDepthComputeScalar,
// TODO: Capture all CTE fields in the outer references
JoinType = QualifiedJoinType.Inner,
OuterReferences = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase)
};

// Capture all CTE fields in the outer references
var anchorSchema = anchorQuery.GetSchema(_nodeContext);

foreach (var col in anchorSchema.Schema)
recurseLoop.OuterReferences[col.Key] = "@" + _nodeContext.GetExpressionName();

if (cteValidator.RecursiveQueries.Count > 1)
{
// Combine the results of each recursive query with a concat node
var concat = new ConcatenateNode();
recurseLoop.RightSource = concat;

foreach (var qry in cteValidator.RecursiveQueries)
concat.Sources.Add(ConvertRecursiveCTEQuery(qry));
{
var rightSource = ConvertRecursiveCTEQuery(qry, anchorSchema, cteValidator, recurseLoop.OuterReferences);
concat.Sources.Add(rightSource.Source);

if (concat.Sources.Count == 1)
{
for (var i = 0; i < rightSource.ColumnSet.Count; i++)
{
var col = rightSource.ColumnSet[i];
var expr = _nodeContext.GetExpressionName();
concat.ColumnSet.Add(new ConcatenateColumn { OutputColumn = expr });
recurseLoop.DefinedValues.Add(expr, expr);
recurseConcat.ColumnSet[i].SourceColumns.Add(expr);
}
}

for (var i = 0; i < rightSource.ColumnSet.Count; i++)
concat.ColumnSet[i].SourceColumns.Add(rightSource.ColumnSet[i].SourceColumn);
}
}
else
{
recurseLoop.RightSource = ConvertRecursiveCTEQuery(cteValidator.RecursiveQueries[0]);
var rightSource = ConvertRecursiveCTEQuery(cteValidator.RecursiveQueries[0], anchorSchema, cteValidator, recurseLoop.OuterReferences);
recurseLoop.RightSource = rightSource.Source;

for (var i = 0; i < rightSource.ColumnSet.Count; i++)
{
var col = rightSource.ColumnSet[i];
var expr = _nodeContext.GetExpressionName();

recurseLoop.DefinedValues.Add(expr, col.SourceColumn);
recurseConcat.ColumnSet[i].SourceColumns.Add(expr);
}
}

// Ensure we don't get stuck in an infinite loop
Expand Down Expand Up @@ -335,13 +370,10 @@ private void ConvertStatement(TSqlStatement statement, ExecutionPlanOptimizer op
recurseConcat.Sources.Add(recurseLoop);
}

// TODO: Update the sources for each field in the concat node
recurseConcat.ColumnSet.Last().SourceColumns.Add(incrementedDepthField);

anchorQuery.Source = incrementRecursionDepthComputeScalar;
}

_cteSubplans.Add(cte.ExpressionName.Value, new AliasNode(plan, cte.ExpressionName, _nodeContext));
}
}
}
Expand Down Expand Up @@ -402,6 +434,21 @@ private void ConvertStatement(TSqlStatement statement, ExecutionPlanOptimizer op
}
}

private SelectNode ConvertRecursiveCTEQuery(QueryExpression queryExpression, INodeSchema anchorSchema, CteValidatorVisitor cteValidator, Dictionary<string, string> outerReferences)
{
// Convert the query using the anchor query as a subquery to check for ambiguous column names
ConvertSelectStatement(queryExpression, null, null, null, _nodeContext);

// Remove recursive references from the FROM clause, moving join predicates to the WHERE clause
// If the recursive reference was in an unqualified join, replace it with (SELECT @Expr1, @Expr2) AS cte (field1, field2)
// Otherwise, remove it entirely and replace column references with variables
var cteReplacer = new RemoveRecursiveCTETableReferencesVisitor(cteValidator.Name, anchorSchema.Schema.Keys.ToArray(), outerReferences);
queryExpression.Accept(cteReplacer);

// Convert the modified query.
return ConvertSelectStatement(queryExpression, null, anchorSchema, outerReferences, _nodeContext);
}

private IRootExecutionPlanNodeInternal[] ConvertExecuteStatement(ExecuteStatement execute)
{
var nodes = new List<IRootExecutionPlanNodeInternal>();
Expand Down
1 change: 1 addition & 0 deletions MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@
<Compile Include="$(MSBuildThisFileDirectory)Visitors\InSubqueryVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\OptimizerHintValidatingVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\ParameterlessCollectingVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\RemoveRecursiveCTETableReferencesVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\ReplaceCtesWithSubqueriesVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\ReplacePrimaryFunctionsVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\RewriteVisitor.cs" />
Expand Down
45 changes: 33 additions & 12 deletions MarkMpn.Sql4Cds.Engine/Visitors/CteValidatorVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,6 @@ public override void ExplicitVisit(BinaryQueryExpression node)
throw new NotSupportedQueryFragmentException($"Recursive common table expression '{Name}' does not contain a top-level UNION ALL operator", node);
}

public override void ExplicitVisit(QuerySpecification node)
{
base.ExplicitVisit(node);

if (!IsRecursive)
AnchorQuery = node;
else
RecursiveQueries.Add(node);
}

public override void ExplicitVisit(QueryParenthesisExpression node)
{
base.ExplicitVisit(node);
Expand All @@ -101,13 +91,13 @@ public override void ExplicitVisit(QueryParenthesisExpression node)
RecursiveQueries.Add(node);
}

public override void Visit(QuerySpecification node)
public override void ExplicitVisit(QuerySpecification node)
{
_scalarAggregate = null;
_subquery = null;
_outerJoin = null;

base.Visit(node);
base.ExplicitVisit(node);

// The following clauses can't be used in the CTE_query_definition:
// ORDER BY (except when a TOP clause is specified)
Expand Down Expand Up @@ -151,6 +141,11 @@ public override void Visit(QuerySpecification node)
if (_subquery != null)
throw new NotSupportedQueryFragmentException("Recursive references are not allowed in subqueries", _subquery);
}

if (!IsRecursive)
AnchorQuery = node;
else
RecursiveQueries.Add(node);
}

public override void Visit(SelectStatement node)
Expand All @@ -176,5 +171,31 @@ public override void ExplicitVisit(ScalarSubquery node)
if (_cteReferenceCount > count)
_subquery = node;
}

public override void Visit(FunctionCall node)
{
base.Visit(node);

switch (node.FunctionName.Value.ToLowerInvariant())
{
case "approx_count_distinct":
case "avg":
case "checksum_agg":
case "count":
case "count_big":
case "grouping":
case "grouping_id":
case "max":
case "min":
case "stdev":
case "stdevp":
case "string_agg":
case "sum":
case "var":
case "varp":
_scalarAggregate = node;
break;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.SqlServer.TransactSql.ScriptDom;

namespace MarkMpn.Sql4Cds.Engine.Visitors
{
/// <summary>
/// Finds recursive references in a CTE definition and removes them, moving join predicates to the WHERE clause
/// </summary>
class RemoveRecursiveCTETableReferencesVisitor : TSqlConcreteFragmentVisitor
{
private readonly string _name;
private readonly string[] _columnNames;
private readonly Dictionary<string, string> _outerReferences;
private BooleanExpression _joinPredicate;
private int _inUnqualifiedJoin;

public RemoveRecursiveCTETableReferencesVisitor(string name, string[] columnNames, Dictionary<string, string> outerReferences)
{
_name = name;
_columnNames = columnNames;
_outerReferences = outerReferences;
}

private bool IsRecursiveReference(TableReference tableReference)
{
if (!(tableReference is NamedTableReference namedTable))
return false;

if (namedTable.SchemaObject.Identifiers.Count != 1)
return false;

return namedTable.SchemaObject.BaseIdentifier.Value.Equals(_name, StringComparison.OrdinalIgnoreCase);
}

private InlineDerivedTable CreateInlineDerivedTable()
{
var table = new InlineDerivedTable
{
Alias = new Identifier { Value = _name },
RowValues = { new RowValue() }
};

foreach (var col in _columnNames)
{
table.Columns.Add(new Identifier { Value = col });
table.RowValues[0].ColumnValues.Add(new VariableReference { Name = _outerReferences[col] });
}

return table;
}

private bool RemoveRecursiveJoin(TableReference tableReference, out TableReference removed)
{
removed = null;

if (!(tableReference is JoinTableReference join))
return false;

if (IsRecursiveReference(join.FirstTableReference))
{
if (_inUnqualifiedJoin > 0)
{
join.FirstTableReference = CreateInlineDerivedTable();
return false;
}

_joinPredicate = (join as QualifiedJoin)?.SearchCondition;
removed = join.SecondTableReference;
return true;
}

if (IsRecursiveReference(join.SecondTableReference))
{
if (_inUnqualifiedJoin > 0)
{
join.SecondTableReference = CreateInlineDerivedTable();
return false;
}

_joinPredicate = (join as QualifiedJoin)?.SearchCondition;
removed = join.FirstTableReference;
return true;
}

return false;
}

public override void Visit(FromClause node)
{
base.Visit(node);

for (var i = 0; i < node.TableReferences.Count; i++)
{
if (IsRecursiveReference(node.TableReferences[i]))
{
if (_inUnqualifiedJoin > 0)
node.TableReferences[i] = CreateInlineDerivedTable();
else
node.TableReferences.RemoveAt(i);
}
else if (RemoveRecursiveJoin(node.TableReferences[i], out var removed))
{
node.TableReferences[i] = removed;
}
}
}

public override void Visit(QualifiedJoin node)
{
base.Visit(node);

if (RemoveRecursiveJoin(node.FirstTableReference, out var removed))
node.FirstTableReference = removed;

if (RemoveRecursiveJoin(node.SecondTableReference, out removed))
node.SecondTableReference = removed;
}

public override void Visit(UnqualifiedJoin node)
{
base.Visit(node);

if (RemoveRecursiveJoin(node.FirstTableReference, out var removed))
node.FirstTableReference = removed;
}

public override void ExplicitVisit(UnqualifiedJoin node)
{
node.FirstTableReference.Accept(this);

_inUnqualifiedJoin++;

if (RemoveRecursiveJoin(node.SecondTableReference, out var removed))
node.SecondTableReference = removed;

node.SecondTableReference.Accept(this);
_inUnqualifiedJoin--;
}

public override void ExplicitVisit(QuerySpecification node)
{
base.ExplicitVisit(node);

if (_joinPredicate != null)
{
if (node.WhereClause == null)
{
node.WhereClause = new WhereClause { SearchCondition = _joinPredicate };
}
else
{
node.WhereClause.SearchCondition = new BooleanBinaryExpression
{
FirstExpression = node.WhereClause.SearchCondition,
BinaryExpressionType = BooleanBinaryExpressionType.And,
SecondExpression = _joinPredicate
};
}
}
}
}
}

0 comments on commit 2fe11e9

Please sign in to comment.