Skip to content

Commit

Permalink
fix(datatypes): decimal normalization failed for integers
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Aug 30, 2023
1 parent dcd9772 commit 5213958
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 7 deletions.
2 changes: 2 additions & 0 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ def _literal(_, op):

if dtype.is_array():
value = list(value)
elif dtype.is_decimal():
value = value.normalize()

return sa.literal(value)

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ def compile_literal(t, op, *, raw=False, **kwargs):
return F.struct(*(F.lit(val).alias(name) for name, val in value.items()))
elif dtype.is_timestamp():
return F.from_utc_timestamp(F.lit(str(value)), tz="UTC")
elif dtype.is_decimal():
return F.lit(value.normalize())
else:
return F.lit(value)

Expand Down
43 changes: 43 additions & 0 deletions ibis/common/numeric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from decimal import Context, Decimal, InvalidOperation


def normalize_decimal(value, precision: int | None = None, scale: int | None = None):
context = Context(prec=38 if precision is None else precision)

try:
if isinstance(value, float):
out = Decimal(str(value))
else:
out = Decimal(value)
except InvalidOperation:
raise TypeError(f"Unable to construct decimal from {value!r}")

out = out.normalize(context=context)
components = out.as_tuple()
n_digits = len(components.digits)
exponent = components.exponent

if precision is not None and precision < n_digits:
raise TypeError(
f"Decimal value {value} has too many digits for precision: {precision}"
)

if scale is not None:
if exponent < -scale:
raise TypeError(
f"Normalizing {value} with scale {exponent} to scale -{scale} "
"would loose precision"
)

other = Decimal(10) ** -scale
try:
out = out.quantize(other, context=context)
except InvalidOperation:
raise TypeError(
f"Unable to normalize {value!r} as decimal with precision {precision} "
f"and scale {scale}"
)

return out
78 changes: 78 additions & 0 deletions ibis/common/tests/test_numeric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

from decimal import Context, localcontext
from decimal import Decimal as D

import pytest

from ibis.common.numeric import normalize_decimal


@pytest.mark.parametrize(
("value", "precision", "scale", "expected"),
[
(1, None, None, D("1")),
(1.0, None, None, D("1.0")),
(1.0, 2, None, D("1.0")),
(1.0, 2, 1, D("1.0")),
(1.0, 3, 2, D("1.0")),
(1.0, 3, 1, D("1.0")),
(1.0, 3, 0, D("1")),
(1.0, 2, 0, D("1")),
(1.0, 1, 0, D("1")),
(3.14, 3, 2, D("3.14")),
(3.14, 10, 2, D("3.14")),
(3.14, 10, 3, D("3.14")),
(3.14, 10, 4, D("3.14")),
(1234.567, 10, 4, D("1234.567")),
(1234.567, 10, 3, D("1234.567")),
],
)
def test_normalize_decimal(value, precision, scale, expected):
assert normalize_decimal(value, precision, scale) == expected


@pytest.mark.parametrize(
("value", "precision", "scale"),
[
(1.0, 2, 2),
(1.0, 1, 1),
(D("1.1234"), 5, 3),
(D("1.1234"), 4, 2),
(D("23145"), 4, 2),
(1234.567, 10, 2),
(1234.567, 10, 1),
(3.14, 10, 0),
(3.14, 3, 0),
(3.14, 3, 1),
(3.14, 10, 1),
],
)
def test_normalize_failing(value, precision, scale):
with pytest.raises(TypeError):
normalize_decimal(value, precision, scale)


def test_normalize_decimal_dont_truncate_precision():
# test that the decimal context is ignored, 38 is the default precision
for prec in [10, 30, 38]:
with localcontext(Context(prec=prec)):
v = "1.123456789"
assert str(normalize_decimal(v + "0000")) == "1.123456789"

v = v + "1" * 28
assert len(v) == 39
assert str(normalize_decimal(v)) == v

# if no precision is specified, we use precision 38 for dec.normalize()
v = v + "1"
assert len(v) == 40
assert str(normalize_decimal(v)) == v[:-1]

# pass the precision explicitly
assert str(normalize_decimal(v, precision=39)) == v

v = v + "1" * 11
assert len(v) == 51
assert str(normalize_decimal(v, precision=50)) == v
assert str(normalize_decimal(v, precision=45)) == v[:-5]
29 changes: 29 additions & 0 deletions ibis/expr/datatypes/tests/test_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,32 @@ def test_normalize_non_convertible_float(typename):
typ = getattr(dt, typename)
with pytest.raises(TypeError, match="Unable to normalize .+ to Float"):
dt.normalize(typ, "not convertible")


