Skip to content

Commit

Permalink
feat(datatypes): unbounded decimal type
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Mar 28, 2022
1 parent 088169a commit f7e6f65
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 56 deletions.
11 changes: 0 additions & 11 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,6 @@ def sa_mysql_numeric(_, satype, nullable=True):
)


@dt.dtype.register(PGDialect, postgresql.NUMERIC)
def sa_postgres_numeric(_, satype, nullable=True):
# PostgreSQL allows any precision for numeric values if not specified,
# up to the implementation limit. Here, default to the maximum value that
# can be specified by the user. The scale defaults to zero.
# https://www.postgresql.org/docs/10/datatype-numeric.html
return dt.Decimal(
satype.precision or 1000, satype.scale or 0, nullable=nullable
)


@dt.dtype.register(Dialect, sa.types.Numeric)
@dt.dtype.register(SQLiteDialect, sqlite.NUMERIC)
def sa_numeric(_, satype, nullable=True):
Expand Down
6 changes: 2 additions & 4 deletions ibis/backends/duckdb/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def spaceless_string(*strings: str):
)


def parse_type(text: str) -> DataType:
def parse_type(text: str, default_decimal_parameters=(18, 3)) -> DataType:
precision = scale = p.digit.at_least(1).concat().map(int)

lparen = spaceless_string("(")
Expand Down Expand Up @@ -103,9 +103,7 @@ def decimal():
)
.skip(rparen)
.optional()
)
if prec_scale is None:
prec_scale = (18, 3)
) or default_decimal_parameters
return Decimal(*prec_scale)

