Skip to content

Commit

Permalink
fix(decimal): add decimal type inference
Browse files Browse the repository at this point in the history
  • Loading branch information
webmiche authored and cpcloud committed Aug 15, 2022
1 parent 3771196 commit 3fe3fd8
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
37 changes: 37 additions & 0 deletions ibis/expr/rules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
import functools
import operator
from itertools import product, starmap

import ibis.common.exceptions as com
Expand Down Expand Up @@ -294,12 +295,48 @@ def _promote_integral_binop(exprs, op):
return dt.highest_precedence(dtypes)


def _promote_decimal_dtype(args, op):

if len(args) != 2:
return highest_precedence_dtype(args)

# TODO: Add support for setting the maximum precision and maximum scale
lhs_prec = args[0].type().precision
lhs_scale = args[0].type().scale
rhs_prec = args[1].type().precision
rhs_scale = args[1].type().scale
max_prec = 31 if lhs_prec <= 31 and rhs_prec <= 31 else 63
max_scale = 31

if op is operator.mul:
return dt.Decimal(
min(max_prec, lhs_prec + rhs_prec),
min(max_scale, lhs_scale + rhs_scale),
)
if op is operator.add or op is operator.sub:
return dt.Decimal(
min(
max_prec,
max(
lhs_prec - lhs_scale,
rhs_prec - rhs_scale,
)
+ max(lhs_scale, rhs_scale)
+ 1,
),
max(lhs_scale, rhs_scale),
)
return highest_precedence_dtype(args)


def numeric_like(name, op):
@immutable_property
def output_dtype(self):
args = getattr(self, name)
if util.all_of(args, ir.IntegerValue):
result = _promote_integral_binop(args, op)
elif util.all_of(args, ir.DecimalValue):
result = _promote_decimal_dtype(args, op)
else:
result = highest_precedence_dtype(args)

Expand Down
31 changes: 31 additions & 0 deletions ibis/tests/expr/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,37 @@ def test_decimal_sum_type(lineitem):
assert result.type() == dt.Decimal(38, 2)


def test_promote_decimal_type_mul(lineitem):
col_1 = lineitem.l_extendedprice
col_2 = lineitem.l_discount
result = col_1 * col_2
assert result.type().precision == 24
assert result.type().scale == 4


def test_promote_decimal_type_add(lineitem):
col_1 = lineitem.l_extendedprice
col_2 = lineitem.l_discount
result = col_1 + col_2
assert result.type().precision == 13
assert result.type().scale == 2


def test_promote_decimal_type_mod(lineitem):
col_1 = lineitem.l_extendedprice
col_2 = lineitem.l_discount
result = col_1 % col_2
assert result.type().precision == 12
assert result.type().scale == 2


def test_promote_decimal_type_max():
t = ibis.table([("a", "decimal(31, 3)"), ("b", "decimal(31, 3)")], "t")
result = t.a * t.b
assert result.type().precision == 31
assert result.type().scale == 6


@pytest.mark.parametrize(
"precision, scale, expected",
[
Expand Down

0 comments on commit 3fe3fd8

Please sign in to comment.