Skip to content

Commit

Permalink
Use StreamAggregate for Distinct if input is already sorted
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkMpn committed Dec 31, 2021
1 parent be81e39 commit a52cb99
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 126 deletions.
56 changes: 56 additions & 0 deletions MarkMpn.Sql4Cds.Engine.Tests/ExecutionPlanTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3250,5 +3250,61 @@ public void EntityReferenceInQuery()
</entity>
</fetch>");
}

[TestMethod]
public void OrderBySelectExpression()
{
var metadata = new AttributeMetadataCache(_service);
var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), this);

var query = "SELECT name + 'foo' FROM account ORDER BY 1";

var plans = planBuilder.Build(query);

Assert.AreEqual(1, plans.Length);

var select = AssertNode<SelectNode>(plans[0]);
Assert.AreEqual("Expr1", select.ColumnSet.Single().SourceColumn);
var sort = AssertNode<SortNode>(select.Source);
Assert.AreEqual("Expr1", sort.Sorts.Single().ToSql());
var compute = AssertNode<ComputeScalarNode>(sort.Source);
Assert.AreEqual("name + 'foo'", compute.Columns["Expr1"].ToSql());
var fetch = AssertNode<FetchXmlScan>(compute.Source);
AssertFetchXml(fetch, @"
<fetch>
<entity name='account'>
<attribute name='name' />
</entity>
</fetch>");
}

