Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid division by decimal 0 in SQLite #33989

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
function,
new[] { sqlBinary.Left, sqlBinary.Right },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
argumentsPropagateNullability: new[] { false, false },
visitedExpression.Type,
visitedExpression.TypeMapping);
}
Expand Down Expand Up @@ -513,7 +513,7 @@ private Expression DoDecimalArithmetics(SqlExpression visitedExpression, Express
return op switch
{
ExpressionType.Add => DecimalArithmeticExpressionFactoryMethod(ResolveFunctionNameFromExpressionType(op), left, right),
ExpressionType.Divide => DecimalArithmeticExpressionFactoryMethod(ResolveFunctionNameFromExpressionType(op), left, right),
ExpressionType.Divide => DecimalDivisionExpressionFactoryMethod(ResolveFunctionNameFromExpressionType(op), left, right),
ExpressionType.Multiply => DecimalArithmeticExpressionFactoryMethod(ResolveFunctionNameFromExpressionType(op), left, right),
ExpressionType.Subtract => DecimalSubtractExpressionFactoryMethod(left, right),
_ => visitedExpression
Expand All @@ -537,6 +537,14 @@ Expression DecimalArithmeticExpressionFactoryMethod(string name, SqlExpression l
new[] { true, true },
visitedExpression.Type);

Expression DecimalDivisionExpressionFactoryMethod(string name, SqlExpression left, SqlExpression right)
=> Dependencies.SqlExpressionFactory.Function(
name,
new[] { left, right },
nullable: true,
new[] { false, false },
visitedExpression.Type);

Expression DecimalSubtractExpressionFactoryMethod(SqlExpression left, SqlExpression right)
{
var subtrahend = Dependencies.SqlExpressionFactory.Function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ private void InitializeDbConnection(DbConnection connection)

sqliteConnection.CreateFunction(
"ef_mod",
(decimal? dividend, decimal? divisor) => dividend % divisor,
(decimal? dividend, decimal? divisor) => divisor == 0m ? null : dividend % divisor,
isDeterministic: true);

sqliteConnection.CreateFunction(
Expand All @@ -123,7 +123,7 @@ private void InitializeDbConnection(DbConnection connection)

sqliteConnection.CreateFunction(
name: "ef_divide",
(decimal? dividend, decimal? divisor) => dividend / divisor,
(decimal? dividend, decimal? divisor) => divisor == 0m ? null : dividend / divisor,
isDeterministic: true);

sqliteConnection.CreateFunction(
Expand Down
75 changes: 75 additions & 0 deletions test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2343,6 +2343,44 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
EnumU16 = EnumU16.SomeValue,
EnumS8 = EnumS8.SomeValue
});

eb.HasData(
new sbyte[] { -10, -7, -4, -3, -2, -1, 0, 1, 2, 3, 8, 15 }
.Select((x, i) =>
new BuiltInDataTypes
{
Id = 17 + i,
PartitionId = 2,
TestInt16 = x,
TestInt32 = x,
TestInt64 = x,
TestDouble = x * 0.25,
TestDecimal = x * 0.2M,
TestDateTime = DateTime.Parse("01/01/2000 12:34:56"),
TestDateTimeOffset = new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)),
TestTimeSpan = new TimeSpan(0, 10, 9, 8, 7),
TestDateOnly = new DateOnly(2020, 3, 1),
TestTimeOnly = new TimeOnly(12, 30, 45, 123),
TestSingle = x * 0.25F,
TestBoolean = x > 0,
TestByte = (byte)(10 + x),
TestUnsignedInt16 = (byte)(10 + x),
TestUnsignedInt32 = (byte)(10 + x),
TestUnsignedInt64 = (byte)(10 + x),
TestCharacter = 'a',
TestSignedByte = x,
Enum64 = Enum64.SomeValue,
Enum32 = Enum32.SomeValue,
Enum16 = Enum16.SomeValue,
Enum8 = Enum8.SomeValue,
EnumU64 = EnumU64.SomeValue,
EnumU32 = EnumU32.SomeValue,
EnumU16 = EnumU16.SomeValue,
EnumS8 = EnumS8.SomeValue
}
)
);

eb.Property(e => e.Id).ValueGeneratedNever();
});
modelBuilder.Entity<BuiltInDataTypesShadow>().Property(e => e.Id).ValueGeneratedNever();
Expand Down Expand Up @@ -2380,6 +2418,43 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
EnumU16 = EnumU16.SomeValue,
EnumS8 = EnumS8.SomeValue
});

