Skip to content

Commit

Permalink
refactor(datatypes): use typehints instead of rules
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Apr 13, 2023
1 parent 7c747ae commit 704542e
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 41 deletions.
4 changes: 1 addition & 3 deletions ibis/common/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,7 @@ def endswith_d(x, this):

def test_annotated_function_with_complex_type_annotations():
@annotated
def test(
a: Annotated[str, short_str, endswith_d], b: Union[int, float] # noqa: UP007
):
def test(a: Annotated[str, short_str, endswith_d], b: Union[int, float]):
return a, b

assert test("abcd", 1) == ("abcd", 1)
Expand Down
2 changes: 1 addition & 1 deletion ibis/examples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


class Example(Concrete):
descr: Optional[str] # noqa: UP007
descr: Optional[str]
key: str
reader: str

Expand Down
75 changes: 49 additions & 26 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from abc import abstractmethod
from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Set as PySet
from enum import Enum
from numbers import Integral, Real
from typing import Any, Iterable, NamedTuple
from typing import Any, Iterable, Literal, NamedTuple, Optional

import numpy as np
import toolz
Expand All @@ -17,15 +18,11 @@
from typing_extensions import get_args, get_origin, get_type_hints

from ibis.common.annotations import attribute, optional
from ibis.common.collections import MapSet
from ibis.common.collections import FrozenDict, MapSet
from ibis.common.grounds import Concrete, Singleton
from ibis.common.validators import (
all_of,
frozendict_of,
instance_of,
isin,
Coercible,
map_to,
validator,
)

# TODO(kszucs): we don't support union types yet
Expand Down Expand Up @@ -82,19 +79,14 @@ def dtype_from_object(value, **kwargs) -> DataType:
raise TypeError(f'Value {value!r} is not a valid datatype')


@validator
def datatype(arg, **kwargs):
return dtype(arg)


@public
class DataType(Concrete):
class DataType(Concrete, Coercible):
"""Base class for all data types.
[`DataType`][ibis.expr.datatypes.DataType] instances are immutable.
"""

nullable = optional(instance_of(bool), default=True)
nullable: bool = True

# TODO(kszucs): remove it, prefer to use Annotable.__repr__ instead
@property
Expand All @@ -107,6 +99,10 @@ def name(self) -> str:
"""Return the name of the data type."""
return self.__class__.__name__

@classmethod
def __coerce__(cls, value):
return dtype(value)

def __call__(self, **kwargs):
return self.copy(**kwargs)

Expand Down Expand Up @@ -414,10 +410,11 @@ class Time(Temporal, Primitive):
class Timestamp(Temporal, Parametric):
"""Timestamp values."""

timezone = optional(instance_of(str))
timezone: Optional[str] = None
"""The timezone of values of this type."""

scale = optional(isin(range(10)))
# Literal[*range(10)] is only supported from 3.11
scale: Optional[Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]] = None
"""The scale of the timestamp if known."""

scalar = "TimestampScalar"
Expand Down Expand Up @@ -562,10 +559,10 @@ class Float64(Floating):
class Decimal(Numeric, Parametric):
"""Fixed-precision decimal values."""

precision = optional(instance_of(int))
precision: Optional[int] = None
"""The number of decimal places values of this type can hold."""

scale = optional(instance_of(int))
scale: Optional[int] = None
"""The number of values after the decimal point."""

scalar = "DecimalScalar"
Expand Down Expand Up @@ -622,6 +619,32 @@ def _pretty_piece(self) -> str:
return f"({', '.join(args)})"


class TemporalUnit(Enum):
YEAR = "Y"
QUARTER = "Q"
MONTH = "M"
WEEK = "W"
DAY = "D"
HOUR = "h"
MINUTE = "m"
SECOND = "s"
MILLISECOND = "ms"
MICROSECOND = "us"
NANOSECOND = "ns"

@property
def singular(self) -> str:
return self.name.lower()

