diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs
index fb8cdf9e..788fc4a9 100644
--- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs
+++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs
@@ -84,11 +84,27 @@ public virtual void NextRecord(Entity entity)
Update(value);
}
+ ///
+ /// Updates the aggregate function state based on the aggregate values for a partition
+ ///
+ /// The that contains aggregated values from a partition of the available records
+ public virtual void NextPartition(Entity entity)
+ {
+ var value = _selector(entity);
+ UpdatePartition(value);
+ }
+
///
/// Updates the aggregation state based on a value extracted from the source
///
///
protected abstract void Update(object value);
+
+ ///
+ /// Updates the aggregation state based on a value extracted from the partition
+ ///
+ ///
+ protected abstract void UpdatePartition(object value);
///
/// Returns the current value of this aggregation
@@ -160,6 +176,11 @@ protected override void Update(object value)
Value = _valueSelector(_sum / _count);
}
+ protected override void UpdatePartition(object value)
+ {
+ throw new InvalidOperationException();
+ }
+
public override Type Type { get; }
public override void Reset()
@@ -188,6 +209,11 @@ protected override void Update(object value)
Value = (SqlInt32)Value + 1;
}
+ protected override void UpdatePartition(object value)
+ {
+ Value = (SqlInt32)Value + (SqlInt32)value;
+ }
+
public override Type Type => typeof(SqlInt32);
public override void Reset()
@@ -217,6 +243,11 @@ protected override void Update(object value)
Value = (SqlInt32)Value + 1;
}
+ protected override void UpdatePartition(object value)
+ {
+ Value = (SqlInt32)Value + (SqlInt32)value;
+ }
+
public override Type Type => typeof(int);
public override void Reset()
@@ -251,6 +282,11 @@ protected override void Update(object value)
Value = value;
}
+ protected override void UpdatePartition(object value)
+ {
+ Update(value);
+ }
+
public override Type Type { get; }
}
@@ -280,6 +316,11 @@ protected override void Update(object value)
Value = value;
}
+ protected override void UpdatePartition(object value)
+ {
+ Update(value);
+ }
+
public override Type Type { get; }
}
@@ -317,6 +358,11 @@ protected override void Update(object value)
Value = _valueSelector(_sumDecimal);
}
+ protected override void UpdatePartition(object value)
+ {
+ Update(value);
+ }
+
public override Type Type { get; }
public override void Reset()
@@ -348,6 +394,11 @@ protected override void Update(object value)
Value = value;
}
+ protected override void UpdatePartition(object value)
+ {
+ throw new InvalidOperationException();
+ }
+
public override Type Type { get; }
public override void Reset()
@@ -387,6 +438,11 @@ public override void NextRecord(Entity entity)
}
}
+ protected override void UpdatePartition(object value)
+ {
+ throw new InvalidOperationException();
+ }
+
protected override void Update(object value)
{
throw new NotImplementedException();
diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs
index d3558282..c232811a 100644
--- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs
+++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs
@@ -290,7 +290,11 @@ Source is FetchXmlScan fetch &&
if (Source is FetchXmlScan || Source is ComputeScalarNode computeScalar && computeScalar.Source is FetchXmlScan)
{
// Check if all the aggregates & groupings can be done in FetchXML. Can only convert them if they can ALL
- // be handled - if any one needs to be calculated manually, we need to calculate them all
+ // be handled - if any one needs to be calculated manually, we need to calculate them all. Also track if
+ // we can partition the query for larger source data sets. We can't partition DISTINCT aggregates, and need
+ // to transform AVG(field) to SUM(field) / COUNT(field)
+ var canPartition = true;
+
foreach (var agg in Aggregates)
{
if (agg.Value.SqlExpression != null && !(agg.Value.SqlExpression is ColumnReferenceExpression))
@@ -301,6 +305,9 @@ Source is FetchXmlScan fetch &&
if (agg.Value.AggregateType == AggregateType.First)
return this;
+
+ if (agg.Value.Distinct)
+ canPartition = false;
}
var fetchXml = Source as FetchXmlScan;
@@ -518,10 +525,47 @@ Source is FetchXmlScan fetch &&
// FoldQuery can be called again in some circumstances. Don't repeat the folding operation and create another try/catch
_folded = true;
+ IDataExecutionPlanNode fallback = this;
+
+ if (canPartition)
+ {
+ // Create a clone of the aggregate FetchXML query
+ var partitionedFetchXml = new FetchXmlScan
+ {
+ DataSource = fetchXml.DataSource,
+ Alias = fetchXml.Alias,
+ AllPages = fetchXml.AllPages,
+ FetchXml = (FetchXml.FetchType)serializer.Deserialize(new StringReader(fetchXml.FetchXmlString)),
+ ReturnFullSchema = fetchXml.ReturnFullSchema
+ };
+
+ var partitionedAggregates = new PartitionedFetchXmlAggregateNode
+ {
+ Source = partitionedFetchXml
+ };
+
+ var tryPartitioned = new TryCatchNode
+ {
+ TrySource = partitionedAggregates,
+ CatchSource = fallback,
+ ExceptionFilter = IsAggregateQueryRetryableException
+ };
+ fallback = tryPartitioned;
+
+ partitionedAggregates.GroupBy.AddRange(GroupBy);
+
+ foreach (var aggregate in Aggregates)
+ {
+ // TODO: Clone the aggregate
+ // TODO: Rewrite AVG
+ partitionedAggregates.Aggregates[aggregate.Key] = aggregate.Value;
+ }
+ }
+
return new TryCatchNode
{
TrySource = fetchXml,
- CatchSource = this,
+ CatchSource = fallback,
ExceptionFilter = IsAggregateQueryRetryableException
};
}
diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedFetchXmlAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedFetchXmlAggregateNode.cs
new file mode 100644
index 00000000..c799eb36
--- /dev/null
+++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedFetchXmlAggregateNode.cs
@@ -0,0 +1,483 @@
+using System;
+using System.Collections.Generic;
+using System.ComponentModel;
+using System.Data.SqlTypes;
+using System.IO;
+using System.Linq;
+using System.Numerics;
+using System.ServiceModel;
+using System.Text;
+using System.Threading.Tasks;
+using System.Xml.Serialization;
+using MarkMpn.Sql4Cds.Engine.FetchXml;
+using Microsoft.SqlServer.TransactSql.ScriptDom;
+using Microsoft.Xrm.Sdk;
+
+namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan
+{
+ ///
+ /// Produces aggregate values by repeatedly executing a related FetchXML query over multiple partitions
+ /// and combining the results
+ ///
+ class PartitionedFetchXmlAggregateNode : BaseDataNode, ISingleSourceExecutionPlanNode
+ {
+ class GroupingKey
+ {
+ private readonly int _hashCode;
+
+ public GroupingKey(Entity entity, List columns)
+ {
+ Values = columns.Select(col => entity[col]).ToList();
+ _hashCode = 0;
+
+ foreach (var value in Values)
+ {
+ if (value == null)
+ continue;
+
+ _hashCode ^= StringComparer.OrdinalIgnoreCase.GetHashCode(value);
+ }
+ }
+
+ public List