Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement nullability simplification for COLLATE and AT TIME ZONE #34263

Merged
merged 6 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2055,6 +2055,19 @@ private SqlExpression ProcessNullNotNull(SqlExpression sqlExpression, bool opera
sqlUnaryExpression.TypeMapping);
}

case CollateExpression collate:
{
// a COLLATE collation == null -> a == null
// a COLLATE collation != null -> a != null
return ProcessNullNotNull(
_sqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
collate.Operand,
typeof(bool),
sqlUnaryExpression.TypeMapping)!,
operandNullable);
}

case SqlUnaryExpression sqlUnaryOperand:
switch (sqlUnaryOperand.OperatorType)
{
Expand All @@ -2080,6 +2093,35 @@ private SqlExpression ProcessNullNotNull(SqlExpression sqlExpression, bool opera

break;

case AtTimeZoneExpression atTimeZone:
{
// a AT TIME ZONE b == null -> a == null || b == null
// a AT TIME ZONE b != null -> a != null && b != null
var left = ProcessNullNotNull(
_sqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
atTimeZone.Operand,
typeof(bool),
sqlUnaryExpression.TypeMapping)!,
operandNullable);

var right = ProcessNullNotNull(
_sqlExpressionFactory.MakeUnary(
sqlUnaryExpression.OperatorType,
atTimeZone.TimeZone,
typeof(bool),
sqlUnaryExpression.TypeMapping)!,
operandNullable);

return _sqlExpressionFactory.MakeBinary(
sqlUnaryExpression.OperatorType == ExpressionType.Equal
? ExpressionType.OrElse
: ExpressionType.AndAlso,
left,
right,
sqlUnaryExpression.TypeMapping)!;
}

case SqlBinaryExpression sqlBinaryOperand
when sqlBinaryOperand.OperatorType != ExpressionType.AndAlso
&& sqlBinaryOperand.OperatorType != ExpressionType.OrElse:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ protected OperatorsProceduralQueryTestBase()
{ typeof(bool), typeof(OperatorEntityBool) },
{ typeof(bool?), typeof(OperatorEntityNullableBool) },
{ typeof(DateTimeOffset), typeof(OperatorEntityDateTimeOffset) },
{ typeof(DateTimeOffset?), typeof(OperatorEntityNullableDateTimeOffset) },
};

ExpectedData = OperatorsData.Instance;
Expand All @@ -136,6 +137,7 @@ protected virtual async Task SeedAsync(OperatorsContext ctx)
ctx.Set<OperatorEntityBool>().AddRange(ExpectedData.OperatorEntitiesBool);
ctx.Set<OperatorEntityNullableBool>().AddRange(ExpectedData.OperatorEntitiesNullableBool);
ctx.Set<OperatorEntityDateTimeOffset>().AddRange(ExpectedData.OperatorEntitiesDateTimeOffset);
ctx.Set<OperatorEntityNullableDateTimeOffset>().AddRange(ExpectedData.OperatorEntitiesNullableDateTimeOffset);

await ctx.SaveChangesAsync();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ protected virtual Task Seed(OperatorsContext ctx)
ctx.Set<OperatorEntityBool>().AddRange(ExpectedData.OperatorEntitiesBool);
ctx.Set<OperatorEntityNullableBool>().AddRange(ExpectedData.OperatorEntitiesNullableBool);
ctx.Set<OperatorEntityDateTimeOffset>().AddRange(ExpectedData.OperatorEntitiesDateTimeOffset);
ctx.Set<OperatorEntityNullableDateTimeOffset>().AddRange(ExpectedData.OperatorEntitiesNullableDateTimeOffset);

return ctx.SaveChangesAsync();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ public virtual Task Collate_case_sensitive_constant(bool async)
c => c.ContactName == EF.Functions.Collate("maria anders", CaseSensitiveCollation),
c => c.ContactName.Equals("maria anders", StringComparison.Ordinal));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Collate_is_null(bool async)
=> AssertCount(
async,
ss => ss.Set<Customer>(),
ss => ss.Set<Customer>(),
c => EF.Functions.Collate(c.Region, CaseSensitiveCollation) == null,
c => c.Region == null);

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
public virtual Task Least(bool async)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Microsoft.EntityFrameworkCore.TestModels.Operators;

#nullable disable