@p.generate
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def guid2(con):
),
param(
lambda t: t.string_col.cast('decimal'),
lambda at: sa.cast(at.c.string_col, sa.NUMERIC(9, 0)),
lambda at: sa.cast(at.c.string_col, sa.NUMERIC()),
id='string_to_decimal_no_params',
),
param(
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ def test_divide_by_zero(backend, alltypes, df, column, denominator):
[
(
{'postgres': postgresql, 'mysql': mysql},
{'postgres': 1000, 'mysql': 10},
{'postgres': 0, 'mysql': 0},
{'postgres': None, 'mysql': 10},
{'postgres': None, 'mysql': 0},
)
],
)
Expand Down
102 changes: 66 additions & 36 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import ast
import collections
import datetime
import decimal
import enum
import functools
import numbers
import re
import typing
import uuid as _uuid
from decimal import Decimal as PythonDecimal
from typing import (
AbstractSet,
Iterable,
Expand Down Expand Up @@ -409,43 +409,64 @@ class Float64(Floating):
class Decimal(DataType):
"""Fixed-precision decimal values."""

precision = instance_of(int)
"""The number of values after the decimal point."""

scale = instance_of(int)
precision = optional(instance_of(int))
"""The number of decimal places values of this type can hold."""

scale = optional(instance_of(int))
"""The number of values after the decimal point."""

scalar = ir.DecimalScalar
column = ir.DecimalColumn

def __init__(self, precision: int, scale: int, **kwargs) -> None:
if not isinstance(precision, numbers.Integral):
raise TypeError('Decimal type precision must be an integer')
if not isinstance(scale, numbers.Integral):
raise TypeError('Decimal type scale must be an integer')
if precision < 0:
raise ValueError('Decimal type precision cannot be negative')
if not precision:
raise ValueError('Decimal type precision cannot be zero')
if scale < 0:
raise ValueError('Decimal type scale cannot be negative')
if precision < scale:
raise ValueError(
'Decimal type precision must be greater than or equal to '
'scale. Got precision={:d} and scale={:d}'.format(
precision, scale
def __init__(
self,
precision: int | None = None,
scale: int | None = None,
**kwargs: Any,
) -> None:
if precision is not None:
if not isinstance(precision, numbers.Integral):
raise TypeError(
"Decimal type precision must be an integer; "
f"got {type(precision)}"
)
if precision < 0:
raise ValueError('Decimal type precision cannot be negative')
if not precision:
raise ValueError('Decimal type precision cannot be zero')
if scale is not None:
if not isinstance(scale, numbers.Integral):
raise TypeError('Decimal type scale must be an integer')
if scale < 0:
raise ValueError('Decimal type scale cannot be negative')
if precision is not None and precision < scale:
raise ValueError(
'Decimal type precision must be greater than or equal to '
'scale. Got precision={:d} and scale={:d}'.format(
precision, scale
)
)
)
super().__init__(precision=precision, scale=scale, **kwargs)

@property
def largest(self) -> Decimal:
"""Return the largest decimal type."""
return self.__class__(38, self.scale)
def largest(self):
"""Return the largest type of decimal."""
return self.__class__(precision=None, scale=None)

@property
def _pretty_piece(self) -> str:
return f"({self.precision:d}, {self.scale:d})"
args = []

if (precision := self.precision) is not None:
args.append(f"prec={precision:d}")

if (scale := self.scale) is not None:
args.append(f"scale={scale:d}")

if not args:
return ""

return f"({', '.join(args)})"


@public
Expand Down Expand Up @@ -860,8 +881,8 @@ class INET(String):
interval = Interval()
category = Category()
# geo spatial data type
geometry = GeoSpatial()
geography = GeoSpatial()
geometry = Geometry()
geography = Geography()
point = Point()
linestring = LineString()
polygon = Polygon()
Expand All @@ -875,6 +896,7 @@ class INET(String):
uuid = UUID()
macaddr = MACADDR()
inet = INET()
decimal = Decimal()

public(
any=any,
Expand Down Expand Up @@ -915,6 +937,7 @@ class INET(String):
uuid=uuid,
macaddr=macaddr,
inet=inet,
decimal=decimal,
)

_STRING_REGEX = """('[^\n'\\\\]*(?:\\\\.[^\n'\\\\]*)*'|"[^\n"\\\\"]*(?:\\\\.[^\n"\\\\]*)*")""" # noqa: E501
Expand Down Expand Up @@ -1055,16 +1078,16 @@ def varchar_or_char():
@p.generate
def decimal():
yield spaceless_string("decimal")
prec_scale = (
prec, sc = (
yield lparen.then(
p.seq(precision.skip(comma), scale).combine(
lambda prec, scale: (prec, scale)
)
)
.skip(rparen)
.optional()
) or (9, 0)
return Decimal(*prec_scale)
) or (None, None)
return Decimal(precision=prec, scale=sc)

@p.generate
def parened_string():
Expand Down Expand Up @@ -1394,8 +1417,15 @@ def can_cast_floats(

@castable.register(Decimal, Decimal)
def can_cast_decimals(source: Decimal, target: Decimal, **kwargs) -> bool:
target_prec = target.precision
source_prec = source.precision
target_sc = target.scale
source_sc = source.scale
return (
target.precision >= source.precision and target.scale >= source.scale
target_prec is None
or (source_prec is not None and target_prec >= source_prec)
) and (
target_sc is None or (source_sc is not None and target_sc >= source_sc)
)


Expand Down Expand Up @@ -1616,8 +1646,8 @@ def _uuid_to_str(typ: String, value: _uuid.UUID) -> str:


@_normalize.register(Decimal, int)
def _int_to_decimal(typ: Decimal, value: int) -> decimal.Decimal:
return decimal.Decimal(value).scaleb(-typ.scale)
def _int_to_decimal(typ: Decimal, value: int) -> PythonDecimal:
return PythonDecimal(value).scaleb(-typ.scale)


@_normalize.register(Array, (tuple, list, np.ndarray))
Expand All @@ -1631,13 +1661,13 @@ def _set_to_frozenset(typ: Set, values: AbstractSet) -> frozenset:


@_normalize.register(Map, dict)
def _map_to_frozendict(typ: Map, values: Mapping) -> decimal.Decimal:
def _map_to_frozendict(typ: Map, values: Mapping) -> PythonDecimal:
values = {k: _normalize(typ.value_type, v) for k, v in values.items()}
return frozendict(values)


@_normalize.register(Struct, dict)
def _struct_to_frozendict(typ: Struct, values: Mapping) -> decimal.Decimal:
def _struct_to_frozendict(typ: Struct, values: Mapping) -> PythonDecimal:
value_types = typ.pairs
values = {
k: _normalize(typ[k], v) for k, v in values.items() if k in value_types
Expand Down
4 changes: 2 additions & 2 deletions ibis/tests/expr/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_decimal_sum_type(lineitem):
col = lineitem.l_extendedprice
result = col.sum()
assert isinstance(result, ir.DecimalScalar)
assert result.type() == dt.Decimal(38, col.type().scale)
assert result.type() == dt.decimal


@pytest.mark.parametrize('func', ['mean', 'max', 'min'])
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_invalid_precision_scale_type(precision, scale):
def test_decimal_str(lineitem):
col = lineitem.l_extendedprice
t = col.type()
assert str(t) == f'decimal({t.precision:d}, {t.scale:d})'
assert str(t) == f'decimal(prec={t.precision:d}, scale={t.scale:d})'


def test_decimal_repr(lineitem):
Expand Down

0 comments on commit f7e6f65

Please sign in to comment.