Skip to content

Commit

Permalink
Simplify NOT (#34142)
Browse files Browse the repository at this point in the history
When constructing `Not` expressions, perform some basic local simplifications.
These mostly match `OptimizeNonNullableNotExpression`, but they are written to
to be generally applicable, even for nullable expressions.
  • Loading branch information
ranma42 authored Jul 3, 2024
1 parent 0337960 commit 97d2365
Show file tree
Hide file tree
Showing 21 changed files with 275 additions and 457 deletions.
8 changes: 5 additions & 3 deletions src/EFCore.Relational/Query/ISqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ public interface ISqlExpressionFactory
/// <param name="operand">A <see cref="SqlExpression" /> to apply unary operator on.</param>
/// <param name="type">The type of the created expression.</param>
/// <param name="typeMapping">A type mapping to be assigned to the created expression.</param>
/// <param name="existingExpr">An optional expression that can be re-used if it matches the new expression.</param>
/// <returns>A <see cref="SqlExpression" /> with the given arguments.</returns>
SqlExpression? MakeUnary(
ExpressionType operatorType,
SqlExpression operand,
Type type,
RelationalTypeMapping? typeMapping = null);
RelationalTypeMapping? typeMapping = null,
SqlExpression? existingExpr = null);

/// <summary>
/// Creates a new <see cref="SqlExpression" /> with the given arguments.
Expand Down Expand Up @@ -275,11 +277,11 @@ SqlExpression Convert(
/// Creates a new <see cref="CaseExpression" /> which represent a CASE statement in a SQL tree.
/// </summary>
/// <param name="operand">An expression to compare with <see cref="CaseWhenClause.Test" /> in <paramref name="whenClauses" />.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare and get result from.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare or evaluate and get result from.</param>
/// <param name="elseResult">A value to return if no <paramref name="whenClauses" /> matches, if any.</param>
/// <returns>An expression representing a CASE statement in a SQL tree.</returns>
SqlExpression Case(
SqlExpression operand,
SqlExpression? operand,
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult);

Expand Down
118 changes: 88 additions & 30 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -599,10 +599,15 @@ public virtual SqlExpression Coalesce(SqlExpression left, SqlExpression right, R
ExpressionType operatorType,
SqlExpression operand,
Type type,
RelationalTypeMapping? typeMapping = null)
=> SqlUnaryExpression.IsValidOperator(operatorType)
? ApplyTypeMapping(new SqlUnaryExpression(operatorType, operand, type, null), typeMapping)
: null;
RelationalTypeMapping? typeMapping = null,
SqlExpression? existingExpr = null)
=> operatorType switch
{
ExpressionType.Not => ApplyTypeMapping(Not(operand, existingExpr), typeMapping),
_ when SqlUnaryExpression.IsValidOperator(operatorType)
=> ApplyTypeMapping(new SqlUnaryExpression(operatorType, operand, type, null), typeMapping),
_ => null,
};

