From f1181d9b6a1afa8bfd94e45c53144ff12a5e30d4 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Thu, 30 Nov 2023 01:16:41 +0100 Subject: [PATCH] Nullability-related fixes to LEAST/GREATEST Fixup to #32338 --- .../Query/ISqlExpressionFactory.cs | 28 --- ...yableMethodTranslatingExpressionVisitor.cs | 40 +++- ...lationalSqlTranslatingExpressionVisitor.cs | 214 +++++------------- .../Query/SqlExpressionFactory.cs | 61 ----- .../Internal/SqlServerSqlExpressionFactory.cs | 30 --- ...qlServerSqlTranslatingExpressionVisitor.cs | 44 ++++ .../Translators/SqlServerMathTranslator.cs | 31 +-- .../Internal/SqliteSqlExpressionFactory.cs | 61 ----- .../SqliteSqlTranslatingExpressionVisitor.cs | 32 +++ .../Translators/SqliteMathTranslator.cs | 21 +- ...tionalNorthwindDbFunctionsQueryTestBase.cs | 16 ++ .../PrimitiveCollectionsQueryTestBase.cs | 44 ++++ .../NorthwindDbFunctionsQuerySqlServerTest.cs | 26 +++ ...imitiveCollectionsQueryOldSqlServerTest.cs | 60 +++++ .../PrimitiveCollectionsQuerySqlServerTest.cs | 56 +++++ .../PrimitiveCollectionsQuerySqliteTest.cs | 60 +++++ 16 files changed, 432 insertions(+), 392 deletions(-) diff --git a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs index 5cf4f57d573..cc7958eb0c4 100644 --- a/src/EFCore.Relational/Query/ISqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/ISqlExpressionFactory.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics.CodeAnalysis; -using System.Runtime.CompilerServices; -using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Query.SqlExpressions; namespace Microsoft.EntityFrameworkCore.Query; @@ -439,30 +437,4 @@ SqlExpression NiladicFunction( /// A string token to print in SQL tree. /// An expression representing a SQL token. SqlExpression Fragment(string sql); - - /// - /// Attempts to creates a new expression that returns the smallest value from a list of expressions, e.g. an invocation of the - /// LEAST SQL function. - /// - /// An entity type to project. - /// The result CLR type for the returned expression. - /// The expression which computes the smallest value. - /// if the expression could be created, otherwise. - bool TryCreateLeast( - IReadOnlyList expressions, - Type resultType, - [NotNullWhen(true)] out SqlExpression? leastExpression); - - /// - /// Attempts to creates a new expression that returns the greatest value from a list of expressions, e.g. an invocation of the - /// GREATEST SQL function. - /// - /// An entity type to project. - /// The result CLR type for the returned expression. - /// The expression which computes the greatest value. - /// if the expression could be created, otherwise. - bool TryCreateGreatest( - IReadOnlyList expressions, - Type resultType, - [NotNullWhen(true)] out SqlExpression? greatestExpression); } diff --git a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs index 20115136354..650abbd4892 100644 --- a/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalQueryableMethodTranslatingExpressionVisitor.cs @@ -976,19 +976,39 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK /// protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TryExtractBareInlineCollectionValues(source, out var values) - && _sqlExpressionFactory.TryCreateGreatest(values, resultType, out var greatestExpression) - ? source.Update(new SelectExpression(greatestExpression, _sqlAliasManager), source.ShaperExpression) - : TranslateAggregateWithSelector( - source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType); + { + // For Max() over an inline array, translate to GREATEST() if possible; otherwise use the default translation of aggregate SQL + // MAX(). + // Note that some providers propagate NULL arguments (SQLite, MySQL), while others only return NULL if all arguments evaluate to + // NULL (SQL Server, PostgreSQL). If the argument is a nullable value type, don't translate to GREATEST() if it propagates NULLs, + // to match the .NET behavior. + if (TryExtractBareInlineCollectionValues(source, out var values) + && _sqlTranslator.GenerateGreatest(values, resultType.UnwrapNullableType()) is SqlFunctionExpression greatestExpression + && (Nullable.GetUnderlyingType(resultType) is null + || greatestExpression.ArgumentsPropagateNullability?.All(a => a == false) == true)) + { + return source.Update(new SelectExpression(greatestExpression, _sqlAliasManager), source.ShaperExpression); + } + + return TranslateAggregateWithSelector( + source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType); + } /// protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType) - => TryExtractBareInlineCollectionValues(source, out var values) - && _sqlExpressionFactory.TryCreateLeast(values, resultType, out var leastExpression) - ? source.Update(new SelectExpression(leastExpression, _sqlAliasManager), source.ShaperExpression) - : TranslateAggregateWithSelector( - source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType); + { + // See comments above in TranslateMax() + if (TryExtractBareInlineCollectionValues(source, out var values) + && _sqlTranslator.GenerateLeast(values, resultType.UnwrapNullableType()) is SqlFunctionExpression leastExpression + && (Nullable.GetUnderlyingType(resultType) is null + || leastExpression.ArgumentsPropagateNullability?.All(a => a == false) == true)) + { + return source.Update(new SelectExpression(leastExpression, _sqlAliasManager), source.ShaperExpression); + } + + return TranslateAggregateWithSelector( + source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType); + } /// protected override ShapedQueryExpression? TranslateOfType(ShapedQueryExpression source, Type resultType) diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 61abb3e145a..60f2f3825dd 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -52,12 +52,6 @@ private static readonly MethodInfo StringEqualsWithStringComparison private static readonly MethodInfo StringEqualsWithStringComparisonStatic = typeof(string).GetRuntimeMethod(nameof(string.Equals), [typeof(string), typeof(string), typeof(StringComparison)])!; - private static readonly MethodInfo LeastMethodInfo - = typeof(RelationalDbFunctionsExtensions).GetMethod(nameof(RelationalDbFunctionsExtensions.Least))!; - - private static readonly MethodInfo GreatestMethodInfo - = typeof(RelationalDbFunctionsExtensions).GetMethod(nameof(RelationalDbFunctionsExtensions.Greatest))!; - private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(GetType))!; private readonly QueryCompilationContext _queryCompilationContext; @@ -183,138 +177,6 @@ protected virtual void AddTranslationErrorDetails(string details) return result; } - /// - /// Translates Average over an expression to an equivalent SQL representation. - /// - /// An expression to translate Average over. - /// A SQL translation of Average over the given expression. - [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] - public virtual SqlExpression? TranslateAverage(SqlExpression sqlExpression) - { - var inputType = sqlExpression.Type; - if (inputType == typeof(int) - || inputType == typeof(long)) - { - sqlExpression = sqlExpression is DistinctExpression distinctExpression - ? new DistinctExpression( - _sqlExpressionFactory.ApplyDefaultTypeMapping( - _sqlExpressionFactory.Convert(distinctExpression.Operand, typeof(double)))) - : _sqlExpressionFactory.ApplyDefaultTypeMapping( - _sqlExpressionFactory.Convert(sqlExpression, typeof(double))); - } - - return inputType == typeof(float) - ? _sqlExpressionFactory.Convert( - _sqlExpressionFactory.Function( - "AVG", - new[] { sqlExpression }, - nullable: true, - argumentsPropagateNullability: new[] { false }, - typeof(double)), - sqlExpression.Type, - sqlExpression.TypeMapping) - : _sqlExpressionFactory.Function( - "AVG", - new[] { sqlExpression }, - nullable: true, - argumentsPropagateNullability: new[] { false }, - sqlExpression.Type, - sqlExpression.TypeMapping); - } - - /// - /// Translates Count over an expression to an equivalent SQL representation. - /// - /// An expression to translate Count over. - /// A SQL translation of Count over the given expression. - [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] - public virtual SqlExpression? TranslateCount(SqlExpression sqlExpression) - => _sqlExpressionFactory.ApplyDefaultTypeMapping( - _sqlExpressionFactory.Function( - "COUNT", - new[] { sqlExpression }, - nullable: false, - argumentsPropagateNullability: new[] { false }, - typeof(int))); - - /// - /// Translates LongCount over an expression to an equivalent SQL representation. - /// - /// An expression to translate LongCount over. - /// A SQL translation of LongCount over the given expression. - [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] - public virtual SqlExpression? TranslateLongCount(SqlExpression sqlExpression) - => _sqlExpressionFactory.ApplyDefaultTypeMapping( - _sqlExpressionFactory.Function( - "COUNT", - new[] { sqlExpression }, - nullable: false, - argumentsPropagateNullability: new[] { false }, - typeof(long))); - - /// - /// Translates Max over an expression to an equivalent SQL representation. - /// - /// An expression to translate Max over. - /// A SQL translation of Max over the given expression. - [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] - public virtual SqlExpression? TranslateMax(SqlExpression sqlExpression) - => sqlExpression != null - ? _sqlExpressionFactory.Function( - "MAX", - new[] { sqlExpression }, - nullable: true, - argumentsPropagateNullability: new[] { false }, - sqlExpression.Type, - sqlExpression.TypeMapping) - : null; - - /// - /// Translates Min over an expression to an equivalent SQL representation. - /// - /// An expression to translate Min over. - /// A SQL translation of Min over the given expression. - [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] - public virtual SqlExpression? TranslateMin(SqlExpression sqlExpression) - => sqlExpression != null - ? _sqlExpressionFactory.Function( - "MIN", - new[] { sqlExpression }, - nullable: true, - argumentsPropagateNullability: new[] { false }, - sqlExpression.Type, - sqlExpression.TypeMapping) - : null; - - /// - /// Translates Sum over an expression to an equivalent SQL representation. - /// - /// An expression to translate Sum over. - /// A SQL translation of Sum over the given expression. - [Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")] - public virtual SqlExpression? TranslateSum(SqlExpression sqlExpression) - { - var inputType = sqlExpression.Type; - - return inputType == typeof(float) - ? _sqlExpressionFactory.Convert( - _sqlExpressionFactory.Function( - "SUM", - new[] { sqlExpression }, - nullable: true, - argumentsPropagateNullability: new[] { false }, - typeof(double)), - inputType, - sqlExpression.TypeMapping) - : _sqlExpressionFactory.Function( - "SUM", - new[] { sqlExpression }, - nullable: true, - argumentsPropagateNullability: new[] { false }, - inputType, - sqlExpression.TypeMapping); - } - /// protected override Expression VisitBinary(BinaryExpression binaryExpression) { @@ -937,14 +799,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp // translation. case { - Method: - { - Name: nameof(RelationalDbFunctionsExtensions.Least) or nameof(RelationalDbFunctionsExtensions.Greatest), - IsGenericMethod: true - }, + Method.Name: nameof(RelationalDbFunctionsExtensions.Least) or nameof(RelationalDbFunctionsExtensions.Greatest), Arguments: [_, NewArrayExpression newArray] - } when method.GetGenericMethodDefinition() is var genericMethodDefinition - && (genericMethodDefinition == LeastMethodInfo || genericMethodDefinition == GreatestMethodInfo): + } when method.DeclaringType == typeof(RelationalDbFunctionsExtensions): { var values = newArray.Expressions; var translatedValues = new SqlExpression[values.Count]; @@ -962,21 +819,54 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp translatedValues[i] = translatedValue!; } - var elementClrType = newArray.Type.GetElementType()!; + var elementClrType = newArray.Type.GetElementType()!.UnwrapNullableType(); - if (genericMethodDefinition == LeastMethodInfo - && _sqlExpressionFactory.TryCreateLeast(translatedValues, elementClrType, out var leastExpression)) - { - return leastExpression; - } + return method.Name switch + { + nameof(RelationalDbFunctionsExtensions.Greatest) => GenerateGreatest(translatedValues, elementClrType), + nameof(RelationalDbFunctionsExtensions.Least) => GenerateLeast(translatedValues, elementClrType), + _ => throw new UnreachableException() + } + ?? QueryCompilationContext.NotTranslatedExpression; + } - if (genericMethodDefinition == GreatestMethodInfo - && _sqlExpressionFactory.TryCreateGreatest(translatedValues, elementClrType, out var greatestExpression)) + // Translate Math.Max/Min. + // These are here rather than in a MethodTranslator since we use TranslateGreatest/Least, and are very similar to the + // EF.Functions.Greatest/Least translation just above. + case + { + Method.Name: nameof(Math.Max) or nameof(Math.Min), + Arguments: [Expression argument1, Expression argument2] + } when method.DeclaringType == typeof(Math): + { + var translatedArguments = new List(); + + return TryFlattenVisit(argument1) + && TryFlattenVisit(argument2) + && method.Name switch + { + nameof(Math.Max) => GenerateGreatest(translatedArguments, argument1.Type), + nameof(Math.Min) => GenerateLeast(translatedArguments, argument1.Type), + _ => throw new UnreachableException() + } is SqlExpression translatedFunctionCall + ? translatedFunctionCall + : QueryCompilationContext.NotTranslatedExpression; + + bool TryFlattenVisit(Expression argument) { - return greatestExpression; - } + if (argument is MethodCallExpression nestedCall && nestedCall.Method == method) + { + return TryFlattenVisit(nestedCall.Arguments[0]) && TryFlattenVisit(nestedCall.Arguments[1]); + } - throw new UnreachableException(); + if (TranslationFailed(argument, Visit(argument), out var translatedArgument)) + { + return false; + } + + translatedArguments.Add(translatedArgument!); + return true; + } } // For queryable methods, either we translate the whole aggregate or we go to subquery mode @@ -1531,6 +1421,18 @@ when QueryableMethods.IsSumWithSelector(genericMethod): return false; } + /// + /// Generates a SQL GREATEST expression over the given expressions. + /// + public virtual SqlExpression? GenerateGreatest(IReadOnlyList expressions, Type resultType) + => null; + + /// + /// Generates a SQL GREATEST expression over the given expressions. + /// + public virtual SqlExpression? GenerateLeast(IReadOnlyList expressions, Type resultType) + => null; + private bool TryTranslateAsEnumerableExpression( Expression? expression, [NotNullWhen(true)] out EnumerableExpression? enumerableExpression) diff --git a/src/EFCore.Relational/Query/SqlExpressionFactory.cs b/src/EFCore.Relational/Query/SqlExpressionFactory.cs index 2630ef75bfd..89b8ba5ab3d 100644 --- a/src/EFCore.Relational/Query/SqlExpressionFactory.cs +++ b/src/EFCore.Relational/Query/SqlExpressionFactory.cs @@ -954,65 +954,4 @@ public virtual SqlExpression Constant(object value, RelationalTypeMapping? typeM /// public virtual SqlExpression Constant(object? value, Type type, RelationalTypeMapping? typeMapping = null) => new SqlConstantExpression(value, type, typeMapping); - - /// - public virtual bool TryCreateLeast( - IReadOnlyList expressions, - Type resultType, - [NotNullWhen(true)] out SqlExpression? leastExpression) - { - var resultTypeMapping = ExpressionExtensions.InferTypeMapping(expressions); - - expressions = FlattenLeastGreatest("LEAST", expressions); - - leastExpression = Function( - "LEAST", expressions, nullable: true, Enumerable.Repeat(true, expressions.Count), resultType, resultTypeMapping); - return true; - } - - /// - public virtual bool TryCreateGreatest( - IReadOnlyList expressions, - Type resultType, - [NotNullWhen(true)] out SqlExpression? greatestExpression) - { - var resultTypeMapping = ExpressionExtensions.InferTypeMapping(expressions); - - expressions = FlattenLeastGreatest("GREATEST", expressions); - - greatestExpression = Function( - "GREATEST", expressions, nullable: true, Enumerable.Repeat(true, expressions.Count), resultType, resultTypeMapping); - return true; - } - - private IReadOnlyList FlattenLeastGreatest(string functionName, IReadOnlyList expressions) - { - List? flattenedExpressions = null; - - for (var i = 0; i < expressions.Count; i++) - { - var expression = expressions[i]; - if (expression is SqlFunctionExpression { IsBuiltIn: true } nestedFunction - && nestedFunction.Name == functionName) - { - if (flattenedExpressions is null) - { - flattenedExpressions = []; - for (var j = 0; j < i; j++) - { - flattenedExpressions.Add(expressions[j]); - } - } - - Check.DebugAssert(nestedFunction.Arguments is not null, "Null arguments to " + functionName); - flattenedExpressions.AddRange(nestedFunction.Arguments); - } - else - { - flattenedExpressions?.Add(expressions[i]); - } - } - - return flattenedExpressions ?? expressions; - } } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlExpressionFactory.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlExpressionFactory.cs index 7673a0ce3f1..df768474719 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlExpressionFactory.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlExpressionFactory.cs @@ -67,34 +67,4 @@ private SqlExpression ApplyTypeMappingOnAtTimeZone(AtTimeZoneExpression atTimeZo atTimeZoneExpression.Type, typeMapping); } - - /// - public override bool TryCreateLeast( - IReadOnlyList expressions, - Type resultType, - [NotNullWhen(true)] out SqlExpression? leastExpression) - { - if (_sqlServerCompatibilityLevel >= 160) - { - return base.TryCreateLeast(expressions, resultType, out leastExpression); - } - - leastExpression = null; - return false; - } - - /// - public override bool TryCreateGreatest( - IReadOnlyList expressions, - Type resultType, - [NotNullWhen(true)] out SqlExpression? greatestExpression) - { - if (_sqlServerCompatibilityLevel >= 160) - { - return base.TryCreateGreatest(expressions, resultType, out greatestExpression); - } - - greatestExpression = null; - return false; - } } diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs index 873ca89fa1d..b09f699f206 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlTranslatingExpressionVisitor.cs @@ -524,6 +524,50 @@ private static string EscapeLikePattern(string pattern) return builder.ToString(); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override SqlExpression? GenerateGreatest(IReadOnlyList expressions, Type resultType) + { + // Docs: https://learn.microsoft.com/sql/t-sql/functions/logical-functions-greatest-transact-sql + if (_sqlServerCompatibilityLevel < 160) + { + return null; + } + + var resultTypeMapping = ExpressionExtensions.InferTypeMapping(expressions); + + // If one or more arguments aren't NULL, then NULL arguments are ignored during comparison. + // If all arguments are NULL, then GREATEST returns NULL. + return _sqlExpressionFactory.Function( + "GREATEST", expressions, nullable: true, Enumerable.Repeat(false, expressions.Count), resultType, resultTypeMapping); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override SqlExpression? GenerateLeast(IReadOnlyList expressions, Type resultType) + { + // Docs: https://learn.microsoft.com/sql/t-sql/functions/logical-functions-least-transact-sql + if (_sqlServerCompatibilityLevel < 160) + { + return null; + } + + var resultTypeMapping = ExpressionExtensions.InferTypeMapping(expressions); + + // If one or more arguments aren't NULL, then NULL arguments are ignored during comparison. + // If all arguments are NULL, then LEAST returns NULL. + return _sqlExpressionFactory.Function( + "LEAST", expressions, nullable: true, Enumerable.Repeat(false, expressions.Count), resultType, resultTypeMapping); + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs index 4ffec17dbda..a1d84af1098 100644 --- a/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs +++ b/src/EFCore.SqlServer/Query/Internal/Translators/SqlServerMathTranslator.cs @@ -71,6 +71,8 @@ public class SqlServerMathTranslator : IMethodCallTranslator { typeof(float).GetRuntimeMethod(nameof(float.RadiansToDegrees), [typeof(float)])!, "DEGREES" } }; + // Note: Math.Max/Min are handled in RelationalSqlTranslatingExpressionVisitor + private static readonly IEnumerable TruncateMethodInfos = new[] { typeof(Math).GetRuntimeMethod(nameof(Math.Truncate), [typeof(decimal)])!, @@ -147,7 +149,7 @@ public SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) resultType = typeof(double); } - var result = (SqlExpression)_sqlExpressionFactory.Function( + var result = _sqlExpressionFactory.Function( "ROUND", new[] { argument, _sqlExpressionFactory.Constant(0), _sqlExpressionFactory.Constant(1) }, nullable: true, @@ -174,7 +176,7 @@ public SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) resultType = typeof(double); } - var result = (SqlExpression)_sqlExpressionFactory.Function( + var result = _sqlExpressionFactory.Function( "ROUND", new[] { argument, digits }, nullable: true, @@ -189,31 +191,6 @@ public SqlServerMathTranslator(ISqlExpressionFactory sqlExpressionFactory) return _sqlExpressionFactory.ApplyTypeMapping(result, argument.TypeMapping); } - if (method.DeclaringType == typeof(Math)) - { - if (method.Name == nameof(Math.Min)) - { - if (_sqlExpressionFactory.TryCreateLeast( - new[] { arguments[0], arguments[1] }, method.ReturnType, out var leastExpression)) - { - return leastExpression; - } - - throw new InvalidOperationException(SqlServerStrings.LeastGreatestCompatibilityLevelTooLow); - } - - if (method.Name == nameof(Math.Max)) - { - if (_sqlExpressionFactory.TryCreateGreatest( - new[] { arguments[0], arguments[1] }, method.ReturnType, out var leastExpression)) - { - return leastExpression; - } - - throw new InvalidOperationException(SqlServerStrings.LeastGreatestCompatibilityLevelTooLow); - } - } - return null; } } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlExpressionFactory.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlExpressionFactory.cs index a80c1647f5a..e70c4289e0e 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlExpressionFactory.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlExpressionFactory.cs @@ -183,65 +183,4 @@ private SqlExpression ApplyTypeMappingOnGlob(GlobExpression globExpression) ? new RegexpExpression(match, pattern, _boolTypeMapping) : regexpExpression; } - - /// - public override bool TryCreateLeast( - IReadOnlyList expressions, - Type resultType, - [NotNullWhen(true)] out SqlExpression? leastExpression) - { - var resultTypeMapping = ExpressionExtensions.InferTypeMapping(expressions); - - expressions = FlattenLeastGreatest("min", expressions); - - leastExpression = Function( - "min", expressions, nullable: true, Enumerable.Repeat(true, expressions.Count), resultType, resultTypeMapping); - return true; - } - - /// - public override bool TryCreateGreatest( - IReadOnlyList expressions, - Type resultType, - [NotNullWhen(true)] out SqlExpression? greatestExpression) - { - var resultTypeMapping = ExpressionExtensions.InferTypeMapping(expressions); - - expressions = FlattenLeastGreatest("max", expressions); - - greatestExpression = Function( - "max", expressions, nullable: true, Enumerable.Repeat(true, expressions.Count), resultType, resultTypeMapping); - return true; - } - - private IReadOnlyList FlattenLeastGreatest(string functionName, IReadOnlyList expressions) - { - List? flattenedExpressions = null; - - for (var i = 0; i < expressions.Count; i++) - { - var expression = expressions[i]; - if (expression is SqlFunctionExpression { IsBuiltIn: true } nestedFunction - && nestedFunction.Name == functionName) - { - if (flattenedExpressions is null) - { - flattenedExpressions = []; - for (var j = 0; j < i; j++) - { - flattenedExpressions.Add(expressions[j]); - } - } - - Check.DebugAssert(nestedFunction.Arguments is not null, "Null arguments to " + functionName); - flattenedExpressions.AddRange(nestedFunction.Arguments); - } - else - { - flattenedExpressions?.Add(expressions[i]); - } - } - - return flattenedExpressions ?? expressions; - } } diff --git a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs index 62b3a4db5f8..15593a83e31 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlTranslatingExpressionVisitor.cs @@ -469,6 +469,38 @@ private static string EscapeLikePattern(string pattern) return builder.ToString(); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override SqlExpression GenerateGreatest(IReadOnlyList expressions, Type resultType) + { + // Docs: https://sqlite.org/lang_corefunc.html#max_scalar + var resultTypeMapping = ExpressionExtensions.InferTypeMapping(expressions); + + // The multi-argument max() function returns the argument with the maximum value, or return NULL if any argument is NULL. + return _sqlExpressionFactory.Function( + "max", expressions, nullable: true, Enumerable.Repeat(true, expressions.Count), resultType, resultTypeMapping); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override SqlExpression GenerateLeast(IReadOnlyList expressions, Type resultType) + { + // Docs: https://sqlite.org/lang_corefunc.html#min_scalar + var resultTypeMapping = ExpressionExtensions.InferTypeMapping(expressions); + + // The multi-argument min() function returns the argument with the minimum value, or return NULL if any argument is NULL. + return _sqlExpressionFactory.Function( + "min", expressions, nullable: true, Enumerable.Repeat(true, expressions.Count), resultType, resultTypeMapping); + } + [return: NotNullIfNotNull(nameof(expression))] private static Type? GetProviderType(SqlExpression? expression) => expression == null diff --git a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteMathTranslator.cs b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteMathTranslator.cs index 3f305152f19..953b46c2abd 100644 --- a/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteMathTranslator.cs +++ b/src/EFCore.Sqlite.Core/Query/Internal/Translators/SqliteMathTranslator.cs @@ -80,6 +80,8 @@ public class SqliteMathTranslator : IMethodCallTranslator { typeof(float).GetRuntimeMethod(nameof(float.RadiansToDegrees), [typeof(float)])!, "degrees" } }; + // Note: Math.Max/Min are handled in RelationalSqlTranslatingExpressionVisitor + private static readonly List _roundWithDecimalMethods = [ typeof(Math).GetMethod(nameof(Math.Round), [typeof(double), typeof(int)])!, @@ -162,25 +164,6 @@ public SqliteMathTranslator(ISqlExpressionFactory sqlExpressionFactory) typeMapping); } - if (method.DeclaringType == typeof(Math)) - { - if (method.Name == nameof(Math.Min)) - { - var success = _sqlExpressionFactory.TryCreateLeast( - new[] { arguments[0], arguments[1] }, method.ReturnType, out var leastExpression); - Check.DebugAssert(success, "Couldn't generate min"); - return leastExpression; - } - - if (method.Name == nameof(Math.Max)) - { - var success = _sqlExpressionFactory.TryCreateGreatest( - new[] { arguments[0], arguments[1] }, method.ReturnType, out var leastExpression); - Check.DebugAssert(success, "Couldn't generate max"); - return leastExpression; - } - } - return null; } } diff --git a/test/EFCore.Relational.Specification.Tests/Query/RelationalNorthwindDbFunctionsQueryTestBase.cs b/test/EFCore.Relational.Specification.Tests/Query/RelationalNorthwindDbFunctionsQueryTestBase.cs index 682df514e6a..7b3ad56021b 100644 --- a/test/EFCore.Relational.Specification.Tests/Query/RelationalNorthwindDbFunctionsQueryTestBase.cs +++ b/test/EFCore.Relational.Specification.Tests/Query/RelationalNorthwindDbFunctionsQueryTestBase.cs @@ -65,6 +65,22 @@ public virtual Task Greatest(bool async) ss => ss.Set().Where(od => EF.Functions.Greatest(od.OrderID, 10251) == 10251), ss => ss.Set().Where(od => Math.Max(od.OrderID, 10251) == 10251)); + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Least_with_nullable_value_type(bool async) + => AssertQuery( + async, + ss => ss.Set().Where(od => EF.Functions.Least(od.OrderID, (int?)10251) == 10251), + ss => ss.Set().Where(od => Math.Min(od.OrderID, 10251) == 10251)); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual Task Greatest_with_nullable_value_type(bool async) + => AssertQuery( + async, + ss => ss.Set().Where(od => EF.Functions.Greatest(od.OrderID, (int?)10251) == 10251), + ss => ss.Set().Where(od => Math.Max(od.OrderID, 10251) == 10251)); + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual async Task Least_with_parameter_array_is_not_supported(bool async) diff --git a/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs index 6a4d5cb42bd..7563a04436a 100644 --- a/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/PrimitiveCollectionsQueryTestBase.cs @@ -236,6 +236,50 @@ await AssertQuery( ss => ss.Set().Where(c => new List { 30, c.Int, i }.Max() == 35)); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Inline_collection_of_nullable_value_type_Min(bool async) + { + int? i = 25; + + await AssertQuery( + async, + ss => ss.Set().Where(c => new[] { 30, c.NullableInt, i }.Min() == 25)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Inline_collection_of_nullable_value_type_Max(bool async) + { + int? i = 35; + + await AssertQuery( + async, + ss => ss.Set().Where(c => new[] { 30, c.NullableInt, i }.Max() == 35)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Inline_collection_of_nullable_value_type_with_null_Min(bool async) + { + int? i = null; + + await AssertQuery( + async, + ss => ss.Set().Where(c => new[] { 30, c.NullableInt, i }.Min() == 30)); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Inline_collection_of_nullable_value_type_with_null_Max(bool async) + { + int? i = null; + + await AssertQuery( + async, + ss => ss.Set().Where(c => new[] { 30, c.NullableInt, i }.Max() == 30)); + } + [ConditionalTheory] [MemberData(nameof(IsAsyncData))] public virtual Task Parameter_collection_Count(bool async) diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindDbFunctionsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindDbFunctionsQuerySqlServerTest.cs index 37d00a9cd21..2233f50a1ce 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindDbFunctionsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindDbFunctionsQuerySqlServerTest.cs @@ -147,6 +147,32 @@ WHERE GREATEST([o].[OrderID], 10251) = 10251 """); } + [SqlServerCondition(SqlServerCondition.SupportsFunctions2022)] + public override async Task Least_with_nullable_value_type(bool async) + { + await base.Least_with_nullable_value_type(async); + + AssertSql( + """ +SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice] +FROM [Order Details] AS [o] +WHERE LEAST([o].[OrderID], 10251) = 10251 +"""); + } + + [SqlServerCondition(SqlServerCondition.SupportsFunctions2022)] + public override async Task Greatest_with_nullable_value_type(bool async) + { + await base.Greatest_with_nullable_value_type(async); + + AssertSql( + """ +SELECT [o].[OrderID], [o].[ProductID], [o].[Discount], [o].[Quantity], [o].[UnitPrice] +FROM [Order Details] AS [o] +WHERE GREATEST([o].[OrderID], 10251) = 10251 +"""); + } + public override async Task Least_with_parameter_array_is_not_supported(bool async) { await base.Least_with_parameter_array_is_not_supported(async); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQueryOldSqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQueryOldSqlServerTest.cs index 655aef626db..81cd1de0b90 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQueryOldSqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQueryOldSqlServerTest.cs @@ -374,6 +374,66 @@ SELECT MAX([v].[Value]) """); } + public override async Task Inline_collection_of_nullable_value_type_Min(bool async) + { + await base.Inline_collection_of_nullable_value_type_Min(async); + + AssertSql( + """ +@__i_0='25' (Nullable = true) + +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE ( + SELECT MIN([v].[Value]) + FROM (VALUES (CAST(30 AS int)), ([p].[NullableInt]), (@__i_0)) AS [v]([Value])) = 25 +"""); + } + + public override async Task Inline_collection_of_nullable_value_type_Max(bool async) + { + await base.Inline_collection_of_nullable_value_type_Max(async); + + AssertSql( + """ +@__i_0='35' (Nullable = true) + +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE ( + SELECT MAX([v].[Value]) + FROM (VALUES (CAST(30 AS int)), ([p].[NullableInt]), (@__i_0)) AS [v]([Value])) = 35 +"""); + } + + public override async Task Inline_collection_of_nullable_value_type_with_null_Min(bool async) + { + await base.Inline_collection_of_nullable_value_type_with_null_Min(async); + + AssertSql( + """ +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE ( + SELECT MIN([v].[Value]) + FROM (VALUES (CAST(30 AS int)), ([p].[NullableInt]), (NULL)) AS [v]([Value])) = 30 +"""); + } + + public override async Task Inline_collection_of_nullable_value_type_with_null_Max(bool async) + { + await base.Inline_collection_of_nullable_value_type_with_null_Max(async); + + AssertSql( + """ +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE ( + SELECT MAX([v].[Value]) + FROM (VALUES (CAST(30 AS int)), ([p].[NullableInt]), (NULL)) AS [v]([Value])) = 30 +"""); + } + public override Task Parameter_collection_Count(bool async) => AssertCompatibilityLevelTooLow(() => base.Parameter_collection_Count(async)); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs index 37725332129..20d05f5dda6 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs @@ -354,6 +354,62 @@ WHERE GREATEST(30, [p].[Int], @__i_0) = 35 """); } + [SqlServerCondition(SqlServerCondition.SupportsFunctions2022)] + public override async Task Inline_collection_of_nullable_value_type_Min(bool async) + { + await base.Inline_collection_of_nullable_value_type_Min(async); + + AssertSql( + """ +@__i_0='25' (Nullable = true) + +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE LEAST(30, [p].[NullableInt], @__i_0) = 25 +"""); + } + + [SqlServerCondition(SqlServerCondition.SupportsFunctions2022)] + public override async Task Inline_collection_of_nullable_value_type_Max(bool async) + { + await base.Inline_collection_of_nullable_value_type_Max(async); + + AssertSql( + """ +@__i_0='35' (Nullable = true) + +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE GREATEST(30, [p].[NullableInt], @__i_0) = 35 +"""); + } + + [SqlServerCondition(SqlServerCondition.SupportsFunctions2022)] + public override async Task Inline_collection_of_nullable_value_type_with_null_Min(bool async) + { + await base.Inline_collection_of_nullable_value_type_with_null_Min(async); + + AssertSql( + """ +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE LEAST(30, [p].[NullableInt], NULL) = 30 +"""); + } + + [SqlServerCondition(SqlServerCondition.SupportsFunctions2022)] + public override async Task Inline_collection_of_nullable_value_type_with_null_Max(bool async) + { + await base.Inline_collection_of_nullable_value_type_with_null_Max(async); + + AssertSql( + """ +SELECT [p].[Id], [p].[Bool], [p].[Bools], [p].[DateTime], [p].[DateTimes], [p].[Enum], [p].[Enums], [p].[Int], [p].[Ints], [p].[NullableInt], [p].[NullableInts], [p].[NullableString], [p].[NullableStrings], [p].[String], [p].[Strings] +FROM [PrimitiveCollectionsEntity] AS [p] +WHERE GREATEST(30, [p].[NullableInt], NULL) = 30 +"""); + } + public override async Task Parameter_collection_Count(bool async) { await base.Parameter_collection_Count(async); diff --git a/test/EFCore.Sqlite.FunctionalTests/Query/PrimitiveCollectionsQuerySqliteTest.cs b/test/EFCore.Sqlite.FunctionalTests/Query/PrimitiveCollectionsQuerySqliteTest.cs index a0641afdf81..a04c6399952 100644 --- a/test/EFCore.Sqlite.FunctionalTests/Query/PrimitiveCollectionsQuerySqliteTest.cs +++ b/test/EFCore.Sqlite.FunctionalTests/Query/PrimitiveCollectionsQuerySqliteTest.cs @@ -352,6 +352,66 @@ WHERE max(30, "p"."Int", @__i_0) = 35 """); } + public override async Task Inline_collection_of_nullable_value_type_Min(bool async) + { + await base.Inline_collection_of_nullable_value_type_Min(async); + + AssertSql( + """ +@__i_0='25' (Nullable = true) + +SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings" +FROM "PrimitiveCollectionsEntity" AS "p" +WHERE ( + SELECT MIN("v"."Value") + FROM (SELECT CAST(30 AS INTEGER) AS "Value" UNION ALL VALUES ("p"."NullableInt"), (@__i_0)) AS "v") = 25 +"""); + } + + public override async Task Inline_collection_of_nullable_value_type_Max(bool async) + { + await base.Inline_collection_of_nullable_value_type_Max(async); + + AssertSql( + """ +@__i_0='35' (Nullable = true) + +SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings" +FROM "PrimitiveCollectionsEntity" AS "p" +WHERE ( + SELECT MAX("v"."Value") + FROM (SELECT CAST(30 AS INTEGER) AS "Value" UNION ALL VALUES ("p"."NullableInt"), (@__i_0)) AS "v") = 35 +"""); + } + + public override async Task Inline_collection_of_nullable_value_type_with_null_Min(bool async) + { + await base.Inline_collection_of_nullable_value_type_with_null_Min(async); + + AssertSql( + """ +SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings" +FROM "PrimitiveCollectionsEntity" AS "p" +WHERE ( + SELECT MIN("v"."Value") + FROM (SELECT CAST(30 AS INTEGER) AS "Value" UNION ALL VALUES ("p"."NullableInt"), (NULL)) AS "v") = 30 +"""); + } + + public override async Task Inline_collection_of_nullable_value_type_with_null_Max(bool async) + { + await base.Inline_collection_of_nullable_value_type_with_null_Max(async); + + AssertSql( + """ +SELECT "p"."Id", "p"."Bool", "p"."Bools", "p"."DateTime", "p"."DateTimes", "p"."Enum", "p"."Enums", "p"."Int", "p"."Ints", "p"."NullableInt", "p"."NullableInts", "p"."NullableString", "p"."NullableStrings", "p"."String", "p"."Strings" +FROM "PrimitiveCollectionsEntity" AS "p" +WHERE ( + SELECT MAX("v"."Value") + FROM (SELECT CAST(30 AS INTEGER) AS "Value" UNION ALL VALUES ("p"."NullableInt"), (NULL)) AS "v") = 30 +"""); + } + public override async Task Parameter_collection_Count(bool async) { await base.Parameter_collection_Count(async);