@pytest.mark.parametrize(
("value", "dtype", "expected"),
[
(1, dt.Decimal(), "1"),
(1.0, dt.Decimal(), "1"),
(1.0, dt.Decimal(2, 1), "1.0"),
(1.0, dt.Decimal(2, 0), "1"),
(1.0, dt.Decimal(4, 3), "1.000"),
(12, dt.Decimal(6, 3), "12.000"),
(12.1234, dt.Decimal(7, 5), "12.12340"),
(True, dt.Decimal(4, 0), "1"),
(True, dt.Decimal(4, 3), "1.000"),
(False, dt.Decimal(4, 0), "0"),
(decimal.Decimal("1.1"), dt.Decimal(76, 38), "1.1" + "0" * 37),
],
)
def test_normalize_decimal(value, dtype, expected):
assert str(dt.normalize(dtype, value)) == expected


def test_normalize_decimal_invalid():
with pytest.raises(TypeError):
dt.normalize(dt.Decimal(4, 2), "invalid")
with pytest.raises(TypeError):
dt.normalize(dt.Decimal(4, 2), 1234)
with pytest.raises(TypeError):
dt.normalize(12.1234, dt.Decimal(6, 2))
7 changes: 3 additions & 4 deletions ibis/expr/datatypes/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ibis.common.collections import frozendict
from ibis.common.dispatch import lazy_singledispatch
from ibis.common.exceptions import IbisTypeError, InputTypeError
from ibis.common.numeric import normalize_decimal
from ibis.common.temporal import (
IntervalUnit,
normalize_datetime,
Expand Down Expand Up @@ -242,6 +243,7 @@ def __repr__(self):
return self.text


# TODO(kszucs): should raise ValueError instead of TypeError
def normalize(typ, value):
"""Ensure that the Python type underlying a literal resolves to a single type."""

Expand Down Expand Up @@ -288,10 +290,7 @@ def normalize(typ, value):
elif dtype.is_string():
return str(value)
elif dtype.is_decimal():
out = decimal.Decimal(value)
if isinstance(value, int):
return out.scaleb(-dtype.scale)
return out
return normalize_decimal(value, precision=dtype.precision, scale=dtype.scale)
elif dtype.is_uuid():
return value if isinstance(value, uuid.UUID) else uuid.UUID(value)
elif dtype.is_array():
Expand Down
6 changes: 6 additions & 0 deletions ibis/tests/expr/test_literal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import datetime
import decimal
import uuid

import pytest
Expand Down Expand Up @@ -168,3 +169,8 @@ def test_timestamp_literal_without_tz():
now_raw = datetime.datetime.utcnow()
assert now_raw.tzinfo is None
assert ibis.literal(now_raw).type().timezone is None


def test_integer_as_decimal():
lit = ibis.literal(12, type="decimal")
assert lit.op().value == decimal.Decimal(12)
6 changes: 3 additions & 3 deletions ibis/tests/expr/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_listeral_with_unhashable_values(value, expected_type, expected_value):
param(uuid.uuid4(), "uuid", id="uuid"),
param(str(uuid.uuid4()), "uuid", id="uuid_str"),
param(Decimal("234.234"), "decimal(6, 3)", id="decimal_native"),
param(234234, "decimal(6, 3)", id="decimal_int"),
param(234234, "decimal(9, 3)", id="decimal_int"),
],
)
def test_literal_with_explicit_type(value, expected_type):
Expand All @@ -143,13 +143,13 @@ def test_literal_with_explicit_type(value, expected_type):
[
# precision > scale
(Decimal("234.234"), Decimal("234.234"), "decimal(6, 3)"),
(234234, Decimal("234.234"), "decimal(6, 3)"),
(234234, Decimal("234234.000"), "decimal(9, 3)"),
# scale == 0
(Decimal("234"), Decimal("234"), "decimal(6, 0)"),
(234, Decimal("234"), "decimal(6, 0)"),
# precision == scale
(Decimal(".234"), Decimal(".234"), "decimal(3, 3)"),
(234, Decimal(".234"), "decimal(3, 3)"),
(234, Decimal("234.000"), "decimal(6, 3)"),
],
)
def test_normalize_decimal_literal(value, expected, dtype):
Expand Down

0 comments on commit 5213958

Please sign in to comment.