Skip to content

Commit

Permalink
Subquery in join condition progress
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkMpn committed May 17, 2024
1 parent 68cb0d6 commit 2fcfa61
Show file tree
Hide file tree
Showing 8 changed files with 270 additions and 23 deletions.
76 changes: 70 additions & 6 deletions MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6780,10 +6780,13 @@ AND [union. all].logicalname IN ('createdon')

var sort = AssertNode<SortNode>(select.Source);

var join1 = AssertNode<HashJoinNode>(sort.Source);
var filter1 = AssertNode<FilterNode>(sort.Source);
Assert.AreEqual("[union. all].logicalname = a2.logicalname", filter1.Filter.ToSql());

var join1 = AssertNode<HashJoinNode>(filter1.Source);
Assert.AreEqual("a2.entitylogicalname", join1.LeftAttribute.ToSql());
Assert.AreEqual("[union. all].eln", join1.RightAttribute.ToSql());
Assert.AreEqual("[union. all].logicalname = a2.logicalname", join1.AdditionalJoinCriteria.ToSql());
Assert.IsNull(join1.AdditionalJoinCriteria);

var mq1 = AssertNode<MetadataQueryNode>(join1.LeftSource);
Assert.AreEqual("french", mq1.DataSource);
Expand Down Expand Up @@ -7498,10 +7501,38 @@ from account a
Assert.AreEqual("name", select.ColumnSet[0].OutputColumn);
}

[TestMethod]
public void ScalarSubquery()
{
var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this);

var query = @"
select top 10 * from (
select fullname, (select name from account where accountid = parentcustomerid) from contact
) a";

var plans = planBuilder.Build(query, null, out _);

Assert.AreEqual(1, plans.Length);

var select = AssertNode<SelectNode>(plans[0]);
var fetch = AssertNode<FetchXmlScan>(select.Source);

AssertFetchXml(fetch, @"
<fetch top='10'>
<entity name='contact'>
<attribute name='fullname' />
<link-entity name='account' to='parentcustomerid' from='accountid' alias='Expr2' link-type='outer'>
<attribute name='name' alias='Expr3' />
</link-entity>
</entity>
</fetch>");
}

[TestMethod]
public void SubqueryInJoinCriteriaRHS()
{
using (_localDataSource.EnableJoinOperator(JoinOperator.Exists))
using (_localDataSource.EnableJoinOperator(JoinOperator.In))
{
var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this);

Expand All @@ -7520,11 +7551,44 @@ from account

AssertFetchXml(fetch, @"
<fetch>
<entity name='account'>
<entity name='contact'>
<all-attributes />
<link-entity name='new_customentity' to='firstname' from='new_name' link-type='in' />
<link-entity name='account' to='parentcustomerid' from='accountid' alias='account' link-type='inner'>
<all-attributes />
</link-entity>
</entity>
</fetch>");
}
}

