Skip to content

Commit

Permalink
Nullability-related fixes to LEAST/GREATEST
Browse files Browse the repository at this point in the history
Fixup to #32338
  • Loading branch information
roji committed Jul 22, 2024
1 parent d41ba67 commit f1181d9
Show file tree
Hide file tree
Showing 16 changed files with 432 additions and 392 deletions.
28 changes: 0 additions & 28 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.Query;
Expand Down Expand Up @@ -439,30 +437,4 @@ SqlExpression NiladicFunction(
/// <param name="sql">A string token to print in SQL tree.</param>
/// <returns>An expression representing a SQL token.</returns>
SqlExpression Fragment(string sql);

/// <summary>
/// Attempts to creates a new expression that returns the smallest value from a list of expressions, e.g. an invocation of the
/// <c>LEAST</c> SQL function.
/// </summary>
/// <param name="expressions">An entity type to project.</param>
/// <param name="resultType">The result CLR type for the returned expression.</param>
/// <param name="leastExpression">The expression which computes the smallest value.</param>
/// <returns><see langword="true" /> if the expression could be created, <see langword="false" /> otherwise.</returns>
bool TryCreateLeast(
IReadOnlyList<SqlExpression> expressions,
Type resultType,
[NotNullWhen(true)] out SqlExpression? leastExpression);

/// <summary>
/// Attempts to creates a new expression that returns the greatest value from a list of expressions, e.g. an invocation of the
/// <c>GREATEST</c> SQL function.
/// </summary>
/// <param name="expressions">An entity type to project.</param>
/// <param name="resultType">The result CLR type for the returned expression.</param>
/// <param name="greatestExpression">The expression which computes the greatest value.</param>
/// <returns><see langword="true" /> if the expression could be created, <see langword="false" /> otherwise.</returns>
bool TryCreateGreatest(
IReadOnlyList<SqlExpression> expressions,
Type resultType,
[NotNullWhen(true)] out SqlExpression? greatestExpression);
}
Original file line number Diff line number Diff line change
Expand Up @@ -976,19 +976,39 @@ private SqlExpression CreateJoinPredicate(Expression outerKey, Expression innerK

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateMax(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
=> TryExtractBareInlineCollectionValues(source, out var values)
&& _sqlExpressionFactory.TryCreateGreatest(values, resultType, out var greatestExpression)
? source.Update(new SelectExpression(greatestExpression, _sqlAliasManager), source.ShaperExpression)
: TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
{
// For Max() over an inline array, translate to GREATEST() if possible; otherwise use the default translation of aggregate SQL
// MAX().
// Note that some providers propagate NULL arguments (SQLite, MySQL), while others only return NULL if all arguments evaluate to
// NULL (SQL Server, PostgreSQL). If the argument is a nullable value type, don't translate to GREATEST() if it propagates NULLs,
// to match the .NET behavior.
if (TryExtractBareInlineCollectionValues(source, out var values)
&& _sqlTranslator.GenerateGreatest(values, resultType.UnwrapNullableType()) is SqlFunctionExpression greatestExpression
&& (Nullable.GetUnderlyingType(resultType) is null
|| greatestExpression.ArgumentsPropagateNullability?.All(a => a == false) == true))
{
return source.Update(new SelectExpression(greatestExpression, _sqlAliasManager), source.ShaperExpression);
}

return TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MaxWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
}

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateMin(ShapedQueryExpression source, LambdaExpression? selector, Type resultType)
=> TryExtractBareInlineCollectionValues(source, out var values)
&& _sqlExpressionFactory.TryCreateLeast(values, resultType, out var leastExpression)
? source.Update(new SelectExpression(leastExpression, _sqlAliasManager), source.ShaperExpression)
: TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
{
// See comments above in TranslateMax()
if (TryExtractBareInlineCollectionValues(source, out var values)
&& _sqlTranslator.GenerateLeast(values, resultType.UnwrapNullableType()) is SqlFunctionExpression leastExpression
&& (Nullable.GetUnderlyingType(resultType) is null
|| leastExpression.ArgumentsPropagateNullability?.All(a => a == false) == true))
{
return source.Update(new SelectExpression(leastExpression, _sqlAliasManager), source.ShaperExpression);
}

return TranslateAggregateWithSelector(
source, selector, t => QueryableMethods.MinWithoutSelector.MakeGenericMethod(t), throwWhenEmpty: true, resultType);
}

/// <inheritdoc />
protected override ShapedQueryExpression? TranslateOfType(ShapedQueryExpression source, Type resultType)
Expand Down
214 changes: 58 additions & 156 deletions src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,6 @@ private static readonly MethodInfo StringEqualsWithStringComparison
private static readonly MethodInfo StringEqualsWithStringComparisonStatic
= typeof(string).GetRuntimeMethod(nameof(string.Equals), [typeof(string), typeof(string), typeof(StringComparison)])!;

private static readonly MethodInfo LeastMethodInfo
= typeof(RelationalDbFunctionsExtensions).GetMethod(nameof(RelationalDbFunctionsExtensions.Least))!;

private static readonly MethodInfo GreatestMethodInfo
= typeof(RelationalDbFunctionsExtensions).GetMethod(nameof(RelationalDbFunctionsExtensions.Greatest))!;

private static readonly MethodInfo GetTypeMethodInfo = typeof(object).GetTypeInfo().GetDeclaredMethod(nameof(GetType))!;

private readonly QueryCompilationContext _queryCompilationContext;
Expand Down Expand Up @@ -183,138 +177,6 @@ protected virtual void AddTranslationErrorDetails(string details)
return result;
}

/// <summary>
/// Translates Average over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate Average over.</param>
/// <returns>A SQL translation of Average over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateAverage(SqlExpression sqlExpression)
{
var inputType = sqlExpression.Type;
if (inputType == typeof(int)
|| inputType == typeof(long))
{
sqlExpression = sqlExpression is DistinctExpression distinctExpression
? new DistinctExpression(
_sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(distinctExpression.Operand, typeof(double))))
: _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Convert(sqlExpression, typeof(double)));
}