/// <inheritdoc />
public virtual SqlExpression IsNull(SqlExpression operand)
Expand All @@ -620,33 +625,102 @@ public virtual SqlExpression Convert(SqlExpression operand, Type type, Relationa
public virtual SqlExpression Not(SqlExpression operand)
=> MakeUnary(ExpressionType.Not, operand, operand.Type, operand.TypeMapping)!;

private SqlExpression Not(SqlExpression operand, SqlExpression? existingExpr)
=> operand switch
{
// !(null) -> null
// ~(null) -> null (bitwise negation)
SqlConstantExpression { Value: null } => operand,

// !(true) -> false
// !(false) -> true
SqlConstantExpression { Value: bool boolValue } => Constant(!boolValue, operand.Type, operand.TypeMapping),

// !(!a) -> a
// ~(~a) -> a (bitwise negation)
SqlUnaryExpression { OperatorType: ExpressionType.Not } unary => unary.Operand,

// !(a IS NULL) -> a IS NOT NULL
SqlUnaryExpression { OperatorType: ExpressionType.Equal } unary => IsNotNull(unary.Operand),

// !(a IS NOT NULL) -> a IS NULL
SqlUnaryExpression { OperatorType: ExpressionType.NotEqual } unary => IsNull(unary.Operand),

// !(a AND b) -> !a OR !b (De Morgan)
SqlBinaryExpression { OperatorType: ExpressionType.AndAlso } binary
=> OrElse(Not(binary.Left), Not(binary.Right)),

// !(a OR b) -> !a AND !b (De Morgan)
SqlBinaryExpression { OperatorType: ExpressionType.OrElse } binary
=> AndAlso(Not(binary.Left), Not(binary.Right)),

// use equality where possible
// !(a == true) -> a == false
// !(a == false) -> a == true
SqlBinaryExpression { OperatorType: ExpressionType.Equal, Right: SqlConstantExpression { Value: bool } } binary
=> Equal(binary.Left, Not(binary.Right)),

// !(true == a) -> false == a
// !(false == a) -> true == a
SqlBinaryExpression { OperatorType: ExpressionType.Equal, Left: SqlConstantExpression { Value: bool } } binary
=> Equal(Not(binary.Left), binary.Right),

// !(a == b) -> a != b
SqlBinaryExpression { OperatorType: ExpressionType.Equal } sqlBinaryOperand => NotEqual(sqlBinaryOperand.Left, sqlBinaryOperand.Right),
// !(a != b) -> a == b
SqlBinaryExpression { OperatorType: ExpressionType.NotEqual } sqlBinaryOperand => Equal(sqlBinaryOperand.Left, sqlBinaryOperand.Right),

// !(CASE x WHEN t1 THEN r1 ... ELSE rN) -> CASE x WHEN t1 THEN !r1 ... ELSE !rN
CaseExpression caseExpression
when caseExpression.Type == typeof(bool)
&& caseExpression.ElseResult is null or SqlConstantExpression
&& caseExpression.WhenClauses.All(clause => clause.Result is SqlConstantExpression)
=> Case(
caseExpression.Operand,
[.. caseExpression.WhenClauses.Select(clause => new CaseWhenClause(clause.Test, Not(clause.Result)))],
caseExpression.ElseResult is null ? null : Not(caseExpression.ElseResult)),

_ => existingExpr is SqlUnaryExpression { OperatorType: ExpressionType.Not } unaryExpr && unaryExpr.Operand == operand
? existingExpr
: new SqlUnaryExpression(ExpressionType.Not, operand, operand.Type, null),
};

/// <inheritdoc />
public virtual SqlExpression Negate(SqlExpression operand)
=> MakeUnary(ExpressionType.Negate, operand, operand.Type, operand.TypeMapping)!;

/// <inheritdoc />
public virtual SqlExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
{
var operandTypeMapping = operand!.TypeMapping
?? whenClauses.Select(wc => wc.Test.TypeMapping).FirstOrDefault(t => t != null)
// Since we never look at type of Operand/Test after this place,
// we need to find actual typeMapping based on non-object type.
?? new[] { operand.Type }.Concat(whenClauses.Select(wc => wc.Test.Type))
.Where(t => t != typeof(object)).Select(t => _typeMappingSource.FindMapping(t, Dependencies.Model))
.FirstOrDefault();
RelationalTypeMapping? testTypeMapping;
if (operand == null)
{
testTypeMapping = _boolTypeMapping;
}
else
{
testTypeMapping = operand.TypeMapping
?? whenClauses.Select(wc => wc.Test.TypeMapping).FirstOrDefault(t => t != null)
// Since we never look at type of Operand/Test after this place,
// we need to find actual typeMapping based on non-object type.
?? new[] { operand.Type }.Concat(whenClauses.Select(wc => wc.Test.Type))
.Where(t => t != typeof(object)).Select(t => _typeMappingSource.FindMapping(t, Dependencies.Model))
.FirstOrDefault();

operand = ApplyTypeMapping(operand, testTypeMapping);
}

var resultTypeMapping = elseResult?.TypeMapping
?? whenClauses.Select(wc => wc.Result.TypeMapping).FirstOrDefault(t => t != null);

operand = ApplyTypeMapping(operand, operandTypeMapping);
elseResult = ApplyTypeMapping(elseResult, resultTypeMapping);

var typeMappedWhenClauses = new List<CaseWhenClause>();
foreach (var caseWhenClause in whenClauses)
{
typeMappedWhenClauses.Add(
new CaseWhenClause(
ApplyTypeMapping(caseWhenClause.Test, operandTypeMapping),
ApplyTypeMapping(caseWhenClause.Test, testTypeMapping),
ApplyTypeMapping(caseWhenClause.Result, resultTypeMapping)));
}

Expand All @@ -655,23 +729,7 @@ public virtual SqlExpression Case(SqlExpression? operand, IReadOnlyList<CaseWhen

/// <inheritdoc />
public virtual SqlExpression Case(IReadOnlyList<CaseWhenClause> whenClauses, SqlExpression? elseResult)
{
var resultTypeMapping = elseResult?.TypeMapping
?? whenClauses.Select(wc => wc.Result.TypeMapping).FirstOrDefault(t => t != null);

var typeMappedWhenClauses = new List<CaseWhenClause>();
foreach (var caseWhenClause in whenClauses)
{
typeMappedWhenClauses.Add(
new CaseWhenClause(
ApplyTypeMapping(caseWhenClause.Test, _boolTypeMapping),
ApplyTypeMapping(caseWhenClause.Result, resultTypeMapping)));
}

elseResult = ApplyTypeMapping(elseResult, resultTypeMapping);

return new CaseExpression(typeMappedWhenClauses, elseResult);
}
=> Case(operand: null, whenClauses, elseResult);

/// <inheritdoc />
public virtual SqlExpression Function(
Expand Down
16 changes: 5 additions & 11 deletions src/EFCore.Relational/Query/SqlExpressions/CaseExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ public class CaseExpression : SqlExpression
/// Creates a new instance of the <see cref="CaseExpression" /> class which represents a simple CASE expression.
/// </summary>
/// <param name="operand">An expression to compare with <see cref="CaseWhenClause.Test" /> in <see cref="WhenClauses" />.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare and get result from.</param>
/// <param name="whenClauses">A list of <see cref="CaseWhenClause" /> to compare or evaluate and get result from.</param>
/// <param name="elseResult">A value to return if no <see cref="WhenClauses" /> matches, if any.</param>
public CaseExpression(
SqlExpression operand,
SqlExpression? operand,
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult = null)
: base(whenClauses[0].Result.Type, whenClauses[0].Result.TypeMapping)
Expand All @@ -45,10 +45,8 @@ public CaseExpression(
public CaseExpression(
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult = null)
: base(whenClauses[0].Result.Type, whenClauses[0].Result.TypeMapping)
: this(null, whenClauses, elseResult)
{
_whenClauses.AddRange(whenClauses);
ElseResult = elseResult;
}

/// <summary>
Expand Down Expand Up @@ -94,9 +92,7 @@ protected override Expression VisitChildren(ExpressionVisitor visitor)
changed |= elseResult != ElseResult;

return changed
? operand == null
? new CaseExpression(whenClauses, elseResult)
: new CaseExpression(operand, whenClauses, elseResult)
? new CaseExpression(operand, whenClauses, elseResult)
: this;
}

Expand All @@ -113,9 +109,7 @@ public virtual CaseExpression Update(
IReadOnlyList<CaseWhenClause> whenClauses,
SqlExpression? elseResult)
=> operand != Operand || !whenClauses.SequenceEqual(WhenClauses) || elseResult != ElseResult
? (operand == null
? new CaseExpression(whenClauses, elseResult)
: new CaseExpression(operand, whenClauses, elseResult))
? new CaseExpression(operand, whenClauses, elseResult)
: this;

/// <inheritdoc />
Expand Down
Loading

0 comments on commit 97d2365

Please sign in to comment.