diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs index fb8cdf9e..788fc4a9 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/Aggregate.cs @@ -84,11 +84,27 @@ public virtual void NextRecord(Entity entity) Update(value); } + /// + /// Updates the aggregate function state based on the aggregate values for a partition + /// + /// The that contains aggregated values from a partition of the available records + public virtual void NextPartition(Entity entity) + { + var value = _selector(entity); + UpdatePartition(value); + } + /// /// Updates the aggregation state based on a value extracted from the source /// /// protected abstract void Update(object value); + + /// + /// Updates the aggregation state based on a value extracted from the partition + /// + /// + protected abstract void UpdatePartition(object value); /// /// Returns the current value of this aggregation @@ -160,6 +176,11 @@ protected override void Update(object value) Value = _valueSelector(_sum / _count); } + protected override void UpdatePartition(object value) + { + throw new InvalidOperationException(); + } + public override Type Type { get; } public override void Reset() @@ -188,6 +209,11 @@ protected override void Update(object value) Value = (SqlInt32)Value + 1; } + protected override void UpdatePartition(object value) + { + Value = (SqlInt32)Value + (SqlInt32)value; + } + public override Type Type => typeof(SqlInt32); public override void Reset() @@ -217,6 +243,11 @@ protected override void Update(object value) Value = (SqlInt32)Value + 1; } + protected override void UpdatePartition(object value) + { + Value = (SqlInt32)Value + (SqlInt32)value; + } + public override Type Type => typeof(int); public override void Reset() @@ -251,6 +282,11 @@ protected override void Update(object value) Value = value; } + protected override void UpdatePartition(object value) + { + Update(value); + } + public override Type Type { get; } } @@ -280,6 +316,11 @@ protected override void Update(object value) Value = value; } + protected override void UpdatePartition(object value) + { + Update(value); + } + public override Type Type { get; } } @@ -317,6 +358,11 @@ protected override void Update(object value) Value = _valueSelector(_sumDecimal); } + protected override void UpdatePartition(object value) + { + Update(value); + } + public override Type Type { get; } public override void Reset() @@ -348,6 +394,11 @@ protected override void Update(object value) Value = value; } + protected override void UpdatePartition(object value) + { + throw new InvalidOperationException(); + } + public override Type Type { get; } public override void Reset() @@ -387,6 +438,11 @@ public override void NextRecord(Entity entity) } } + protected override void UpdatePartition(object value) + { + throw new InvalidOperationException(); + } + protected override void Update(object value) { throw new NotImplementedException(); diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs index d3558282..c232811a 100644 --- a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/HashMatchAggregateNode.cs @@ -290,7 +290,11 @@ Source is FetchXmlScan fetch && if (Source is FetchXmlScan || Source is ComputeScalarNode computeScalar && computeScalar.Source is FetchXmlScan) { // Check if all the aggregates & groupings can be done in FetchXML. Can only convert them if they can ALL - // be handled - if any one needs to be calculated manually, we need to calculate them all + // be handled - if any one needs to be calculated manually, we need to calculate them all. Also track if + // we can partition the query for larger source data sets. We can't partition DISTINCT aggregates, and need + // to transform AVG(field) to SUM(field) / COUNT(field) + var canPartition = true; + foreach (var agg in Aggregates) { if (agg.Value.SqlExpression != null && !(agg.Value.SqlExpression is ColumnReferenceExpression)) @@ -301,6 +305,9 @@ Source is FetchXmlScan fetch && if (agg.Value.AggregateType == AggregateType.First) return this; + + if (agg.Value.Distinct) + canPartition = false; } var fetchXml = Source as FetchXmlScan; @@ -518,10 +525,47 @@ Source is FetchXmlScan fetch && // FoldQuery can be called again in some circumstances. Don't repeat the folding operation and create another try/catch _folded = true; + IDataExecutionPlanNode fallback = this; + + if (canPartition) + { + // Create a clone of the aggregate FetchXML query + var partitionedFetchXml = new FetchXmlScan + { + DataSource = fetchXml.DataSource, + Alias = fetchXml.Alias, + AllPages = fetchXml.AllPages, + FetchXml = (FetchXml.FetchType)serializer.Deserialize(new StringReader(fetchXml.FetchXmlString)), + ReturnFullSchema = fetchXml.ReturnFullSchema + }; + + var partitionedAggregates = new PartitionedFetchXmlAggregateNode + { + Source = partitionedFetchXml + }; + + var tryPartitioned = new TryCatchNode + { + TrySource = partitionedAggregates, + CatchSource = fallback, + ExceptionFilter = IsAggregateQueryRetryableException + }; + fallback = tryPartitioned; + + partitionedAggregates.GroupBy.AddRange(GroupBy); + + foreach (var aggregate in Aggregates) + { + // TODO: Clone the aggregate + // TODO: Rewrite AVG + partitionedAggregates.Aggregates[aggregate.Key] = aggregate.Value; + } + } + return new TryCatchNode { TrySource = fetchXml, - CatchSource = this, + CatchSource = fallback, ExceptionFilter = IsAggregateQueryRetryableException }; } diff --git a/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedFetchXmlAggregateNode.cs b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedFetchXmlAggregateNode.cs new file mode 100644 index 00000000..c799eb36 --- /dev/null +++ b/MarkMpn.Sql4Cds.Engine/ExecutionPlan/PartitionedFetchXmlAggregateNode.cs @@ -0,0 +1,483 @@ +using System; +using System.Collections.Generic; +using System.ComponentModel; +using System.Data.SqlTypes; +using System.IO; +using System.Linq; +using System.Numerics; +using System.ServiceModel; +using System.Text; +using System.Threading.Tasks; +using System.Xml.Serialization; +using MarkMpn.Sql4Cds.Engine.FetchXml; +using Microsoft.SqlServer.TransactSql.ScriptDom; +using Microsoft.Xrm.Sdk; + +namespace MarkMpn.Sql4Cds.Engine.ExecutionPlan +{ + /// + /// Produces aggregate values by repeatedly executing a related FetchXML query over multiple partitions + /// and combining the results + /// + class PartitionedFetchXmlAggregateNode : BaseDataNode, ISingleSourceExecutionPlanNode + { + class GroupingKey + { + private readonly int _hashCode; + + public GroupingKey(Entity entity, List columns) + { + Values = columns.Select(col => entity[col]).ToList(); + _hashCode = 0; + + foreach (var value in Values) + { + if (value == null) + continue; + + _hashCode ^= StringComparer.OrdinalIgnoreCase.GetHashCode(value); + } + } + + public List Values { get; } + + public override int GetHashCode() => _hashCode; + + public override bool Equals(object obj) + { + var other = (GroupingKey)obj; + + for (var i = 0; i < Values.Count; i++) + { + if (Values[i] == null && other.Values[i] == null) + continue; + + if (Values[i] == null || other.Values[i] == null) + return false; + + if (!StringComparer.OrdinalIgnoreCase.Equals(Values[i], other.Values[i])) + return false; + } + + return true; + } + } + + /// + /// The list of columns to group the results by + /// + [Category("Partitioned FetchXML Aggregate")] + [Description("The list of columns to group the results by")] + [DisplayName("Group By")] + public List GroupBy { get; } = new List(); + + /// + /// The list of aggregate values to produce + /// + [Category("Partitioned FetchXML Aggregate")] + [Description("The list of aggregate values to produce")] + public Dictionary Aggregates { get; } = new Dictionary(); + + public IDataExecutionPlanNode Source { get; set; } + + public override void AddRequiredColumns(IDictionary dataSources, IDictionary parameterTypes, IList requiredColumns) + { + // Columns required by previous nodes must be derived from this node, so no need to pass them through. + // Just calculate the columns that are required to calculate the groups & aggregates + var scalarRequiredColumns = new List(); + if (GroupBy != null) + scalarRequiredColumns.AddRange(GroupBy.Select(g => g.GetColumnName())); + + scalarRequiredColumns.AddRange(Aggregates.Where(agg => agg.Value.SqlExpression != null).SelectMany(agg => agg.Value.SqlExpression.GetColumns()).Distinct()); + + Source.AddRequiredColumns(dataSources, parameterTypes, scalarRequiredColumns); + } + + public override int EstimateRowsOut(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + { + if (GroupBy.Count == 0) + return 1; + + return Source.EstimateRowsOut(dataSources, options, parameterTypes) * 4 / 10; + } + + public override IDataExecutionPlanNode FoldQuery(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes) + { + Source = Source.FoldQuery(dataSources, options, parameterTypes); + Source.Parent = this; + return this; + } + + public override NodeSchema GetSchema(IDictionary dataSources, IDictionary parameterTypes) + { + var sourceSchema = Source.GetSchema(dataSources, parameterTypes); + var schema = new NodeSchema(); + + foreach (var group in GroupBy) + { + var colName = group.GetColumnName(); + sourceSchema.ContainsColumn(colName, out var normalized); + schema.Schema[normalized] = sourceSchema.Schema[normalized]; + + foreach (var alias in sourceSchema.Aliases.Where(a => a.Value.Contains(normalized))) + { + if (!schema.Aliases.TryGetValue(alias.Key, out var aliases)) + { + aliases = new List(); + schema.Aliases[alias.Key] = aliases; + } + + aliases.Add(normalized); + } + + if (GroupBy.Count == 1) + schema.PrimaryKey = normalized; + } + + foreach (var aggregate in Aggregates) + { + Type aggregateType; + + switch (aggregate.Value.AggregateType) + { + case AggregateType.Count: + case AggregateType.CountStar: + aggregateType = typeof(SqlInt32); + break; + + default: + aggregateType = aggregate.Value.SqlExpression.GetType(sourceSchema, null, parameterTypes); + break; + } + + schema.Schema[aggregate.Key] = aggregateType; + } + + return schema; + } + + public override IEnumerable GetSources() + { + yield return Source; + } + + protected override IEnumerable ExecuteInternal(IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues) + { + var groups = new Dictionary>(); + var schema = Source.GetSchema(dataSources, parameterTypes); + var groupByCols = GroupBy + .Select(col => + { + var colName = col.GetColumnName(); + schema.ContainsColumn(colName, out colName); + return colName; + }) + .ToList(); + + foreach (var aggregate in Aggregates.Where(agg => agg.Value.SqlExpression != null)) + { + var sourceExpression = aggregate.Value.SqlExpression; + + // Sum and Aggregate need to have Decimal values as input for their calculations to work correctly + if (aggregate.Value.AggregateType == AggregateType.Average || aggregate.Value.AggregateType == AggregateType.Sum) + sourceExpression = new ConvertCall { Parameter = sourceExpression, DataType = new SqlDataTypeReference { SqlDataTypeOption = SqlDataTypeOption.Decimal } }; + + aggregate.Value.Expression = sourceExpression.Compile(schema, parameterTypes); + + aggregate.Value.ReturnType = aggregate.Value.SqlExpression.GetType(schema, null, parameterTypes); + + if (aggregate.Value.AggregateType == AggregateType.Average) + { + if (aggregate.Value.ReturnType == typeof(SqlByte) || aggregate.Value.ReturnType == typeof(SqlInt16)) + aggregate.Value.ReturnType = typeof(SqlInt32); + } + } + + var fetchXmlNode = (FetchXmlScan)Source; + + // Get the minimum and maximum primary keys from the source + var minKey = GetMinMaxKey(fetchXmlNode, dataSources, options, parameterTypes, parameterValues, false); + var maxKey = GetMinMaxKey(fetchXmlNode, dataSources, options, parameterTypes, parameterValues, true); + + if (minKey.IsNull || maxKey.IsNull || minKey == maxKey) + throw new QueryExecutionException("Cannot partition query"); + + // Add the filter to the FetchXML to partition the results + var metadata = dataSources[fetchXmlNode.DataSource].Metadata[fetchXmlNode.Entity.name]; + fetchXmlNode.Entity.AddItem(new filter + { + Items = new object[] + { + new condition { attribute = metadata.PrimaryIdAttribute, @operator = @operator.ge, value = "@PartitionStart" }, + new condition { attribute = metadata.PrimaryIdAttribute, @operator = @operator.le, value = "@PartitionEnd" } + } + }); + + var partitionParameterTypes = new Dictionary + { + ["@PartitionStart"] = typeof(SqlGuid), + ["@PartitionEnd"] = typeof(SqlGuid) + }; + + var partionParameterValues = new Dictionary + { + ["@PartitionStart"] = minKey, + ["@PartitionEnd"] = maxKey + }; + + if (parameterTypes != null) + { + foreach (var kvp in parameterTypes) + partitionParameterTypes[kvp.Key] = kvp.Value; + } + + if (parameterValues != null) + { + foreach (var kvp in parameterValues) + partionParameterValues[kvp.Key] = kvp.Value; + } + + var minValue = GuidToNumber(minKey.Value); + var maxValue = GuidToNumber(maxKey.Value); + + while (minValue <= maxValue) + { + // Repeatedly split the primary key space until an aggregate query returns without error + // Quit with an error after 10 iterations + const int maxSplits = 10; + var split = minValue + (maxValue - minValue) / 2; + + for (var i = 0; i < maxSplits; i++) + { + try + { + // Execute the query with the partition minValue -> split + var results = fetchXmlNode.Execute(dataSources, options, partitionParameterTypes, partionParameterValues); + + foreach (var entity in results) + { + // Update aggregates + var key = new GroupingKey(entity, groupByCols); + + if (!groups.TryGetValue(key, out var values)) + { + values = new Dictionary(); + + foreach (var aggregate in Aggregates) + { + Func selector = null; + + if (aggregate.Value.AggregateType != AggregateType.CountStar) + selector = e => aggregate.Value.Expression(e, parameterValues, options); + + switch (aggregate.Value.AggregateType) + { + case AggregateType.Average: + throw new QueryExecutionException("Average aggregate not supported for partitions"); + + case AggregateType.Count: + values[aggregate.Key] = new CountColumn(selector); + break; + + case AggregateType.CountStar: + values[aggregate.Key] = new Count(null); + break; + + case AggregateType.Max: + values[aggregate.Key] = new Max(selector, aggregate.Value.ReturnType); + break; + + case AggregateType.Min: + values[aggregate.Key] = new Min(selector, aggregate.Value.ReturnType); + break; + + case AggregateType.Sum: + values[aggregate.Key] = new Sum(selector, aggregate.Value.ReturnType); + break; + + default: + throw new QueryExecutionException("Unknown aggregate type"); + } + + if (aggregate.Value.Distinct) + throw new QueryExecutionException("Distinct aggregates not supported for partitions"); + + values[aggregate.Key].Reset(); + } + + groups[key] = values; + } + + foreach (var func in values.Values) + func.NextPartition(entity); + } + + break; + } + catch (Exception ex) + { + if (i == maxSplits - 1) + throw; + + if (!IsAggregateQueryRecordLimitExceeded(ex)) + throw; + } + + split = minValue + (split - minValue) / 2; + } + + // Update minimum primary key and repeat to process next partition + minValue = split + 1; + } + + foreach (var group in groups) + { + var result = new Entity(); + + for (var i = 0; i < GroupBy.Count; i++) + result[groupByCols[i]] = group.Key.Values[i]; + + foreach (var aggregate in group.Value) + result[aggregate.Key] = aggregate.Value.Value; + + yield return result; + } + } + + private static readonly int[] x_rgiGuidOrder = new int[16] + { + 5, + 4, + 3, + 2, + 1, + 0, + 7, + 6, + 9, + 8, + 11, + 10, + 15, + 14, + 13, + 12 + }; + + private BigInteger GuidToNumber(Guid guid) + { + var bytes = guid.ToByteArray(); + + // Shuffle the bytes into order of their significance. BigInteger uses little-endian + var shuffled = new byte[16]; + for (var i = 0; i < bytes.Length; i++) + shuffled[i] = bytes[x_rgiGuidOrder[i]]; + + var value = new BigInteger(shuffled); + return value; + } + + private Guid NumberToGuid(BigInteger integer) + { + var bytes = integer.ToByteArray(); + + var shuffled = new byte[16]; + for (var i = 0; i < bytes.Length; i++) + shuffled[x_rgiGuidOrder[i]] = bytes[i]; + + var value = new Guid(bytes); + return value; + } + + private bool IsAggregateQueryRecordLimitExceeded(Exception ex) + { + if (ex is QueryExecutionException qee) + ex = qee.InnerException; + + if (!(ex is FaultException faultEx)) + return false; + + var fault = faultEx.Detail; + while (fault.InnerFault != null) + fault = fault.InnerFault; + + /* + * 0x8004E023 / -2147164125 + * Name: AggregateQueryRecordLimitExceeded + * Message: The maximum record limit is exceeded. Reduce the number of records. + */ + if (fault.ErrorCode == -2147164125) + return true; + + return false; + } + + private SqlGuid GetMinMaxKey(FetchXmlScan fetchXmlNode, IDictionary dataSources, IQueryExecutionOptions options, IDictionary parameterTypes, IDictionary parameterValues, bool max) + { + // Create a new FetchXmlScan node with a copy of the original query + var minMaxNode = new FetchXmlScan + { + Alias = "minmax", + DataSource = fetchXmlNode.DataSource, + FetchXml = CloneFetchXml(fetchXmlNode.FetchXml) + }; + + // Remove the aggregate settings and all attributes from the query + minMaxNode.FetchXml.aggregate = false; + RemoveAttributes(minMaxNode.Entity); + + // Add the primary key attribute of the root entity + var metadata = dataSources[minMaxNode.DataSource].Metadata[minMaxNode.Entity.name]; + minMaxNode.Entity.AddItem(new FetchAttributeType { name = metadata.PrimaryIdAttribute }); + + // Sort by the primary key + minMaxNode.Entity.AddItem(new FetchOrderType { attribute = metadata.PrimaryIdAttribute, descending = max }); + + // Only need to retrieve the first item + minMaxNode.FetchXml.top = "1"; + + var result = minMaxNode.Execute(dataSources, options, parameterTypes, parameterValues).FirstOrDefault(); + + if (result == null) + return SqlGuid.Null; + + return (SqlEntityReference)result[$"minmax.{metadata.PrimaryIdAttribute}"]; + } + + private void RemoveAttributes(FetchXml.FetchEntityType entity) + { + if (entity.Items != null) + { + entity.Items = entity.Items.Where(o => !(o is FetchAttributeType) && !(o is allattributes)).ToArray(); + + foreach (var linkEntity in entity.Items.OfType()) + RemoveAttributes(linkEntity); + } + } + + private void RemoveAttributes(FetchXml.FetchLinkEntityType entity) + { + if (entity.Items != null) + { + entity.Items = entity.Items.Where(o => !(o is FetchAttributeType) && !(o is allattributes)).ToArray(); + + foreach (var linkEntity in entity.Items.OfType()) + RemoveAttributes(linkEntity); + } + } + + private FetchXml.FetchType CloneFetchXml(FetchXml.FetchType fetchXml) + { + var serializer = new XmlSerializer(typeof(FetchXml.FetchType)); + using (var writer = new StringWriter()) + { + serializer.Serialize(writer, fetchXml); + + using (var reader = new StringReader(writer.ToString())) + { + return (FetchXml.FetchType)serializer.Deserialize(reader); + } + } + } + } +} diff --git a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.csproj b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.csproj index 9a73d004..bfab3496 100644 --- a/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.csproj +++ b/MarkMpn.Sql4Cds.Engine/MarkMpn.Sql4Cds.Engine.csproj @@ -43,6 +43,7 @@ + @@ -98,6 +99,7 @@ +