return inputType == typeof(float)
? _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"AVG",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(double)),
sqlExpression.Type,
sqlExpression.TypeMapping)
: _sqlExpressionFactory.Function(
"AVG",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping);
}

/// <summary>
/// Translates Count over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate Count over.</param>
/// <returns>A SQL translation of Count over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateCount(SqlExpression sqlExpression)
=> _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function(
"COUNT",
new[] { sqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(int)));

/// <summary>
/// Translates LongCount over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate LongCount over.</param>
/// <returns>A SQL translation of LongCount over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateLongCount(SqlExpression sqlExpression)
=> _sqlExpressionFactory.ApplyDefaultTypeMapping(
_sqlExpressionFactory.Function(
"COUNT",
new[] { sqlExpression },
nullable: false,
argumentsPropagateNullability: new[] { false },
typeof(long)));

/// <summary>
/// Translates Max over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate Max over.</param>
/// <returns>A SQL translation of Max over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateMax(SqlExpression sqlExpression)
=> sqlExpression != null
? _sqlExpressionFactory.Function(
"MAX",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping)
: null;

/// <summary>
/// Translates Min over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate Min over.</param>
/// <returns>A SQL translation of Min over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateMin(SqlExpression sqlExpression)
=> sqlExpression != null
? _sqlExpressionFactory.Function(
"MIN",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
sqlExpression.Type,
sqlExpression.TypeMapping)
: null;

/// <summary>
/// Translates Sum over an expression to an equivalent SQL representation.
/// </summary>
/// <param name="sqlExpression">An expression to translate Sum over.</param>
/// <returns>A SQL translation of Sum over the given expression.</returns>
[Obsolete("Use IAggregateMethodCallTranslatorProvider to add translation for aggregate methods")]
public virtual SqlExpression? TranslateSum(SqlExpression sqlExpression)
{
var inputType = sqlExpression.Type;

return inputType == typeof(float)
? _sqlExpressionFactory.Convert(
_sqlExpressionFactory.Function(
"SUM",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
typeof(double)),
inputType,
sqlExpression.TypeMapping)
: _sqlExpressionFactory.Function(
"SUM",
new[] { sqlExpression },
nullable: true,
argumentsPropagateNullability: new[] { false },
inputType,
sqlExpression.TypeMapping);
}