[TestMethod]
public void DistinctOrderByUsesScalarAggregate()
{
var metadata = new AttributeMetadataCache(_service);
var planBuilder = new ExecutionPlanBuilder(metadata, new StubTableSizeCache(), this);

var query = "SELECT DISTINCT name + 'foo' FROM account ORDER BY 1";

var plans = planBuilder.Build(query);

Assert.AreEqual(1, plans.Length);

var select = AssertNode<SelectNode>(plans[0]);
Assert.AreEqual("Expr1", select.ColumnSet.Single().SourceColumn);
var aggregate = AssertNode<StreamAggregateNode>(select.Source);
Assert.AreEqual("Expr1", aggregate.GroupBy.Single().ToSql());
var sort = AssertNode<SortNode>(aggregate.Source);
Assert.AreEqual("Expr1", sort.Sorts.Single().ToSql());
var compute = AssertNode<ComputeScalarNode>(sort.Source);
Assert.AreEqual("name + 'foo'", compute.Columns["Expr1"].ToSql());
var fetch = AssertNode<FetchXmlScan>(compute.Source);
AssertFetchXml(fetch, @"
<fetch>
<entity name='account'>
<attribute name='name' />
</entity>
</fetch>");
}
}
}
48 changes: 0 additions & 48 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlan/BaseAggregateNode.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,54 +21,6 @@ namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan
/// </summary>
abstract class BaseAggregateNode : BaseDataNode, ISingleSourceExecutionPlanNode
{
protected class GroupingKey
{
private readonly Lazy<int> _hashCode;

public GroupingKey(Entity entity, List<string> columns)
{
Values = columns.Select(col => entity[col]).ToList();

_hashCode = new Lazy<int>(() =>
{
var hashCode = 0;

foreach (var value in Values)
{
if (value == null)
continue;

hashCode ^= value.GetHashCode();
}

return hashCode;
});
}

public List<object> Values { get; }

public override int GetHashCode() => _hashCode.Value;

public override bool Equals(object obj)
{
var other = (GroupingKey)obj;

for (var i = 0; i < Values.Count; i++)
{
if (Values[i] == null && other.Values[i] == null)
continue;

if (Values[i] == null || other.Values[i] == null)
return false;

if (!StringComparer.OrdinalIgnoreCase.Equals(Values[i], other.Values[i]))
return false;
}

return true;
}
}

protected class AggregateFunctionState
{
public AggregateFunction AggregateFunction { get; set; }
Expand Down
71 changes: 71 additions & 0 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlan/CompoundKey.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using System;
using System.Collections.Generic;
using System.Data.SqlTypes;
using System.Linq;
using System.Text;
using Microsoft.Xrm.Sdk;

namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan
{
/// <summary>
/// Compares key fields for sorting, grouping or distinctness. null values are treated as equal
/// </summary>
class CompoundKey
{
private readonly Lazy<int> _hashCode;

/// <summary>
/// Extracts a compound key from a <see cref="Entity"/>
/// </summary>
/// <param name="entity">The <see cref="Entity"/> to extract the compound key from</param>
/// <param name="columns">The columns that form the compound key</param>
public CompoundKey(Entity entity, List<string> columns)
{
Values = new object[columns.Count];

for (var i = 0; i < columns.Count; i++)
Values[i] = entity[columns[i]];

_hashCode = new Lazy<int>(() =>
{
var hashCode = 0;

foreach (var value in Values)
{
if (value == null)
continue;

hashCode ^= value.GetHashCode();
}

return hashCode;
});
}

public object[] Values { get; }

public override int GetHashCode() => _hashCode.Value;

public override bool Equals(object obj)
{
var other = (CompoundKey)obj;

for (var i = 0; i < Values.Length; i++)
{
var xNullable = (INullable)Values[i];
var yNullable = (INullable)other.Values[i];
if (xNullable.IsNull && yNullable.IsNull)
continue;

if (xNullable.IsNull || yNullable.IsNull)
return false;

var xComparable = (IComparable)Values[i];
if (xComparable.CompareTo(other.Values[i]) != 0)
return false;
}

return true;
}
}
}
80 changes: 32 additions & 48 deletions MarkMpn.Sql4Cds.Engine/ExecutionPlan/DistinctNode.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data.SqlTypes;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
Expand All @@ -15,51 +16,6 @@ namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan
/// </summary>
class DistinctNode : BaseDataNode, ISingleSourceExecutionPlanNode
{
class DistinctKey
{
private List<object> _values;
private readonly int _hashCode;

public DistinctKey(Entity entity, List<string> columns)
{
_values = columns.Select(col => entity[col]).ToList();

_hashCode = 0;

foreach (var val in _values)
{
if (val == null)
continue;

_hashCode ^= StringComparer.CurrentCultureIgnoreCase.GetHashCode(val);
}
}

public override int GetHashCode()
{
return _hashCode;
}

public override bool Equals(object obj)
{
var other = (DistinctKey)obj;

for (var i = 0; i < _values.Count; i++)
{
if (_values[i] == null && other._values[i] == null)
continue;

if (_values[i] == null || other._values[i] == null)
return false;

if (StringComparer.CurrentCultureIgnoreCase.Compare(_values[i], other._values[i]) != 0)
return false;
}

return true;
}
}

/// <summary>
/// The columns to consider
/// </summary>
Expand All @@ -75,11 +31,11 @@ public override bool Equals(object obj)

protected override IEnumerable<Entity> ExecuteInternal(IDictionary<string, DataSource> dataSources, IQueryExecutionOptions options, IDictionary<string, Type> parameterTypes, IDictionary<string, object> parameterValues)
{
var distinct = new HashSet<DistinctKey>();
var distinct = new HashSet<CompoundKey>();

foreach (var entity in Source.Execute(dataSources, options, parameterTypes, parameterValues))
{
var key = new DistinctKey(entity, Columns);
var key = new CompoundKey(entity, Columns);

if (distinct.Add(key))
yield return entity;
Expand Down Expand Up @@ -176,7 +132,35 @@ public override IDataExecutionPlanNode FoldQuery(IDictionary<string, DataSource>
return fetch;
}

return this;
// If the data is already sorted by all the distinct columns we can use a stream aggregate instead.
// We don't mind what order the columns are sorted in though, so long as the distinct columns form a
// prefix of the sort order.
var requiredSorts = new HashSet<string>(StringComparer.OrdinalIgnoreCase);

foreach (var col in Columns)
{
if (!schema.ContainsColumn(col, out var column))
return this;

requiredSorts.Add(column);
}

if (requiredSorts.Count > schema.SortOrder.Count)
return this;

for (var i = 0; i < requiredSorts.Count; i++)
{
if (!requiredSorts.Contains(schema.SortOrder[i]))
return this;
}

var aggregate = new StreamAggregateNode { Source = Source };
Source.Parent = aggregate;

for (var i = 0; i < requiredSorts.Count; i++)
aggregate.GroupBy.Add(schema.SortOrder[i].ToColumnReference());

return aggregate;
}

public override void AddRequiredColumns(IDictionary<string, DataSource> dataSources, IDictionary<string, Type> parameterTypes, IList<string> requiredColumns)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class HashMatchAggregateNode : BaseAggregateNode

protected override IEnumerable<Entity> ExecuteInternal(IDictionary<string, DataSource> dataSources, IQueryExecutionOptions options, IDictionary<string, Type> parameterTypes, IDictionary<string, object> parameterValues)
{
var groups = new Dictionary<GroupingKey, Dictionary<string, AggregateFunctionState>>();
var groups = new Dictionary<CompoundKey, Dictionary<string, AggregateFunctionState>>();
var schema = Source.GetSchema(dataSources, parameterTypes);
var groupByCols = GetGroupingColumns(schema);

Expand All @@ -34,7 +34,7 @@ protected override IEnumerable<Entity> ExecuteInternal(IDictionary<string, DataS

foreach (var entity in Source.Execute(dataSources, options, parameterTypes, parameterValues))
{
var key = new GroupingKey(entity, groupByCols);
var key = new CompoundKey(entity, groupByCols);

if (!groups.TryGetValue(key, out var values))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public override void AddRequiredColumns(IDictionary<string, DataSource> dataSour

protected override IEnumerable<Entity> ExecuteInternal(IDictionary<string, DataSource> dataSources, IQueryExecutionOptions options, IDictionary<string, Type> parameterTypes, IDictionary<string, object> parameterValues)
{
var groups = new Dictionary<GroupingKey, Dictionary<string, AggregateFunctionState>>();
var groups = new Dictionary<CompoundKey, Dictionary<string, AggregateFunctionState>>();
var schema = Source.GetSchema(dataSources, parameterTypes);
var groupByCols = GetGroupingColumns(schema);

Expand Down Expand Up @@ -181,7 +181,7 @@ private void SplitPartition(Partition partition)
});
}

private void ExecuteAggregate(IDictionary<string, DataSource> dataSources, IQueryExecutionOptions options, IDictionary<string, Type> parameterTypes, IDictionary<string, object> parameterValues, Dictionary<string, AggregateFunction> aggregates, Dictionary<GroupingKey, Dictionary<string, AggregateFunctionState>> groups, List<string> groupByCols, FetchXmlScan fetchXmlNode, SqlDateTime minValue, SqlDateTime maxValue)
private void ExecuteAggregate(IDictionary<string, DataSource> dataSources, IQueryExecutionOptions options, IDictionary<string, Type> parameterTypes, IDictionary<string, object> parameterValues, Dictionary<string, AggregateFunction> aggregates, Dictionary<CompoundKey, Dictionary<string, AggregateFunctionState>> groups, List<string> groupByCols, FetchXmlScan fetchXmlNode, SqlDateTime minValue, SqlDateTime maxValue)
{
parameterValues["@PartitionStart"] = minValue;
parameterValues["@PartitionEnd"] = maxValue;
Expand All @@ -191,7 +191,7 @@ private void ExecuteAggregate(IDictionary<string, DataSource> dataSources, IQuer
foreach (var entity in results)
{
// Update aggregates
var key = new GroupingKey(entity, groupByCols);
var key = new CompoundKey(entity, groupByCols);

if (!groups.TryGetValue(key, out var values))
{
Expand Down
Loading

0 comments on commit a52cb99

Please sign in to comment.