diff --git a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs index a9152527..6b5dfc30 100644 --- a/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs +++ b/MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs @@ -1507,6 +1507,60 @@ SELECT name FROM account "); } + [TestMethod] + public void UnionMultiple() + { + var planBuilder = new ExecutionPlanBuilder(_localDataSource.Values, this); + + var query = @" + SELECT name FROM account + UNION + SELECT fullname FROM contact + UNION + SELECT domainname FROM systemuser"; + + var plans = planBuilder.Build(query, null, out _); + + Assert.AreEqual(1, plans.Length); + + var select = AssertNode(plans[0]); + Assert.AreEqual(1, select.ColumnSet.Count); + Assert.AreEqual("name", select.ColumnSet[0].OutputColumn); + Assert.AreEqual("Expr2", select.ColumnSet[0].SourceColumn); + var distinct = AssertNode(select.Source); + Assert.AreEqual("Expr2", distinct.Columns.Single()); + var concat = AssertNode(distinct.Source); + Assert.AreEqual(3, concat.Sources.Count); + Assert.AreEqual("Expr2", concat.ColumnSet[0].OutputColumn); + Assert.AreEqual("account.name", concat.ColumnSet[0].SourceColumns[0]); + Assert.AreEqual("contact.fullname", concat.ColumnSet[0].SourceColumns[1]); + Assert.AreEqual("systemuser.domainname", concat.ColumnSet[0].SourceColumns[2]); + var accountFetch = AssertNode(concat.Sources[0]); + AssertFetchXml(accountFetch, @" + + + + + + "); + var contactFetch = AssertNode(concat.Sources[1]); + AssertFetchXml(contactFetch, @" + + + + + + "); + var systemuserFetch = AssertNode(concat.Sources[2]); + AssertFetchXml(systemuserFetch, @" + + + + + + "); + } + [TestMethod] public void UnionSort() { diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs index 07689c40..058aeca5 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs @@ -62,6 +62,10 @@ public override IEnumerable GetSources() public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext context, IList hints) { + // We can have a sequence of Distinct - Concatenate - Distinct - Concatenate when we have multiple UNION statements + // We can collapse this to a single Distinct - Concatenate with all the sources from the various Concatenate nodes + CombineConcatenateSources(); + Source = Source.FoldQuery(context, hints); Source.Parent = this; @@ -178,6 +182,60 @@ public override IDataExecutionPlanNodeInternal FoldQuery(NodeCompilationContext return aggregate; } + private void CombineConcatenateSources() + { + if (!(Source is ConcatenateNode concat)) + return; + + for (var i = 0; i < concat.Sources.Count; i++) + { + bool folded; + + do + { + folded = false; + + if (concat.Sources[i] is DistinctNode distinct) + { + concat.Sources[i] = distinct.Source; + folded = true; + } + + if (concat.Sources[i] is ConcatenateNode subConcat) + { + for (var j = 0; j < subConcat.Sources.Count; j++) + { + concat.Sources.Insert(i + j + 1, subConcat.Sources[j]); + + foreach (var col in concat.ColumnSet) + { + foreach (var subCol in subConcat.ColumnSet) + { + if (col.SourceColumns[i] == subCol.OutputColumn) + { + col.SourceColumns.Insert(i + j + 1, subCol.SourceColumns[j]); + col.SourceExpressions.Insert(i + j + 1, subCol.SourceExpressions[j]); + break; + } + } + } + } + + concat.Sources.RemoveAt(i); + + foreach (var col in concat.ColumnSet) + { + col.SourceColumns.RemoveAt(i); + col.SourceExpressions.RemoveAt(i); + } + + i--; + folded = false; + } + } while (folded); + } + } + public override void AddRequiredColumns(NodeCompilationContext context, IList requiredColumns) { foreach (var col in Columns)