/// <inheritdoc />
protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
Expand Down Expand Up @@ -937,14 +799,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
// translation.
case
{
Method:
{
Name: nameof(RelationalDbFunctionsExtensions.Least) or nameof(RelationalDbFunctionsExtensions.Greatest),
IsGenericMethod: true
},
Method.Name: nameof(RelationalDbFunctionsExtensions.Least) or nameof(RelationalDbFunctionsExtensions.Greatest),
Arguments: [_, NewArrayExpression newArray]
} when method.GetGenericMethodDefinition() is var genericMethodDefinition
&& (genericMethodDefinition == LeastMethodInfo || genericMethodDefinition == GreatestMethodInfo):
} when method.DeclaringType == typeof(RelationalDbFunctionsExtensions):
{
var values = newArray.Expressions;
var translatedValues = new SqlExpression[values.Count];
Expand All @@ -962,21 +819,54 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp
translatedValues[i] = translatedValue!;
}

var elementClrType = newArray.Type.GetElementType()!;
var elementClrType = newArray.Type.GetElementType()!.UnwrapNullableType();

if (genericMethodDefinition == LeastMethodInfo
&& _sqlExpressionFactory.TryCreateLeast(translatedValues, elementClrType, out var leastExpression))
{
return leastExpression;
}
return method.Name switch
{
nameof(RelationalDbFunctionsExtensions.Greatest) => GenerateGreatest(translatedValues, elementClrType),
nameof(RelationalDbFunctionsExtensions.Least) => GenerateLeast(translatedValues, elementClrType),
_ => throw new UnreachableException()
}
?? QueryCompilationContext.NotTranslatedExpression;
}

if (genericMethodDefinition == GreatestMethodInfo
&& _sqlExpressionFactory.TryCreateGreatest(translatedValues, elementClrType, out var greatestExpression))
// Translate Math.Max/Min.
// These are here rather than in a MethodTranslator since we use TranslateGreatest/Least, and are very similar to the
// EF.Functions.Greatest/Least translation just above.
case
{
Method.Name: nameof(Math.Max) or nameof(Math.Min),
Arguments: [Expression argument1, Expression argument2]
} when method.DeclaringType == typeof(Math):
{
var translatedArguments = new List<SqlExpression>();

return TryFlattenVisit(argument1)
&& TryFlattenVisit(argument2)
&& method.Name switch
{
nameof(Math.Max) => GenerateGreatest(translatedArguments, argument1.Type),
nameof(Math.Min) => GenerateLeast(translatedArguments, argument1.Type),
_ => throw new UnreachableException()
} is SqlExpression translatedFunctionCall
? translatedFunctionCall
: QueryCompilationContext.NotTranslatedExpression;

bool TryFlattenVisit(Expression argument)
{
return greatestExpression;
}
if (argument is MethodCallExpression nestedCall && nestedCall.Method == method)
{
return TryFlattenVisit(nestedCall.Arguments[0]) && TryFlattenVisit(nestedCall.Arguments[1]);
}

throw new UnreachableException();
if (TranslationFailed(argument, Visit(argument), out var translatedArgument))
{
return false;
}

translatedArguments.Add(translatedArgument!);
return true;
}
}

// For queryable methods, either we translate the whole aggregate or we go to subquery mode
Expand Down Expand Up @@ -1531,6 +1421,18 @@ when QueryableMethods.IsSumWithSelector(genericMethod):
return false;
}

/// <summary>
/// Generates a SQL GREATEST expression over the given expressions.
/// </summary>
public virtual SqlExpression? GenerateGreatest(IReadOnlyList<SqlExpression> expressions, Type resultType)
=> null;

/// <summary>
/// Generates a SQL GREATEST expression over the given expressions.
/// </summary>
public virtual SqlExpression? GenerateLeast(IReadOnlyList<SqlExpression> expressions, Type resultType)
=> null;

private bool TryTranslateAsEnumerableExpression(
Expression? expression,
[NotNullWhen(true)] out EnumerableExpression? enumerableExpression)
Expand Down
Loading

0 comments on commit f1181d9

Please sign in to comment.