Skip to content

Commit

Permalink
refactor(ir): remove the decimal precision promotion logic
Browse files Browse the repository at this point in the history
  • Loading branch information
chelsea-lin authored and kszucs committed Feb 5, 2024
1 parent e286b69 commit 0db3ec7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 40 deletions.
36 changes: 0 additions & 36 deletions ibis/expr/rules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import operator
from itertools import product, starmap
from typing import Optional

Expand Down Expand Up @@ -108,39 +107,6 @@ def _promote_integral_binop(exprs, op):
return dt.highest_precedence(dtypes)


def _promote_decimal_binop(args, op):
if len(args) != 2:
return highest_precedence_dtype(args)

# TODO: Add support for setting the maximum precision and maximum scale
left = args[0].dtype
right = args[1].dtype

max_prec = 31 if left.precision <= 31 and right.precision <= 31 else 63
max_scale = 31

if op is operator.mul:
return dt.Decimal(
min(max_prec, left.precision + right.precision),
min(max_scale, left.scale + right.scale),
)
elif op is operator.add or op is operator.sub:
return dt.Decimal(
min(
max_prec,
max(
left.precision - left.scale,
right.precision - right.scale,
)
+ max(left.scale, right.scale)
+ 1,
),
max(left.scale, right.scale),
)
else:
return highest_precedence_dtype(args)


@public
def numeric_like(name, op):
@attribute
Expand All @@ -149,8 +115,6 @@ def dtype(self):
dtypes = [arg.dtype for arg in args]
if util.all_of(dtypes, dt.Integer):
result = _promote_integral_binop(args, op)
elif util.all_of(dtypes, dt.Decimal):
result = _promote_decimal_binop(args, op)
else:
result = highest_precedence_dtype(args)

Expand Down
8 changes: 4 additions & 4 deletions ibis/tests/expr/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def test_promote_decimal_type_mul(lineitem):
col_1 = lineitem.l_extendedprice
col_2 = lineitem.l_discount
result = col_1 * col_2
assert result.type().precision == 24
assert result.type().scale == 4
assert result.type().precision == 12
assert result.type().scale == 2


def test_promote_decimal_type_add(lineitem):
col_1 = lineitem.l_extendedprice
col_2 = lineitem.l_discount
result = col_1 + col_2
assert result.type().precision == 13
assert result.type().precision == 12
assert result.type().scale == 2


Expand All @@ -60,7 +60,7 @@ def test_promote_decimal_type_max():
t = ibis.table([("a", "decimal(31, 3)"), ("b", "decimal(31, 3)")], "t")
result = t.a * t.b
assert result.type().precision == 31
assert result.type().scale == 6
assert result.type().scale == 3


@pytest.mark.parametrize(
Expand Down

0 comments on commit 0db3ec7

Please sign in to comment.