Skip to content

Commit

Permalink
Avoid division by decimal 0 in SQLite (#33989)
Browse files Browse the repository at this point in the history
  • Loading branch information
ranma42 authored Jun 14, 2024
1 parent 17b1deb commit f0a733d
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
function,
new[] { sqlBinary.Left, sqlBinary.Right },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
argumentsPropagateNullability: new[] { false, false },
visitedExpression.Type,
visitedExpression.TypeMapping);
}
Expand Down Expand Up @@ -513,7 +513,7 @@ private Expression DoDecimalArithmetics(SqlExpression visitedExpression, Express
return op switch
{
ExpressionType.Add => DecimalArithmeticExpressionFactoryMethod(ResolveFunctionNameFromExpressionType(op), left, right),
ExpressionType.Divide => DecimalArithmeticExpressionFactoryMethod(ResolveFunctionNameFromExpressionType(op), left, right),
ExpressionType.Divide => DecimalDivisionExpressionFactoryMethod(ResolveFunctionNameFromExpressionType(op), left, right),
ExpressionType.Multiply => DecimalArithmeticExpressionFactoryMethod(ResolveFunctionNameFromExpressionType(op), left, right),
ExpressionType.Subtract => DecimalSubtractExpressionFactoryMethod(left, right),
_ => visitedExpression
Expand All @@ -537,6 +537,14 @@ Expression DecimalArithmeticExpressionFactoryMethod(string name, SqlExpression l
new[] { true, true },
visitedExpression.Type);

Expression DecimalDivisionExpressionFactoryMethod(string name, SqlExpression left, SqlExpression right)
=> Dependencies.SqlExpressionFactory.Function(
name,
new[] { left, right },
nullable: true,
new[] { false, false },
visitedExpression.Type);

Expression DecimalSubtractExpressionFactoryMethod(SqlExpression left, SqlExpression right)
{
var subtrahend = Dependencies.SqlExpressionFactory.Function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ private void InitializeDbConnection(DbConnection connection)

sqliteConnection.CreateFunction(
"ef_mod",
(decimal? dividend, decimal? divisor) => dividend % divisor,
(decimal? dividend, decimal? divisor) => divisor == 0m ? null : dividend % divisor,
isDeterministic: true);

sqliteConnection.CreateFunction(
Expand All @@ -123,7 +123,7 @@ private void InitializeDbConnection(DbConnection connection)

sqliteConnection.CreateFunction(
name: "ef_divide",
(decimal? dividend, decimal? divisor) => dividend / divisor,
(decimal? dividend, decimal? divisor) => divisor == 0m ? null : dividend / divisor,
isDeterministic: true);

sqliteConnection.CreateFunction(
Expand Down
75 changes: 75 additions & 0 deletions test/EFCore.Specification.Tests/BuiltInDataTypesTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2343,6 +2343,44 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
EnumU16 = EnumU16.SomeValue,
EnumS8 = EnumS8.SomeValue
});

eb.HasData(
new sbyte[] { -10, -7, -4, -3, -2, -1, 0, 1, 2, 3, 8, 15 }
.Select((x, i) =>
new BuiltInDataTypes
{
Id = 17 + i,
PartitionId = 2,
TestInt16 = x,
TestInt32 = x,
TestInt64 = x,
TestDouble = x * 0.25,
TestDecimal = x * 0.2M,
TestDateTime = DateTime.Parse("01/01/2000 12:34:56"),
TestDateTimeOffset = new DateTimeOffset(DateTime.Parse("01/01/2000 12:34:56"), TimeSpan.FromHours(-8.0)),
TestTimeSpan = new TimeSpan(0, 10, 9, 8, 7),
TestDateOnly = new DateOnly(2020, 3, 1),
TestTimeOnly = new TimeOnly(12, 30, 45, 123),
TestSingle = x * 0.25F,
TestBoolean = x > 0,
TestByte = (byte)(10 + x),
TestUnsignedInt16 = (byte)(10 + x),
TestUnsignedInt32 = (byte)(10 + x),
TestUnsignedInt64 = (byte)(10 + x),
TestCharacter = 'a',
TestSignedByte = x,
Enum64 = Enum64.SomeValue,
Enum32 = Enum32.SomeValue,
Enum16 = Enum16.SomeValue,
Enum8 = Enum8.SomeValue,
EnumU64 = EnumU64.SomeValue,
EnumU32 = EnumU32.SomeValue,
EnumU16 = EnumU16.SomeValue,
EnumS8 = EnumS8.SomeValue
}
)
);

eb.Property(e => e.Id).ValueGeneratedNever();
});
modelBuilder.Entity<BuiltInDataTypesShadow>().Property(e => e.Id).ValueGeneratedNever();
Expand Down Expand Up @@ -2380,6 +2418,43 @@ protected override void OnModelCreating(ModelBuilder modelBuilder, DbContext con
EnumU16 = EnumU16.SomeValue,
EnumS8 = EnumS8.SomeValue
});