eb.HasData(
new sbyte?[] { null, -10, -7, -4, -3, -2, -1, 0, 1, 2, 3, 8, 15 }
.Select((x, i) =>
new BuiltInNullableDataTypes
{
Id = 17 + i,
PartitionId = 2,
TestNullableInt16 = x,
TestNullableInt32 = x,
TestNullableInt64 = x,
TestNullableDouble = x * 0.25,
TestNullableDecimal = x * 0.2M,
TestNullableDateTimeOffset = new DateTimeOffset(new DateTime(), TimeSpan.FromHours(-8.0)),
TestNullableTimeSpan = new TimeSpan(0, 10, 9, 8, 7),
TestNullableDateOnly = new DateOnly(2020, 3, 1),
TestNullableTimeOnly = new TimeOnly(12, 30, 45, 123),
TestNullableSingle = x * 0.25F,
TestNullableBoolean = x == null ? null : x > 0,
TestNullableByte = (byte?)(10 + x),
TestNullableUnsignedInt16 = (byte?)(10 + x),
TestNullableUnsignedInt32 = (byte?)(10 + x),
TestNullableUnsignedInt64 = (byte?)(10 + x),
TestNullableCharacter = 'a',
TestNullableSignedByte = x,
Enum64 = Enum64.SomeValue,
Enum32 = Enum32.SomeValue,
Enum16 = Enum16.SomeValue,
Enum8 = Enum8.SomeValue,
EnumU64 = EnumU64.SomeValue,
EnumU32 = EnumU32.SomeValue,
EnumU16 = EnumU16.SomeValue,
EnumS8 = EnumS8.SomeValue
}
)
);

eb.Property(e => e.Id).ValueGeneratedNever();
});
modelBuilder.Entity<BuiltInNullableDataTypesShadow>().Property(e => e.Id).ValueGeneratedNever();
Expand Down
57 changes: 56 additions & 1 deletion test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,7 @@ from dt2 in context.Set<BuiltInDataTypes>().ToList()
subtract = dt1.TestDecimal - dt2.TestDecimal,
multiply = dt1.TestDecimal * dt2.TestDecimal,
divide = dt1.TestDecimal / dt2.TestDecimal,
modulus = dt1.TestDecimal % dt2.TestDecimal,
negate = -dt1.TestDecimal
}).ToList();

Expand All @@ -1829,6 +1830,7 @@ from dt2 in context.Set<BuiltInDataTypes>()
subtract = dt1.TestDecimal - dt2.TestDecimal,
multiply = dt1.TestDecimal * dt2.TestDecimal,
divide = dt1.TestDecimal / dt2.TestDecimal,
modulus = dt1.TestDecimal % dt2.TestDecimal,
negate = -dt1.TestDecimal
}).ToList();

Expand All @@ -1839,19 +1841,72 @@ from dt2 in context.Set<BuiltInDataTypes>()
Assert.Equal(expected[i].subtract, actual[i].subtract);
Assert.Equal(expected[i].multiply, actual[i].multiply);
Assert.Equal(expected[i].divide, actual[i].divide);
Assert.Equal(expected[i].modulus, actual[i].modulus);
Assert.Equal(expected[i].negate, actual[i].negate);
}

