Skip to content

Commit

Permalink
Allow using TDS Endpoint for non-recursive CTEs
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkMpn committed Sep 5, 2023
1 parent f450304 commit c07a8cc
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 10 deletions.
5 changes: 4 additions & 1 deletion MarkMpn.Sql4Cds.Engine/ExecutionPlanBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public IRootExecutionPlanNode[] Build(string sql, IDictionary<string, DataTypeRe
var tdsEndpointCompatibilityVisitor = new TDSEndpointCompatibilityVisitor(con, DataSources[Options.PrimaryDataSource].Metadata);
fragment.Accept(tdsEndpointCompatibilityVisitor);

if (tdsEndpointCompatibilityVisitor.IsCompatible)
if (tdsEndpointCompatibilityVisitor.IsCompatible && !tdsEndpointCompatibilityVisitor.RequiresCteRewrite)
{
useTDSEndpointDirectly = true;
var sqlNode = new SqlNode
Expand Down Expand Up @@ -1663,6 +1663,9 @@ private IRootExecutionPlanNodeInternal ConvertSelectStatement(SelectStatement se

if (tdsEndpointCompatibilityVisitor.IsCompatible && hintCompatibilityVisitor.TdsCompatible)
{
if (tdsEndpointCompatibilityVisitor.RequiresCteRewrite)
select.Accept(new ReplaceCtesWithSubqueriesVisitor());

select.ScriptTokenStream = null;
var sql = new SqlNode
{
Expand Down
2 changes: 2 additions & 0 deletions MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.projitems
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
<Compile Include="$(MSBuildThisFileDirectory)SqlDateTypes.cs" />
<Compile Include="$(MSBuildThisFileDirectory)SqlVariant.cs" />
<Compile Include="$(MSBuildThisFileDirectory)StateTransitionLoader.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\CteValidatorVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\ExplicitCollationVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\JoinConditionVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)MetadataExtensions.cs" />
Expand All @@ -139,6 +140,7 @@
<Compile Include="$(MSBuildThisFileDirectory)Visitors\InSubqueryVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\OptimizerHintValidatingVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\ParameterlessCollectingVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\ReplaceCtesWithSubqueriesVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\ReplacePrimaryFunctionsVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\RewriteVisitor.cs" />
<Compile Include="$(MSBuildThisFileDirectory)Visitors\RewriteVisitorBase.cs" />
Expand Down
144 changes: 144 additions & 0 deletions MarkMpn.Sql4Cds.Engine/Visitors/CteValidatorVisitor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.SqlServer.TransactSql.ScriptDom;

namespace MarkMpn.Sql4Cds.Engine.Visitors
{
/// <summary>
/// Checks the properties of a CTE to ensure it is valid
/// </summary>
/// <remarks>
/// https://learn.microsoft.com/en-us/sql/t-sql/queries/with-common-table-expression-transact-sql?view=sql-server-ver16
/// </remarks>
class CteValidatorVisitor : TSqlFragmentVisitor
{
private int _cteReferenceCount;
private FunctionCall _scalarAggregate;
private ScalarSubquery _subquery;
private QualifiedJoin _outerJoin;

public string Name { get; private set; }

public bool IsRecursive { get; private set; }

public override void Visit(CommonTableExpression node)
{
Name = node.ExpressionName.Value;

base.Visit(node);
}

public override void Visit(FromClause node)
{
_cteReferenceCount = 0;

base.Visit(node);

// The FROM clause of a recursive member must refer only one time to the CTE expression_name.
if (_cteReferenceCount > 1)
throw new NotSupportedQueryFragmentException("Recursive CTEs can only be referenced once", node);
}

public override void Visit(NamedTableReference node)
{
if (node.SchemaObject.Identifiers.Count == 1 &&
node.SchemaObject.BaseIdentifier.Value.Equals(Name, StringComparison.OrdinalIgnoreCase))
{
IsRecursive = true;
_cteReferenceCount++;

// The following items aren't allowed in the CTE_query_definition of a recursive member:
// A hint applied to a recursive reference to a CTE inside a CTE_query_definition.
if (node.TableHints.Count > 0)
throw new NotSupportedQueryFragmentException("Table hints are not supported in CTEs", node.TableHints[0]);
}

base.Visit(node);
}

public override void Visit(QualifiedJoin node)
{
base.Visit(node);

if (node.QualifiedJoinType != QualifiedJoinType.Inner)
_outerJoin = node;
}

public override void ExplicitVisit(BinaryQueryExpression node)
{
base.ExplicitVisit(node);

// UNION ALL is the only set operator allowed between the last anchor member and first recursive member, and when combining multiple recursive members.
if (IsRecursive && (node.BinaryQueryExpressionType != BinaryQueryExpressionType.Union || !node.All))
throw new NotSupportedQueryFragmentException("Recursive CTEs must have a UNION ALL between the anchor and recursive parts", node);
}

public override void Visit(QuerySpecification node)
{
_scalarAggregate = null;
_subquery = null;
_outerJoin = null;

base.Visit(node);

// The following clauses can't be used in the CTE_query_definition:
// ORDER BY (except when a TOP clause is specified)
if (node.OrderByClause != null && node.TopRowFilter == null)
throw new NotSupportedQueryFragmentException("ORDER BY is not supported in CTEs", node.OrderByClause);

// FOR BROWSE
if (node.ForClause is BrowseForClause)
throw new NotSupportedQueryFragmentException("FOR BROWSE is not supported in CTEs", node.ForClause);

if (IsRecursive)
{
// The following items aren't allowed in the CTE_query_definition of a recursive member:
// SELECT DISTINCT
if (node.UniqueRowFilter == UniqueRowFilter.Distinct)
throw new NotSupportedQueryFragmentException("DISTINCT is not supported in CTEs", node);

// GROUP BY
if (node.GroupByClause != null)
throw new NotSupportedQueryFragmentException("GROUP BY is not supported in CTEs", node.GroupByClause);

// TODO: PIVOT

// HAVING
if (node.HavingClause != null)
throw new NotSupportedQueryFragmentException("HAVING is not supported in CTEs", node.HavingClause);

// Scalar aggregation
if (_scalarAggregate != null)
throw new NotSupportedQueryFragmentException("Scalar aggregation is not supported in CTEs", _scalarAggregate);

// TOP
if (node.TopRowFilter != null)
throw new NotSupportedQueryFragmentException("TOP is not supported in CTEs", node.TopRowFilter);

// LEFT, RIGHT, OUTER JOIN (INNER JOIN is allowed)
if (_outerJoin != null)
throw new NotSupportedQueryFragmentException("Outer joins are not supported in CTEs", _outerJoin);

// Subqueries
if (_subquery != null)
throw new NotSupportedQueryFragmentException("Subqueries are not supported in CTEs", _subquery);
}
}

public override void Visit(SelectStatement node)
{
base.Visit(node);

// The following clauses can't be used in the CTE_query_definition:
// INTO
if (node.Into != null)
throw new NotSupportedQueryFragmentException("INTO is not supported in CTEs", node.Into);

// OPTION clause with query hints
if (node.OptimizerHints.Count > 0)
throw new NotSupportedQueryFragmentException("Optimizer hints are not supported in CTEs", node.OptimizerHints[0]);

}
}
}
109 changes: 109 additions & 0 deletions MarkMpn.Sql4Cds.Engine/Visitors/ReplaceCtesWithSubqueriesVisitor.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.SqlServer.TransactSql.ScriptDom;

namespace MarkMpn.Sql4Cds.Engine.Visitors
{
/// <summary>
/// Finds references to non-recursive CTEs and replaces them with subqueries
/// </summary>
class ReplaceCtesWithSubqueriesVisitor : TSqlFragmentVisitor
{
private Dictionary<string, CommonTableExpression> _cteQueries;

public ReplaceCtesWithSubqueriesVisitor()
{
_cteQueries = new Dictionary<string, CommonTableExpression>(StringComparer.OrdinalIgnoreCase);
}

public override void Visit(CommonTableExpression node)
{
base.Visit(node);

_cteQueries[node.ExpressionName.Value] = node;
}

public override void Visit(SelectStatement node)
{
// Visit the CTEs first
if (node.WithCtesAndXmlNamespaces != null)
node.WithCtesAndXmlNamespaces.Accept(this);

base.Visit(node);

// Should have visited the CTEs now, so remove them from the query
node.WithCtesAndXmlNamespaces = null;
}

public override void Visit(FromClause node)
{
base.Visit(node);

for (var i = 0; i < node.TableReferences.Count; i++)
{
if (node.TableReferences[i] is NamedTableReference ntr &&
TryGetCteDefinition(ntr, out var subquery))
{
node.TableReferences[i] = subquery;
}
}
}

public override void Visit(QualifiedJoin node)
{
base.Visit(node);

if (node.FirstTableReference is NamedTableReference table1 &&
TryGetCteDefinition(table1, out var subquery1))
{
node.FirstTableReference = subquery1;
}

if (node.SecondTableReference is NamedTableReference table2 &&
TryGetCteDefinition(table2, out var subquery2))
{
node.SecondTableReference = subquery2;
}
}

public override void Visit(UnqualifiedJoin node)
{
base.Visit(node);

if (node.FirstTableReference is NamedTableReference table1 &&
TryGetCteDefinition(table1, out var subquery1))
{
node.FirstTableReference = subquery1;
}

if (node.SecondTableReference is NamedTableReference table2 &&
TryGetCteDefinition(table2, out var subquery2))
{
node.SecondTableReference = subquery2;
}
}

private bool TryGetCteDefinition(NamedTableReference table, out QueryDerivedTable subquery)
{
subquery = null;

if (table.SchemaObject.Identifiers.Count > 1)
return false;

if (!_cteQueries.TryGetValue(table.SchemaObject.BaseIdentifier.Value, out var cte))
return false;

subquery = new QueryDerivedTable
{
Alias = table.Alias ?? cte.ExpressionName,
QueryExpression = cte.QueryExpression
};

foreach (var col in cte.Columns)
subquery.Columns.Add(col);

return true;
}
}
}
Loading

0 comments on commit c07a8cc

Please sign in to comment.