Skip to content

Commit

Permalink
Implement sum aggregation for decimal in SQLite
Browse files Browse the repository at this point in the history
Contributes to dotnet#19635
  • Loading branch information
ranma42 committed May 15, 2024
1 parent 5ffdda0 commit 807740a
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
Expand Down Expand Up @@ -99,8 +99,14 @@ public SqliteQueryableAggregateMethodTranslator(ISqlExpressionFactory sqlExpress
var sumArgumentType = GetProviderType(sumSqlExpression);
if (sumArgumentType == typeof(decimal))
{
throw new NotSupportedException(
SqliteStrings.AggregateOperationNotSupported(nameof(Queryable.Sum), sumArgumentType.ShortDisplayName()));
sumSqlExpression = CombineTerms(source, sumSqlExpression);
return _sqlExpressionFactory.Function(
"ef_sum",
[sumSqlExpression],
nullable: true,
argumentsPropagateNullability: [false],
sumSqlExpression.Type,
sumSqlExpression.TypeMapping);
}

break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,30 @@ protected virtual SqlExpression VisitRegexp(
return regexpExpression.Update(match, pattern);
}

/// <inheritdoc/>
protected override SqlExpression VisitSqlFunction(
SqlFunctionExpression sqlFunctionExpression,
bool allowOptimizedExpansion,
out bool nullable)
{
var result = base.VisitSqlFunction(sqlFunctionExpression, allowOptimizedExpansion, out nullable);

if (result is SqlFunctionExpression resultFunctionExpression
&& resultFunctionExpression.IsBuiltIn
&& string.Equals(resultFunctionExpression.Name, "ef_sum", StringComparison.OrdinalIgnoreCase))
{
nullable = false;

var sqlExpressionFactory = Dependencies.SqlExpressionFactory;
return sqlExpressionFactory.Coalesce(
result,
sqlExpressionFactory.Constant(0, resultFunctionExpression.TypeMapping),
resultFunctionExpression.TypeMapping);
}

return result;
}

#pragma warning disable EF1001
/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ private void InitializeDbConnection(DbConnection connection)
name: "ef_negate",
(decimal? m) => -m,
isDeterministic: true);

sqliteConnection.CreateAggregate(
"ef_sum",
seed: null,
(decimal? sum, decimal? value) => value is null
? sum
: sum is null ? value : sum.Value + value.Value,
isDeterministic: true);
}
else
{
Expand Down

0 comments on commit 807740a

Please sign in to comment.