AssertSql(
"""
SELECT ef_add("b"."TestDecimal", "b0"."TestDecimal") AS "add", ef_add("b"."TestDecimal", ef_negate("b0"."TestDecimal")) AS "subtract", ef_multiply("b"."TestDecimal", "b0"."TestDecimal") AS "multiply", ef_divide("b"."TestDecimal", "b0"."TestDecimal") AS "divide", ef_negate("b"."TestDecimal") AS "negate"
SELECT ef_add("b"."TestDecimal", "b0"."TestDecimal") AS "add", ef_add("b"."TestDecimal", ef_negate("b0"."TestDecimal")) AS "subtract", ef_multiply("b"."TestDecimal", "b0"."TestDecimal") AS "multiply", ef_divide("b"."TestDecimal", "b0"."TestDecimal") AS "divide", ef_mod("b"."TestDecimal", "b0"."TestDecimal") AS "modulus", ef_negate("b"."TestDecimal") AS "negate"
FROM "BuiltInDataTypes" AS "b"
CROSS JOIN "BuiltInDataTypes" AS "b0"
WHERE "b0"."TestDecimal" <> '0.0'
ORDER BY "b"."Id", "b0"."Id"
""");
}

[ConditionalFact]
public virtual void Projecting_arithmetic_operations_on_nullable_decimals()
{
using var context = CreateContext();
var expected = (from dt1 in context.Set<BuiltInNullableDataTypes>().ToList()
from dt2 in context.Set<BuiltInNullableDataTypes>().ToList()
orderby dt1.Id, dt2.Id
select new
{
add = dt1.TestNullableDecimal + dt2.TestNullableDecimal,
subtract = dt1.TestNullableDecimal - dt2.TestNullableDecimal,
multiply = dt1.TestNullableDecimal * dt2.TestNullableDecimal,
divide = dt2.TestNullableDecimal == 0 ? null : dt1.TestNullableDecimal / dt2.TestNullableDecimal,
modulus = dt2.TestNullableDecimal == 0 ? null : dt1.TestNullableDecimal % dt2.TestNullableDecimal,
negate = -dt1.TestNullableDecimal
}).ToList();

Fixture.TestSqlLoggerFactory.Clear();

var actual = (from dt1 in context.Set<BuiltInNullableDataTypes>()
from dt2 in context.Set<BuiltInNullableDataTypes>()
orderby dt1.Id, dt2.Id
select new
{
add = dt1.TestNullableDecimal + dt2.TestNullableDecimal,
subtract = dt1.TestNullableDecimal - dt2.TestNullableDecimal,
multiply = dt1.TestNullableDecimal * dt2.TestNullableDecimal,
divide = dt1.TestNullableDecimal / dt2.TestNullableDecimal,
modulus = dt1.TestNullableDecimal % dt2.TestNullableDecimal,
negate = -dt1.TestNullableDecimal
}).ToList();

Assert.Equal(expected.Count, actual.Count);
for (var i = 0; i < expected.Count; i++)
{
Assert.Equal(expected[i].add, actual[i].add);
Assert.Equal(expected[i].subtract, actual[i].subtract);
Assert.Equal(expected[i].multiply, actual[i].multiply);
Assert.Equal(expected[i].divide, actual[i].divide);
Assert.Equal(expected[i].modulus, actual[i].modulus);
Assert.Equal(expected[i].negate, actual[i].negate);
}

AssertSql(
"""
SELECT ef_add("b"."TestNullableDecimal", "b0"."TestNullableDecimal") AS "add", ef_add("b"."TestNullableDecimal", ef_negate("b0"."TestNullableDecimal")) AS "subtract", ef_multiply("b"."TestNullableDecimal", "b0"."TestNullableDecimal") AS "multiply", ef_divide("b"."TestNullableDecimal", "b0"."TestNullableDecimal") AS "divide", ef_mod("b"."TestNullableDecimal", "b0"."TestNullableDecimal") AS "modulus", ef_negate("b"."TestNullableDecimal") AS "negate"
FROM "BuiltInNullableDataTypes" AS "b"
CROSS JOIN "BuiltInNullableDataTypes" AS "b0"
ORDER BY "b"."Id", "b0"."Id"
""");
}

private void AssertTranslationFailed(Action testCode)
=> Assert.Contains(
CoreStrings.TranslationFailed("")[21..],
Expand Down