public class OperatorEntityNullableDateTimeOffset : OperatorEntityBase
{
public DateTimeOffset? Value { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ protected override void OnModelCreating(ModelBuilder modelBuilder)
modelBuilder.Entity<OperatorEntityBool>().Property(x => x.Id).ValueGeneratedNever();
modelBuilder.Entity<OperatorEntityNullableBool>().Property(x => x.Id).ValueGeneratedNever();
modelBuilder.Entity<OperatorEntityDateTimeOffset>().Property(x => x.Id).ValueGeneratedNever();
modelBuilder.Entity<OperatorEntityNullableDateTimeOffset>().Property(x => x.Id).ValueGeneratedNever();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,21 @@ public class OperatorsData : ISetSource
() => new DateTimeOffset(new DateTime(2000, 1, 1, 9, 0, 0), new TimeSpan(13, 0, 0))
];

private readonly List<Expression<Func<DateTimeOffset?>>> _nullableDateTimeOffsetValues =
[
() => null,
() => new DateTimeOffset(new DateTime(2000, 1, 1, 10, 0, 0), new TimeSpan(-8, 0, 0)),
() => new DateTimeOffset(new DateTime(2000, 1, 1, 9, 0, 0), new TimeSpan(13, 0, 0))
];

public IReadOnlyList<OperatorEntityString> OperatorEntitiesString { get; }
public IReadOnlyList<OperatorEntityInt> OperatorEntitiesInt { get; }
public IReadOnlyList<OperatorEntityNullableInt> OperatorEntitiesNullableInt { get; }
public IReadOnlyList<OperatorEntityLong> OperatorEntitiesLong { get; }
public IReadOnlyList<OperatorEntityBool> OperatorEntitiesBool { get; }
public IReadOnlyList<OperatorEntityNullableBool> OperatorEntitiesNullableBool { get; }
public IReadOnlyList<OperatorEntityDateTimeOffset> OperatorEntitiesDateTimeOffset { get; }
public IReadOnlyList<OperatorEntityNullableDateTimeOffset> OperatorEntitiesNullableDateTimeOffset { get; }
public IDictionary<Type, List<Expression>> ConstantExpressionsPerType { get; }

private OperatorsData()
Expand All @@ -74,6 +82,7 @@ private OperatorsData()
OperatorEntitiesBool = CreateBools();
OperatorEntitiesNullableBool = CreateNullableBools();
OperatorEntitiesDateTimeOffset = CreateDateTimeOffsets();
OperatorEntitiesNullableDateTimeOffset = CreateNullableDateTimeOffsets();

ConstantExpressionsPerType = new Dictionary<Type, List<Expression>>
{
Expand All @@ -84,6 +93,7 @@ private OperatorsData()
{ typeof(bool), _boolValues.Select(x => x.Body).ToList() },
{ typeof(bool?), _nullableBoolValues.Select(x => x.Body).ToList() },
{ typeof(DateTimeOffset), _dateTimeOffsetValues.Select(x => x.Body).ToList() },
{ typeof(DateTimeOffset?), _nullableDateTimeOffsetValues.Select(x => x.Body).ToList() },
};
}

Expand Down Expand Up @@ -125,6 +135,11 @@ public virtual IQueryable<TEntity> Set<TEntity>()
return (IQueryable<TEntity>)OperatorEntitiesDateTimeOffset.AsQueryable();
}

if (typeof(TEntity) == typeof(OperatorEntityNullableDateTimeOffset))
{
return (IQueryable<TEntity>)OperatorEntitiesNullableDateTimeOffset.AsQueryable();
}

throw new InvalidOperationException("Invalid entity type: " + typeof(TEntity));
}

Expand All @@ -151,4 +166,8 @@ public IReadOnlyList<OperatorEntityNullableBool> CreateNullableBools()
public IReadOnlyList<OperatorEntityDateTimeOffset> CreateDateTimeOffsets()
=> _dateTimeOffsetValues
.Select((x, i) => new OperatorEntityDateTimeOffset { Id = i + 1, Value = _dateTimeOffsetValues[i].Compile()() }).ToList();

public IReadOnlyList<OperatorEntityNullableDateTimeOffset> CreateNullableDateTimeOffsets()
=> _nullableDateTimeOffsetValues.Select((x, i) => new OperatorEntityNullableDateTimeOffset { Id = i + 1, Value = _nullableDateTimeOffsetValues[i].Compile()() })
.ToList();
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ FROM [Customers] AS [c]
""");
}

public override async Task Collate_is_null(bool async)
{
await base.Collate_is_null(async);

AssertSql(
"""
SELECT COUNT(*)
FROM [Customers] AS [c]
WHERE [c].[Region] IS NULL
""");
}

[SqlServerCondition(SqlServerCondition.SupportsFunctions2022)]
public override async Task Least(bool async)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,38 @@ where EF.Functions.AtTimeZone(e1.Value, "UTC") == e2.Value
FROM [OperatorEntityDateTimeOffset] AS [o]
CROSS JOIN [OperatorEntityDateTimeOffset] AS [o0]
WHERE [o].[Value] AT TIME ZONE 'UTC' = [o0].[Value]
""");
}

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
[SqlServerCondition(SqlServerCondition.SupportsSqlClr)]
public virtual async Task Where_AtTimeZone_is_null(bool async)
{
var contextFactory = await InitializeAsync<OperatorsContext>(seed: Seed);
using var context = contextFactory.CreateContext();

var expected = (from e in ExpectedData.OperatorEntitiesNullableDateTimeOffset
where e.Value == null
select e.Id).ToList();

var actual = (from e in context.Set<OperatorEntityNullableDateTimeOffset>()
#pragma warning disable CS8073 // The result of the expression is always the same since a value of this type is never equal to 'null'
where EF.Functions.AtTimeZone(e.Value.Value, "UTC") == null
#pragma warning restore CS8073 // The result of the expression is always the same since a value of this type is never equal to 'null'
select e.Id).ToList();

Assert.Equal(expected.Count, actual.Count);
for (var i = 0; i < expected.Count; i++)
{
Assert.Equal(expected[i], actual[i]);
}

AssertSql(
"""
SELECT [o].[Id]
FROM [OperatorEntityNullableDateTimeOffset] AS [o]
WHERE [o].[Value] IS NULL
""");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ SELECT COUNT(*)
""");
}

public override async Task Collate_is_null(bool async)
{
await base.Collate_is_null(async);

AssertSql(
"""
SELECT COUNT(*)
FROM "Customers" AS "c"
WHERE "c"."Region" IS NULL
""");
}

protected override string CaseInsensitiveCollation
=> "NOCASE";

Expand Down