@property
def plural(self) -> str:
return self.singular + "s"

@property
def short(self) -> str:
return self.value


@public
class Interval(Parametric):
"""Interval values."""
Expand Down Expand Up @@ -650,7 +673,7 @@ class Interval(Parametric):
unit = optional(map_to(__valid_units__), default='s')
"""The time unit of the interval."""

value_type = optional(all_of([datatype, instance_of(Integer)]), default=Int32())
value_type: Integer = Int32()
"""The underlying type of the stored values."""

scalar = "IntervalScalar"
Expand Down Expand Up @@ -692,7 +715,7 @@ def _pretty_piece(self) -> str:
class Struct(Parametric, MapSet):
"""Structured values."""

fields = frozendict_of(instance_of(str), datatype)
fields: FrozenDict[str, DataType]

scalar = "StructScalar"
column = "StructColumn"
Expand Down Expand Up @@ -752,7 +775,7 @@ def _pretty_piece(self) -> str:
class Array(Variadic, Parametric):
"""Array values."""

value_type = datatype
value_type: DataType

scalar = "ArrayScalar"
column = "ArrayColumn"
Expand All @@ -766,7 +789,7 @@ def _pretty_piece(self) -> str:
class Set(Variadic, Parametric):
"""Set values."""

value_type = datatype
value_type: DataType

scalar = "SetScalar"
column = "SetColumn"
Expand All @@ -780,8 +803,8 @@ def _pretty_piece(self) -> str:
class Map(Variadic, Parametric):
"""Associative array values."""

key_type = datatype
value_type = datatype
key_type: DataType
value_type: DataType

scalar = "MapScalar"
column = "MapColumn"
Expand All @@ -803,10 +826,10 @@ class JSON(Variadic):
class GeoSpatial(DataType):
"""Geospatial values."""

geotype = optional(isin({"geography", "geometry"}))
geotype: Optional[Literal["geography", "geometry"]] = None
"""The specific geospatial type."""

srid = optional(instance_of(int))
srid: Optional[int] = None
"""The spatial reference identifier."""

column = "GeoSpatialColumn"
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/datatypes/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def geotype_parser(typ: type[dt.DataType]) -> dt.DataType:

interval = spaceless_string("interval").then(
parsy.seq(
value_type=angle_type.optional(), unit=parened_string.optional("s")
value_type=angle_type.optional(dt.int32), unit=parened_string.optional("s")
).combine_dict(dt.Interval)
)

Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/datatypes/tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_parse_interval_with_invalid_value_type():
dt.dtype("interval<float>('s')")


@pytest.mark.parametrize('unit', ['H', 'unsupported'])
@pytest.mark.parametrize('unit', ['X', 'unsupported'])
def test_parse_interval_with_invalid_unit(unit):
definition = f"interval('{unit}')"
with pytest.raises(ValueError):
Expand Down
19 changes: 10 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -406,15 +406,16 @@ ignore = [
"RET506",
"RET507",
"RET508",
"RUF005", # splat instead of concat
"SIM102", # nested ifs
"SIM108", # convert everything to ternary operator
"SIM114", # combine `if` branches using logical `or` operator
"SIM116", # dictionary instead of `if` statements
"SIM117", # nested withs
"SIM118", # remove .keys() calls from dictionaries
"SIM300", # yoda conditions
"UP037", # remove quotes from type annotation
"RUF005", # splat instead of concat
"SIM102", # nested ifs
"SIM108", # convert everything to ternary operator
"SIM114", # combine `if` branches using logical `or` operator
"SIM116", # dictionary instead of `if` statements
"SIM117", # nested withs
"SIM118", # remove .keys() calls from dictionaries
"SIM300", # yoda conditions
"UP037", # remove quotes from type annotation
"UP007", # Optional[str] -> str | None
]
exclude = ["*_py310.py", "ibis/tests/*/snapshots/*"]
target-version = "py38"
Expand Down

0 comments on commit 704542e

Please sign in to comment.