From 87617320375d749c8df0f1741052861af431e047 Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Tue, 31 Jan 2023 20:20:34 -0500 Subject: [PATCH] feat(datatype): enable inference of `Decimal` type --- ibis/expr/datatypes/cast.py | 14 ++++++++++++-- ibis/expr/datatypes/tests/test_value.py | 2 ++ ibis/expr/datatypes/value.py | 6 ++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/ibis/expr/datatypes/cast.py b/ibis/expr/datatypes/cast.py index baabfd00e082..de0d1d73a1f9 100644 --- a/ibis/expr/datatypes/cast.py +++ b/ibis/expr/datatypes/cast.py @@ -113,8 +113,18 @@ def can_cast_decimals(source: dt.Decimal, target: dt.Decimal, **kwargs) -> bool: target_sc = target.scale source_sc = source.scale return ( - target_prec is None or (source_prec is not None and target_prec >= source_prec) - ) and (target_sc is None or (source_sc is not None and target_sc >= source_sc)) + # If either sides precision and scale are both `None`, return `True`. + target_prec is None + and target_sc is None + or source_prec is None + and source_sc is None + # Otherwise, return `True` unless we are downcasting precision or scale. + or ( + target_prec is None + or (source_prec is not None and target_prec >= source_prec) + ) + and (target_sc is None or (source_sc is not None and target_sc >= source_sc)) + ) @castable.register(dt.Interval, dt.Interval) diff --git a/ibis/expr/datatypes/tests/test_value.py b/ibis/expr/datatypes/tests/test_value.py index 0db51a7dc24c..938d582f4501 100644 --- a/ibis/expr/datatypes/tests/test_value.py +++ b/ibis/expr/datatypes/tests/test_value.py @@ -1,4 +1,5 @@ import datetime +import decimal import enum from collections import OrderedDict @@ -46,6 +47,7 @@ class Foo(enum.Enum): (-32769, dt.int32), (-2147483649, dt.int64), (1.5, dt.double), + (decimal.Decimal(1.5), dt.decimal), # parametric types (list('abc'), dt.Array(dt.string)), (set('abc'), dt.Set(dt.string)), diff --git a/ibis/expr/datatypes/value.py b/ibis/expr/datatypes/value.py index 475ffa466620..cd1fdcbd6d31 100644 --- a/ibis/expr/datatypes/value.py +++ b/ibis/expr/datatypes/value.py @@ -132,6 +132,12 @@ def infer_enum(_: enum.Enum) -> dt.String: return dt.string +@infer.register(decimal.Decimal) +def infer_decimal(value: decimal.Decimal) -> dt.Decimal: + """Infer the [`Decimal`][ibis.expr.datatypes.Decimal] type of `value`.""" + return dt.decimal + + @infer.register(bool) def infer_boolean(value: bool) -> dt.Boolean: return dt.boolean