Skip to content

Commit

Permalink
Optimize away Coalesce for trivial cases (#34002)
Browse files Browse the repository at this point in the history
  • Loading branch information
ranma42 authored Jun 24, 2024
1 parent ecd6104 commit baaf79e
Show file tree
Hide file tree
Showing 13 changed files with 174 additions and 171 deletions.
146 changes: 73 additions & 73 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs

Large diffs are not rendered by default.

126 changes: 66 additions & 60 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ private SqlExpression ApplyTypeMappingOnJsonScalar(
}

/// <inheritdoc />
public virtual SqlBinaryExpression? MakeBinary(
public virtual SqlExpression? MakeBinary(
ExpressionType operatorType,
SqlExpression left,
SqlExpression right,
Expand All @@ -416,125 +416,131 @@ private SqlExpression ApplyTypeMappingOnJsonScalar(
break;
}

return (SqlBinaryExpression)ApplyTypeMapping(
return ApplyTypeMapping(
new SqlBinaryExpression(operatorType, left, right, returnType, null), typeMapping);
}

/// <inheritdoc />
public virtual SqlBinaryExpression Equal(SqlExpression left, SqlExpression right)
public virtual SqlExpression Equal(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.Equal, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression NotEqual(SqlExpression left, SqlExpression right)
public virtual SqlExpression NotEqual(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.NotEqual, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression GreaterThan(SqlExpression left, SqlExpression right)
public virtual SqlExpression GreaterThan(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.GreaterThan, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression GreaterThanOrEqual(SqlExpression left, SqlExpression right)
public virtual SqlExpression GreaterThanOrEqual(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.GreaterThanOrEqual, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression LessThan(SqlExpression left, SqlExpression right)
public virtual SqlExpression LessThan(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.LessThan, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression LessThanOrEqual(SqlExpression left, SqlExpression right)
public virtual SqlExpression LessThanOrEqual(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.LessThanOrEqual, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression AndAlso(SqlExpression left, SqlExpression right)
public virtual SqlExpression AndAlso(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.AndAlso, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression OrElse(SqlExpression left, SqlExpression right)
public virtual SqlExpression OrElse(SqlExpression left, SqlExpression right)
=> MakeBinary(ExpressionType.OrElse, left, right, null)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Add(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Add(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Add, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Subtract(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Subtract(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Subtract, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Multiply(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Multiply(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Multiply, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Divide(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Divide(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Divide, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Modulo(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Modulo(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Modulo, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression And(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression And(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.And, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlBinaryExpression Or(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Or(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
=> MakeBinary(ExpressionType.Or, left, right, typeMapping)!;

/// <inheritdoc />
public virtual SqlFunctionExpression Coalesce(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Coalesce(SqlExpression left, SqlExpression right, RelationalTypeMapping? typeMapping = null)
{
var resultType = right.Type;
var inferredTypeMapping = typeMapping
?? ExpressionExtensions.InferTypeMapping(left, right)
?? _typeMappingSource.FindMapping(resultType, Dependencies.Model);

var typeMappedArguments = new List<SqlExpression>
left = ApplyTypeMapping(left, inferredTypeMapping);
right = ApplyTypeMapping(right, inferredTypeMapping);

return left switch
{
ApplyTypeMapping(left, inferredTypeMapping), ApplyTypeMapping(right, inferredTypeMapping)
SqlConstantExpression { Value: null } => right,

SqlConstantExpression { Value: not null } or
ColumnExpression { IsNullable: false } => left,

_ => new SqlFunctionExpression(
"COALESCE",
[left, right],
nullable: true,
// COALESCE is handled separately since it's only nullable if *all* arguments are null
argumentsPropagateNullability: [false, false],
resultType,
inferredTypeMapping)
};

return new SqlFunctionExpression(
"COALESCE",
typeMappedArguments,
nullable: true,
// COALESCE is handled separately since it's only nullable if *all* arguments are null
argumentsPropagateNullability: [false, false],
resultType,
inferredTypeMapping);
}

/// <inheritdoc />
public virtual SqlUnaryExpression? MakeUnary(
public virtual SqlExpression? MakeUnary(
ExpressionType operatorType,
SqlExpression operand,
Type type,
RelationalTypeMapping? typeMapping = null)
=> SqlUnaryExpression.IsValidOperator(operatorType)
? (SqlUnaryExpression)ApplyTypeMapping(new SqlUnaryExpression(operatorType, operand, type, null), typeMapping)
? ApplyTypeMapping(new SqlUnaryExpression(operatorType, operand, type, null), typeMapping)
: null;

/// <inheritdoc />
public virtual SqlUnaryExpression IsNull(SqlExpression operand)
public virtual SqlExpression IsNull(SqlExpression operand)
=> MakeUnary(ExpressionType.Equal, operand, typeof(bool))!;

/// <inheritdoc />
public virtual SqlUnaryExpression IsNotNull(SqlExpression operand)
public virtual SqlExpression IsNotNull(SqlExpression operand)
=> MakeUnary(ExpressionType.NotEqual, operand, typeof(bool))!;

/// <inheritdoc />
public virtual SqlUnaryExpression Convert(SqlExpression operand, Type type, RelationalTypeMapping? typeMapping = null)
public virtual SqlExpression Convert(SqlExpression operand, Type type, RelationalTypeMapping? typeMapping = null)
=> MakeUnary(ExpressionType.Convert, operand, type.UnwrapNullableType(), typeMapping)!;

/// <inheritdoc />
public virtual SqlUnaryExpression Not(SqlExpression operand)
public virtual SqlExpression Not(SqlExpression operand)
=> MakeUnary(ExpressionType.Not, operand, operand.Type, operand.TypeMapping)!;

/// <inheritdoc />
public virtual SqlUnaryExpression Negate(SqlExpression operand)
public virtual SqlExpression Negate(SqlExpression operand)
=> MakeUnary(ExpressionType.Negate, operand, operand.Type, operand.TypeMapping)!;

/// <inheritdoc />
public virtual CaseExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
public virtual SqlExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
{
var operandTypeMapping = operand!.TypeMapping
?? whenClauses.Select(wc => wc.Test.TypeMapping).FirstOrDefault(t => t != null)
Expand Down Expand Up @@ -563,7 +569,7 @@ public virtual CaseExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhe
}

/// <inheritdoc />
public virtual CaseExpression Case(IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
public virtual SqlExpression Case(IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
{
var resultTypeMapping = elseResult?.TypeMapping
?? whenClauses.Select(wc => wc.Result.TypeMapping).FirstOrDefault(t => t != null);
Expand All @@ -583,7 +589,7 @@ public virtual CaseExpression Case(IReadOnlyList<CaseWhenClause> whenClauses, Sq
}

/// <inheritdoc />
public virtual SqlFunctionExpression Function(
public virtual SqlExpression Function(
string name,
IEnumerable<SqlExpression> arguments,
bool nullable,
Expand All @@ -602,7 +608,7 @@ public virtual SqlFunctionExpression Function(
}

/// <inheritdoc />
public virtual SqlFunctionExpression Function(
public virtual SqlExpression Function(
string? schema,
string name,
IEnumerable<SqlExpression> arguments,
Expand All @@ -622,7 +628,7 @@ public virtual SqlFunctionExpression Function(
}

/// <inheritdoc />
public virtual SqlFunctionExpression Function(
public virtual SqlExpression Function(
SqlExpression instance,
string name,
IEnumerable<SqlExpression> arguments,
Expand All @@ -645,64 +651,64 @@ public virtual SqlFunctionExpression Function(
}

/// <inheritdoc />
public virtual SqlFunctionExpression NiladicFunction(
public virtual SqlExpression NiladicFunction(
string name,
bool nullable,
Type returnType,
RelationalTypeMapping? typeMapping = null)
=> new(name, nullable, returnType, typeMapping);
=> new SqlFunctionExpression(name, nullable, returnType, typeMapping);

/// <inheritdoc />
public virtual SqlFunctionExpression NiladicFunction(
public virtual SqlExpression NiladicFunction(
string schema,
string name,
bool nullable,
Type returnType,
RelationalTypeMapping? typeMapping = null)
=> new(schema, name, nullable, returnType, typeMapping);
=> new SqlFunctionExpression(schema, name, nullable, returnType, typeMapping);

/// <inheritdoc />
public virtual SqlFunctionExpression NiladicFunction(
public virtual SqlExpression NiladicFunction(
SqlExpression instance,
string name,
bool nullable,
bool instancePropagatesNullability,
Type returnType,
RelationalTypeMapping? typeMapping = null)
=> new(
=> new SqlFunctionExpression(
ApplyDefaultTypeMapping(instance), name, nullable, instancePropagatesNullability, returnType, typeMapping);

/// <inheritdoc />
public virtual ExistsExpression Exists(SelectExpression subquery)
=> new(subquery, _boolTypeMapping);
public virtual SqlExpression Exists(SelectExpression subquery)
=> new ExistsExpression(subquery, _boolTypeMapping);

/// <inheritdoc />
public virtual InExpression In(SqlExpression item, SelectExpression subquery)
public virtual SqlExpression In(SqlExpression item, SelectExpression subquery)
=> ApplyTypeMappingOnIn(new InExpression(item, subquery, _boolTypeMapping));

/// <inheritdoc />
public virtual InExpression In(SqlExpression item, IReadOnlyList<SqlExpression> values)
public virtual SqlExpression In(SqlExpression item, IReadOnlyList<SqlExpression> values)
=> ApplyTypeMappingOnIn(new InExpression(item, values, _boolTypeMapping));

/// <inheritdoc />
public virtual InExpression In(SqlExpression item, SqlParameterExpression valuesParameter)
public virtual SqlExpression In(SqlExpression item, SqlParameterExpression valuesParameter)
=> ApplyTypeMappingOnIn(new InExpression(item, valuesParameter, _boolTypeMapping));

/// <inheritdoc />
public virtual LikeExpression Like(SqlExpression match, SqlExpression pattern, SqlExpression? escapeChar = null)
=> (LikeExpression)ApplyDefaultTypeMapping(new LikeExpression(match, pattern, escapeChar, null));
public virtual SqlExpression Like(SqlExpression match, SqlExpression pattern, SqlExpression? escapeChar = null)
=> ApplyDefaultTypeMapping(new LikeExpression(match, pattern, escapeChar, null));

/// <inheritdoc />
public virtual SqlFragmentExpression Fragment(string sql)
=> new(sql);
public virtual SqlExpression Fragment(string sql)
=> new SqlFragmentExpression(sql);

/// <inheritdoc />
public virtual SqlConstantExpression Constant(object value, RelationalTypeMapping? typeMapping = null)
=> new(value, typeMapping);
public virtual SqlExpression Constant(object value, RelationalTypeMapping? typeMapping = null)
=> new SqlConstantExpression(value, typeMapping);

/// <inheritdoc />
public virtual SqlConstantExpression Constant(object? value, Type type, RelationalTypeMapping? typeMapping = null)
=> new(value, type, typeMapping);
public virtual SqlExpression Constant(object? value, Type type, RelationalTypeMapping? typeMapping = null)
=> new SqlConstantExpression(value, type, typeMapping);

/// <inheritdoc />
public virtual bool TryCreateLeast(
Expand Down
27 changes: 19 additions & 8 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt
subquery.Offset,
subquery.Limit);

var predicate = VisitSqlBinary(
var predicate = Visit(
_sqlExpressionFactory.Equal(subqueryProjection, item), allowOptimizedExpansion: true, out _);
subquery.ApplyPredicate(predicate);
subquery.ClearOrdering();
Expand Down Expand Up @@ -908,7 +908,7 @@ protected virtual SqlExpression VisitIn(InExpression inExpression, bool allowOpt
result,
(expr, nullableValue) => _sqlExpressionFactory.OrElse(
expr,
VisitSqlBinary(_sqlExpressionFactory.Equal(item, nullableValue), allowOptimizedExpansion, out _)));
Visit(_sqlExpressionFactory.Equal(item, nullableValue), allowOptimizedExpansion, out _)));

InExpression ProcessInExpressionValues(
InExpression inExpression,
Expand Down Expand Up @@ -1873,8 +1873,13 @@ private SqlExpression RewriteNullSemantics(
return sqlBinaryExpression.Update(left, right);
}

private SqlExpression SimplifyLogicalSqlBinaryExpression(SqlBinaryExpression sqlBinaryExpression)
private SqlExpression SimplifyLogicalSqlBinaryExpression(SqlExpression expression)
{
if (expression is not SqlBinaryExpression sqlBinaryExpression)
{
return expression;
}

if (sqlBinaryExpression is
{
Left: SqlUnaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } leftUnary,
Expand Down Expand Up @@ -1930,13 +1935,14 @@ private SqlExpression SimplifyLogicalSqlBinaryExpression(SqlBinaryExpression sql
/// <summary>
/// Attempts to simplify a unary not operation on a non-nullable operand.
/// </summary>
/// <param name="sqlUnaryExpression">The expression to simplify.</param>
/// <param name="expression">The expression to simplify.</param>
/// <returns>The simplified expression, or the original expression if it cannot be simplified.</returns>
protected virtual SqlExpression OptimizeNonNullableNotExpression(SqlUnaryExpression sqlUnaryExpression)
protected virtual SqlExpression OptimizeNonNullableNotExpression(SqlExpression expression)
{
if (sqlUnaryExpression.OperatorType != ExpressionType.Not)
if (expression is not SqlUnaryExpression sqlUnaryExpression
|| sqlUnaryExpression.OperatorType != ExpressionType.Not)
{
return sqlUnaryExpression;
return expression;
}

switch (sqlUnaryExpression.Operand)
Expand Down Expand Up @@ -2207,8 +2213,13 @@ protected virtual TableExpressionBase UpdateParameterCollection(
SqlParameterExpression newCollectionParameter)
=> throw new InvalidOperationException();

private SqlExpression ProcessNullNotNull(SqlUnaryExpression sqlUnaryExpression, bool operandNullable)
private SqlExpression ProcessNullNotNull(SqlExpression sqlExpression, bool operandNullable)
{
if (sqlExpression is not SqlUnaryExpression sqlUnaryExpression)
{
return sqlExpression;
}

if (!operandNullable)
{
// when we know that operand is non-nullable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp

// CONCAT_WS filters out nulls, but string.Join treats them as empty strings; so coalesce (which is a no-op for non-nullable
// arguments).
arguments[i + 1] = sqlArgument switch
{
ColumnExpression { IsNullable: false } => sqlArgument,
SqlConstantExpression constantExpression => constantExpression.Value is null
? _sqlExpressionFactory.Constant(string.Empty)
: constantExpression,
_ => Dependencies.SqlExpressionFactory.Coalesce(sqlArgument, _sqlExpressionFactory.Constant(string.Empty))
};
arguments[i + 1] = Dependencies.SqlExpressionFactory.Coalesce(sqlArgument, _sqlExpressionFactory.Constant(string.Empty));
}

// CONCAT_WS never returns null; a null delimiter is interpreted as an empty string, and null arguments are skipped
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public class SqlServerDateTimeMemberTranslator(
_ => null
};

SqlFunctionExpression DatePart(string part)
SqlExpression DatePart(string part)
=> sqlExpressionFactory.Function(
"DATEPART",
arguments: [sqlExpressionFactory.Fragment(part), instance!],
Expand Down
Loading

0 comments on commit baaf79e

Please sign in to comment.