[TestMethod]
public void SubqueryInJoinCriteriaRHSCorrelatedExists()
{
using (_localDataSource.EnableJoinOperator(JoinOperator.In))
{
var planBuilder = new ExecutionPlanBuilder(_localDataSources.Values, this);

var query = @"
select
*
from account
inner join contact ON account.accountid = contact.parentcustomerid AND EXISTS(SELECT * FROM new_customentity WHERE new_name = contact.firstname)";

var plans = planBuilder.Build(query, null, out _);

Assert.AreEqual(1, plans.Length);

var select = AssertNode<SelectNode>(plans[0]);
var fetch = AssertNode<FetchXmlScan>(select.Source);

AssertFetchXml(fetch, @"
<fetch>
<entity name='contact'>
<all-attributes />
<link-entity name='contact' to='accountid' from='parentcustomerid' alias='contact' link-type='inner'>
<link-entity name='new_customentity' to='firstname' from='new_name' link-type='in' />
<link-entity name='account' to='parentcustomerid' from='accountid' alias='account' link-type='inner'>
<all-attributes />
<link-entity name='new_customentity' to='firstname' from='new_name' link-type='exists' >
</link-entity>
</entity>
</fetch>");
Expand Down
2 changes: 1 addition & 1 deletion MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseJoinNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ protected virtual INodeSchema GetSchema(NodeCompilationContext context, bool inc
foreach (var definedValue in DefinedValues)
{
innerSchema.ContainsColumn(definedValue.Value, out var innerColumn);
schema[definedValue.Key] = innerSchema.Schema[innerColumn];
schema[definedValue.Key] = innerSchema.Schema[innerColumn].Invisible().Calculated();
}

_lastLeftSchema = outerSchema;
Expand Down
78 changes: 78 additions & 0 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlan/FilterNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ private void AddNotNullColumn(NodeSchema schema, ScalarExpression expr)

public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList<OptimizerHint> hints)
{
// Swap filter to come after sort
if (Source is SortNode sort)
{
Source = sort.Source;
sort.Source = this;
return sort.FoldQuery(context, hints);
}

Filter = FoldNotIsNullToIsNotNull(Filter);

// If we have a filter which implies a non-null value for a column that is generated by an outer join,
Expand All @@ -156,6 +164,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext
foldedFilters |= FoldTableSpoolToIndexSpool(context, hints);
foldedFilters |= ExpandFiltersOnColumnComparisons(context);
foldedFilters |= FoldFiltersToDataSources(context, hints, subqueryConditions);
foldedFilters |= FoldFiltersToInnerJoinSources(context, hints);

foreach (var addedLink in addedLinks)
{
Expand All @@ -181,6 +190,75 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext
return this;
}

private bool FoldFiltersToInnerJoinSources(NodeCompilationContext context, IList<OptimizerHint> hints)
{
if (Filter == null)
return false;

if (!(Source is BaseJoinNode join) || join.JoinType != QualifiedJoinType.Inner)
return false;

var folded = false;
var leftSchema = join.LeftSource.GetSchema(context);
Filter = ExtractChildFilters(Filter, leftSchema, col => leftSchema.ContainsColumn(col, out _), out var leftFilter);

if (leftFilter != null)
{
join.LeftSource = new FilterNode
{
Source = join.LeftSource,
Filter = leftFilter
}.FoldQuery(context, hints);
join.LeftSource.Parent = join;

folded = true;
}

if (Filter == null)
return true;

var rightContext = context;

if (join is NestedLoopNode loop)
{
var innerParameterTypes = context.ParameterTypes
.Concat(loop.OuterReferences.Select(or => new KeyValuePair<string, DataTypeReference>(or.Value, leftSchema.Schema[or.Key].Type)))
.ToDictionary(p => p.Key, p => p.Value, StringComparer.OrdinalIgnoreCase);

rightContext = new NodeCompilationContext(context, innerParameterTypes);
}

var rightSchema = join.RightSource.GetSchema(rightContext);
Filter = ExtractChildFilters(Filter, rightSchema, col => rightSchema.ContainsColumn(col, out _) || join.DefinedValues.ContainsKey(col), out var rightFilter);

if (rightFilter != null)
{
if (join.DefinedValues.Count > 0)
{
var rewrite = new RewriteVisitor(join.DefinedValues.ToDictionary(kvp => (ScalarExpression)kvp.Key.ToColumnReference(), kvp => (ScalarExpression)kvp.Value.ToColumnReference()));
rightFilter.Accept(rewrite);
}

join.RightSource = new FilterNode
{
Source = join.RightSource,
Filter = rightFilter
}.FoldQuery(rightContext, hints);
join.RightSource.Parent = join;

folded = true;
}

if (folded)
{
// Re-fold the join
Source = Source.FoldQuery(context, hints);
Source.Parent = this;
}

return folded;
}

private bool CheckStartupExpression()
{
// We only need to apply the filter expression to individual rows if it references any fields
Expand Down
82 changes: 81 additions & 1 deletion MarkMpn.Sql4Cds.Engine/ExecutionPlan/FoldableJoinNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ abstract class FoldableJoinNode : BaseJoinNode

public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList<OptimizerHint> hints)
{
// For inner joins, additional join criteria are eqivalent to doing the join without them and then applying the filter
// We've already got logic in the Filter node for efficiently folding those queries, so split them out and let it do
// what it can
if (JoinType == QualifiedJoinType.Inner && AdditionalJoinCriteria != null)
{
var filter = new FilterNode
{
Source = this,
Filter = AdditionalJoinCriteria
};
AdditionalJoinCriteria = null;
return filter.FoldQuery(context, hints);
}

LeftSource = LeftSource.FoldQuery(context, hints);
LeftSource.Parent = this;
RightSource = RightSource.FoldQuery(context, hints);
Expand Down Expand Up @@ -428,7 +442,73 @@ private bool FoldFetchXmlJoin(NodeCompilationContext context, IList<OptimizerHin
// We might have previously folded a sort to the FetchXML that is no longer valid as we require custom paging
if (leftFetch.RequiresCustomPaging(context.DataSources))
leftFetch.RemoveSorts();


// We might have previously folded a not-null condition that is no longer required as it is implicit in the join
var unnecessaryNotNullColumns = new List<string>();
if (JoinType == QualifiedJoinType.Inner || JoinType == QualifiedJoinType.RightOuter)
unnecessaryNotNullColumns.Add(LeftAttribute.GetColumnName());
if (JoinType == QualifiedJoinType.Inner || JoinType == QualifiedJoinType.LeftOuter)
unnecessaryNotNullColumns.Add(RightAttribute.GetColumnName());

if (unnecessaryNotNullColumns != null)
{
var finalSchema = leftFetch.GetSchema(context);

foreach (var col in unnecessaryNotNullColumns)
{
if (!finalSchema.ContainsColumn(col, out var normalizedCol))
continue;

var parts = normalizedCol.SplitMultiPartIdentifier();

if (parts[0] == leftFetch.Alias)
{
foreach (var entityFilter in leftFetch.Entity.Items.OfType<filter>())
{
if (entityFilter.type != filterType.and)
continue;

foreach (var condition in entityFilter.Items.OfType<condition>().ToList())
{
if (condition.entityname == null && condition.attribute == parts[1] && condition.@operator == @operator.notnull)
entityFilter.Items = entityFilter.Items.Except(new[] { condition }).ToArray();
}
}
}
else
{
foreach (var entityFilter in leftFetch.Entity.Items.OfType<filter>())
{
if (entityFilter.type != filterType.and)
continue;

foreach (var condition in entityFilter.Items.OfType<condition>().ToList())
{
if (condition.entityname == parts[0] && condition.attribute == parts[1] && condition.@operator == @operator.notnull)
entityFilter.Items = entityFilter.Items.Except(new[] { condition }).ToArray();
}
}

var link = leftFetch.Entity.FindLinkEntity(parts[0]);

if (link?.Items != null)
{
foreach (var linkFilter in link.Items.OfType<filter>())
{
if (linkFilter.type != filterType.and)
continue;

foreach (var condition in linkFilter.Items.OfType<condition>().ToList())
{
if (condition.attribute == parts[1] && condition.@operator == @operator.notnull)
linkFilter.Items = linkFilter.Items.Except(new[] { condition }).ToArray();
}
}
}
}
}
}

return true;
}

Expand Down
38 changes: 26 additions & 12 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlan/MergeJoinNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan
/// </summary>
class MergeJoinNode : FoldableJoinNode
{
private SortNode _leftSort;
private SortNode _rightSort;

[Description("Many to Many")]
[Category("Merge Join")]
public bool ManyToMany { get; private set; }
Expand Down Expand Up @@ -233,6 +236,19 @@ private bool Done(bool hasLeft, bool hasRight)

public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList<OptimizerHint> hints)
{
// If we've previously added sort nodes to either input, remove them before trying to fold the query again
if (LeftSource == _leftSort)
{
LeftSource = _leftSort.Source;
_leftSort = null;
}

if (RightSource == _rightSort)
{
RightSource = _rightSort.Source;
_rightSort = null;
}

var folded = base.FoldQuery(context, hints);

if (folded != this)
Expand Down Expand Up @@ -289,7 +305,7 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext
return this;

// Can't fold the join down into the FetchXML, so add a sort and try to fold that in instead
LeftSource = new SortNode
_leftSort = new SortNode
{
Source = LeftSource,
Sorts =
Expand All @@ -300,10 +316,11 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext
SortOrder = SortOrder.Ascending
}
}
}.FoldQuery(context, hints);
};
LeftSource = _leftSort.FoldQuery(context, hints);
LeftSource.Parent = this;

RightSource = new SortNode
_rightSort = new SortNode
{
Source = RightSource,
Sorts =
Expand All @@ -314,24 +331,21 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext
SortOrder = SortOrder.Ascending
}
}
}.FoldQuery(context, hints);
};
RightSource = _rightSort.FoldQuery(context, hints);
RightSource.Parent = this;

// If we couldn't fold the sorts, it's probably faster to use a hash join instead if we only want partial results
var leftSort = LeftSource as SortNode;
var rightSort = RightSource as SortNode;

if (leftSort == null && rightSort == null)
if (LeftSource != _leftSort && RightSource != _rightSort)
return this;

hashJoin.LeftSource = leftSort?.Source ?? LeftSource;
hashJoin.RightSource = rightSort?.Source ?? RightSource;

hashJoin.LeftSource = (LeftSource == _leftSort) ? _leftSort.Source : LeftSource;
hashJoin.RightSource = (RightSource == _rightSort) ? _rightSort.Source : RightSource;

var foldedHashJoin = hashJoin.FoldQuery(context, hints);

if (Parent is TopNode ||
leftSort != null && rightSort != null)
LeftSource == _leftSort && RightSource == _rightSort)
return foldedHashJoin;

LeftSource.Parent = this;
Expand Down
Loading

0 comments on commit 2fcfa61

Please sign in to comment.