eb.HasData(
new sbyte?[] { null, -10, -7, -4, -3, -2, -1, 0, 1, 2, 3, 8, 15 }
.Select((x, i) =>
new BuiltInNullableDataTypes
{
Id = 17 + i,
PartitionId = 2,
TestNullableInt16 = x,
TestNullableInt32 = x,
TestNullableInt64 = x,
TestNullableDouble = x * 0.25,
TestNullableDecimal = x * 0.2M,
TestNullableDateTimeOffset = new DateTimeOffset(new DateTime(), TimeSpan.FromHours(-8.0)),
TestNullableTimeSpan = new TimeSpan(0, 10, 9, 8, 7),
TestNullableDateOnly = new DateOnly(2020, 3, 1),
TestNullableTimeOnly = new TimeOnly(12, 30, 45, 123),
TestNullableSingle = x * 0.25F,
TestNullableBoolean = x == null ? null : x > 0,
TestNullableByte = (byte?)(10 + x),
TestNullableUnsignedInt16 = (byte?)(10 + x),
TestNullableUnsignedInt32 = (byte?)(10 + x),
TestNullableUnsignedInt64 = (byte?)(10 + x),
TestNullableCharacter = 'a',
TestNullableSignedByte = x,
Enum64 = Enum64.SomeValue,
Enum32 = Enum32.SomeValue,
Enum16 = Enum16.SomeValue,
Enum8 = Enum8.SomeValue,
EnumU64 = EnumU64.SomeValue,
EnumU32 = EnumU32.SomeValue,
EnumU16 = EnumU16.SomeValue,
EnumS8 = EnumS8.SomeValue
}
)
);

eb.Property(e => e.Id).ValueGeneratedNever();
});
modelBuilder.Entity<BuiltInNullableDataTypesShadow>().Property(e => e.Id).ValueGeneratedNever();
Expand Down
57 changes: 56 additions & 1 deletion test/EFCore.Sqlite.FunctionalTests/BuiltInDataTypesSqliteTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,7 @@ from dt2 in context.Set<BuiltInDataTypes>().ToList()
subtract = dt1.TestDecimal - dt2.TestDecimal,
multiply = dt1.TestDecimal * dt2.TestDecimal,
divide = dt1.TestDecimal / dt2.TestDecimal,
modulus = dt1.TestDecimal % dt2.TestDecimal,
negate = -dt1.TestDecimal
}).ToList();

Expand All @@ -1829,6 +1830,7 @@ from dt2 in context.Set<BuiltInDataTypes>()
subtract = dt1.TestDecimal - dt2.TestDecimal,
multiply = dt1.TestDecimal * dt2.TestDecimal,
divide = dt1.TestDecimal / dt2.TestDecimal,
modulus = dt1.TestDecimal % dt2.TestDecimal,
negate = -dt1.TestDecimal
}).ToList();

Expand All @@ -1839,19 +1841,72 @@ from dt2 in context.Set<BuiltInDataTypes>()
Assert.Equal(expected[i].subtract, actual[i].subtract);
Assert.Equal(expected[i].multiply, actual[i].multiply);
Assert.Equal(expected[i].divide, actual[i].divide);
Assert.Equal(expected[i].modulus, actual[i].modulus);
Assert.Equal(expected[i].negate, actual[i].negate);
}

