Skip to content

Commit

Permalink
fix(ir): coerce integers passed to Value[dt.Floating] annotated val…
Browse files Browse the repository at this point in the history
…ues as `dt.float64`
  • Loading branch information
kszucs authored and cpcloud committed Oct 26, 2023
1 parent e654a15 commit b8a924a
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 42 deletions.
30 changes: 1 addition & 29 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,6 @@ def __contains__(self, value: int) -> bool:
class Numeric(DataType):
"""Numeric types."""

@property
@abstractmethod
def largest(self) -> DataType:
"""Return the largest type in this family."""


@public
class Integer(Primitive, Numeric):
Expand Down Expand Up @@ -630,11 +625,6 @@ def _pretty_piece(self) -> str:
class SignedInteger(Integer):
"""Signed integer values."""

@property
def largest(self):
"""Return the largest type of signed integer."""
return int64

@property
def bounds(self):
exp = self.nbytes * 8 - 1
Expand All @@ -646,11 +636,6 @@ def bounds(self):
class UnsignedInteger(Integer):
"""Unsigned integer values."""

@property
def largest(self):
"""Return the largest type of unsigned integer."""
return uint64

@property
def bounds(self):
exp = self.nbytes * 8
Expand All @@ -665,15 +650,10 @@ class Floating(Primitive, Numeric):
scalar = "FloatingScalar"
column = "FloatingColumn"

@property
def largest(self):
"""Return the largest type of floating point values."""
return float64

@property
@abstractmethod
def nbytes(self) -> int: # pragma: no cover
...
"""Return the number of bytes used to store values of this type."""


@public
Expand Down Expand Up @@ -794,14 +774,6 @@ def __init__(
)
super().__init__(precision=precision, scale=scale, **kwargs)

@property
def largest(self):
"""Return the largest type of decimal."""
return self.__class__(
precision=max(self.precision, 38) if self.precision is not None else None,
scale=max(self.scale, 2) if self.scale is not None else None,
)

@property
def _pretty_piece(self) -> str:
precision = self.precision
Expand Down
8 changes: 7 additions & 1 deletion ibis/expr/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,17 @@ def __coerce__(
if isinstance(value, Value):
return value

try:
if T is dt.Integer:
dtype = dt.infer(int(value))
elif T is dt.Floating:
dtype = dt.infer(float(value))
else:
try:
dtype = dt.DataType.from_typehint(T)
except TypeError:
dtype = dt.infer(value)

try:
return Literal(value, dtype=dtype)
except TypeError:
raise CoercionError(f"Unable to coerce {value!r} to Value[{T!r}]")
Expand Down
7 changes: 5 additions & 2 deletions ibis/expr/operations/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,17 @@ class MathUnary(Unary):

@attribute
def dtype(self):
return dt.higher_precedence(self.arg.dtype, dt.double)
return dt.higher_precedence(self.arg.dtype, dt.float64)


@public
class ExpandingMathUnary(MathUnary):
@attribute
def dtype(self):
return dt.higher_precedence(self.arg.dtype.largest, dt.double)
if self.arg.dtype.is_decimal():
return self.arg.dtype
else:
return dt.float64


@public
Expand Down
23 changes: 19 additions & 4 deletions ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,25 @@ class Sum(Filterable, Reduction):

@attribute
def dtype(self):
if self.arg.dtype.is_boolean():
dtype = self.arg.dtype
if dtype.is_boolean():
return dt.int64
elif dtype.is_integer():
return dt.int64
elif dtype.is_unsigned_integer():
return dt.uint64
elif dtype.is_floating():
return dt.float64
elif dtype.is_decimal():
return dt.Decimal(
precision=max(dtype.precision, 38)
if dtype.precision is not None
else None,
scale=max(dtype.scale, 2) if dtype.scale is not None else None,
)

else:
return self.arg.dtype.largest
raise TypeError(f"Cannot compute sum of {dtype} values")


@public
Expand Down Expand Up @@ -204,8 +219,8 @@ class VarianceBase(Filterable, Reduction):

@attribute
def dtype(self):
if (dtype := self.arg.dtype).is_decimal():
return dtype.largest
if self.arg.dtype.is_decimal():
return self.arg.dtype
else:
return dt.float64

Expand Down
29 changes: 23 additions & 6 deletions ibis/expr/operations/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
@pytest.mark.parametrize(
("value", "dtype"),
[
(1, dt.int8),
(1.0, dt.double),
(True, dt.boolean),
("foo", dt.string),
(b"foo", dt.binary),
((1, 2), dt.Array(dt.int8)),
(1, dt.Int8),
(1.0, dt.Float64),
(True, dt.Boolean),
("foo", dt.String),
(b"foo", dt.Binary),
((1, 2), dt.Array[dt.Int8]),
],
)
def test_literal_coercion_type_inference(value, dtype):
Expand All @@ -39,6 +39,9 @@ def test_literal_coercion_type_inference(value, dtype):
(ops.Literal[dt.Int8], 1, one),
(ops.Literal[dt.Int16], 1, ops.Literal(1, dt.int16)),
(ops.Literal[dt.Int8], ops.Literal(1, dt.int16), NoMatch),
(ops.Literal[dt.Integer], 1, ops.Literal(1, dt.int8)),
(ops.Literal[dt.Floating], 1, ops.Literal(1, dt.float64)),
(ops.Literal[dt.Float32], 1.0, ops.Literal(1.0, dt.float32)),
],
)
def test_coerced_to_literal(typehint, value, expected):
Expand All @@ -60,15 +63,29 @@ def test_coerced_to_literal(typehint, value, expected):
# same applies here, the coercion itself will use only the inferred datatype
# but then the result is checked against the given typehint
(ops.Value[dt.Int8 | dt.Int16], 1, one),
(Union[ops.Value[dt.Int8], ops.Value[dt.Int16]], 1, one),
(ops.Value[dt.Int8 | dt.Int16], 128, ops.Literal(128, dt.int16)),
(
Union[ops.Value[dt.Int8], ops.Value[dt.Int16]],
128,
ops.Literal(128, dt.int16),
),
(ops.Value[dt.Int8 | dt.Int16], 128, ops.Literal(128, dt.int16)),
(
Union[ops.Value[dt.Int8], ops.Value[dt.Int16]],
128,
ops.Literal(128, dt.int16),
),
(ops.Value[dt.Int8], 128, NoMatch),
# this is actually supported by creating an explicit dtype
# in Value.__coerce__ based on the `T` keyword argument
(ops.Value[dt.Int16, ds.Scalar], 1, ops.Literal(1, dt.int16)),
(ops.Value[dt.Int16, ds.Scalar], 128, ops.Literal(128, dt.int16)),
# equivalent with ops.Value[dt.Int8 | dt.Int16]
(Union[ops.Value[dt.Int8], ops.Value[dt.Int16]], 1, one),
# when expecting floating point values given an integer value it will
# be coerced to float64
(ops.Value[dt.Floating], 1, ops.Literal(1, dt.float64)),
],
)
def test_coerced_to_value(typehint, value, expected):
Expand Down

0 comments on commit b8a924a

Please sign in to comment.