AssertSql(
"""
SELECT ef_add("b"."TestDecimal", "b0"."TestDecimal") AS "add", ef_add("b"."TestDecimal", ef_negate("b0"."TestDecimal")) AS "subtract", ef_multiply("b"."TestDecimal", "b0"."TestDecimal") AS "multiply", ef_divide("b"."TestDecimal", "b0"."TestDecimal") AS "divide", ef_negate("b"."TestDecimal") AS "negate"
SELECT ef_add("b"."TestDecimal", "b0"."TestDecimal") AS "add", ef_add("b"."TestDecimal", ef_negate("b0"."TestDecimal")) AS "subtract", ef_multiply("b"."TestDecimal", "b0"."TestDecimal") AS "multiply", ef_divide("b"."TestDecimal", "b0"."TestDecimal") AS "divide", ef_mod("b"."TestDecimal", "b0"."TestDecimal") AS "modulus", ef_negate("b"."TestDecimal") AS "negate"
FROM "BuiltInDataTypes" AS "b"
CROSS JOIN "BuiltInDataTypes" AS "b0"
WHERE "b0"."TestDecimal" <> '0.0'
ORDER BY "b"."Id", "b0"."Id"
""");
}

[ConditionalFact]
public virtual void Projecting_arithmetic_operations_on_nullable_decimals()
{
using var context = CreateContext();
var expected = (from dt1 in context.Set<BuiltInNullableDataTypes>().ToList()
from dt2 in context.Set<BuiltInNullableDataTypes>().ToList()
orderby dt1.Id, dt2.Id
select new
{
add = dt1.TestNullableDecimal + dt2.TestNullableDecimal,
subtract = dt1.TestNullableDecimal - dt2.TestNullableDecimal,
multiply = dt1.TestNullableDecimal * dt2.TestNullableDecimal,
divide = dt2.TestNullableDecimal == 0 ? null : dt1.TestNullableDecimal / dt2.TestNullableDecimal,
modulus = dt2.TestNullableDecimal == 0 ? null : dt1.TestNullableDecimal % dt2.TestNullableDecimal,
negate = -dt1.TestNullableDecimal
}).ToList();

Fixture.TestSqlLoggerFactory.Clear();

var actual = (from dt1 in context.Set<BuiltInNullableDataTypes>()
from dt2 in context.Set<BuiltInNullableDataTypes>()
orderby dt1.Id, dt2.Id
select new
{
add = dt1.TestNullableDecimal + dt2.TestNullableDecimal,
subtract = dt1.TestNullableDecimal - dt2.TestNullableDecimal,
multiply = dt1.TestNullableDecimal * dt2.TestNullableDecimal,
divide = dt1.TestNullableDecimal / dt2.TestNullableDecimal,
modulus = dt1.TestNullableDecimal % dt2.TestNullableDecimal,
negate = -dt1.TestNullableDecimal
}).ToList();

Assert.Equal(expected.Count, actual.Count);
for (var i = 0; i < expected.Count; i++)
{
Assert.Equal(expected[i].add, actual[i].add);
Assert.Equal(expected[i].subtract, actual[i].subtract);
Assert.Equal(expected[i].multiply, actual[i].multiply);
Assert.Equal(expected[i].divide, actual[i].divide);
Assert.Equal(expected[i].modulus, actual[i].modulus);
Assert.Equal(expected[i].negate, actual[i].negate);
}

AssertSql(
"""
SELECT ef_add("b"."TestNullableDecimal", "b0"."TestNullableDecimal") AS "add", ef_add("b"."TestNullableDecimal", ef_negate("b0"."TestNullableDecimal")) AS "subtract", ef_multiply("b"."TestNullableDecimal", "b0"."TestNullableDecimal") AS "multiply", ef_divide("b"."TestNullableDecimal", "b0"."TestNullableDecimal") AS "divide", ef_mod("b"."TestNullableDecimal", "b0"."TestNullableDecimal") AS "modulus", ef_negate("b"."TestNullableDecimal") AS "negate"
FROM "BuiltInNullableDataTypes" AS "b"
CROSS JOIN "BuiltInNullableDataTypes" AS "b0"
ORDER BY "b"."Id", "b0"."Id"
""");
}

private void AssertTranslationFailed(Action testCode)
=> Assert.Contains(
CoreStrings.TranslationFailed("")[21..],
Expand Down

0 comments on commit f0a733d

Please sign in to comment.