Skip to content

Commit

Permalink
feat(datatype/schema): support datatype and schema declaration using …
Browse files Browse the repository at this point in the history
…type annotated classes
  • Loading branch information
kszucs committed Jan 31, 2023
1 parent d22ae7b commit 6722c31
Show file tree
Hide file tree
Showing 6 changed files with 499 additions and 15 deletions.
2 changes: 1 addition & 1 deletion ibis/expr/datatypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def can_cast_struct(source, target, **kwargs):

@castable.register(dt.Array, dt.Array)
@castable.register(dt.Set, dt.Set)
def can_cast_variadic(
def can_cast_array_or_set(
source: dt.Array | dt.Set, target: dt.Array | dt.Set, **kwargs
) -> bool:
return castable(source.value_type, target.value_type)
Expand Down
98 changes: 86 additions & 12 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from __future__ import annotations

import datetime as pydatetime
import decimal as pydecimal
import numbers
import uuid as pyuuid
from abc import abstractmethod
from collections.abc import Iterator, Mapping
from collections.abc import Iterator, Mapping, Sequence
from collections.abc import Set as PySet
from numbers import Integral, Real
from typing import Any, Iterable, NamedTuple

import numpy as np
import toolz
from multipledispatch import Dispatcher
from public import public
from typing_extensions import get_args, get_origin, get_type_hints

import ibis.expr.types as ir
from ibis.common.annotations import attribute, optional
from ibis.common.exceptions import IbisTypeError
from ibis.common.grounds import Concrete, Singleton
from ibis.common.validators import (
all_of,
Expand All @@ -22,12 +28,53 @@
validator,
)

# TODO(kszucs): we don't support union types yet

dtype = Dispatcher('dtype')


@dtype.register(object)
def dtype_from_object(value, **kwargs) -> DataType:
raise IbisTypeError(f'Value {value!r} is not a valid datatype')
# TODO(kszucs): implement this in a @dtype.register(type) overload once dtype
# turned into a singledispatched function because that overload doesn't work
# with multipledispatch

# TODO(kszucs): support Tuple[int, str] and Tuple[int, ...] typehints
# in order to support more kinds of typehints follow the implementation of
# Validator.from_annotation
origin_type = get_origin(value)
if origin_type is None:
if issubclass(value, DataType):
return value()
elif result := _python_dtypes.get(value):
return result
elif annots := get_type_hints(value):
return Struct(toolz.valmap(dtype, annots))
elif issubclass(value, bytes):
return bytes
elif issubclass(value, str):
return string
elif issubclass(value, Integral):
return int64
elif issubclass(value, Real):
return float64
elif value is type(None):
return null
else:
raise TypeError(
f"Cannot construct an ibis datatype from python type {value!r}"
)
elif issubclass(origin_type, Sequence):
(value_type,) = map(dtype, get_args(value))
return Array(value_type)
elif issubclass(origin_type, Mapping):
key_type, value_type = map(dtype, get_args(value))
return Map(key_type, value_type)
elif issubclass(origin_type, PySet):
(value_type,) = map(dtype, get_args(value))
return Set(value_type)
else:
raise TypeError(f'Value {value!r} is not a valid datatype')


@validator
Expand Down Expand Up @@ -239,11 +286,20 @@ class Primitive(DataType, Singleton):
"""Values with known size."""


# TODO(kszucs): consider to remove since we don't actually use this information
@public
class Variadic(DataType):
"""Values with unknown size."""


@public
class Parametric(DataType):
"""Types that can be parameterized."""

def __class_getitem__(cls, params):
return cls(*params) if isinstance(params, tuple) else cls(params)


@public
class Null(Primitive):
"""Null values."""
Expand Down Expand Up @@ -339,7 +395,7 @@ class Time(Temporal, Primitive):


@public
class Timestamp(Temporal):
class Timestamp(Temporal, Parametric):
"""Timestamp values."""

timezone = optional(instance_of(str))
Expand Down Expand Up @@ -487,7 +543,7 @@ class Float64(Floating):


@public
class Decimal(Numeric):
class Decimal(Numeric, Parametric):
"""Fixed-precision decimal values."""

precision = optional(instance_of(int))
Expand Down Expand Up @@ -551,7 +607,7 @@ def _pretty_piece(self) -> str:


@public
class Interval(DataType):
class Interval(Parametric):
"""Interval values."""

__valid_units__ = {
Expand Down Expand Up @@ -617,7 +673,7 @@ def _pretty_piece(self) -> str:


@public
class Category(DataType):
class Category(Parametric):
cardinality = optional(instance_of(int))

scalar = ir.CategoryScalar
Expand All @@ -640,14 +696,17 @@ def to_integer_type(self):


@public
class Struct(DataType, Mapping):
class Struct(Parametric, Mapping):
"""Structured values."""

fields = frozendict_of(instance_of(str), datatype)

scalar = ir.StructScalar
column = ir.StructColumn

def __class_getitem__(cls, fields):
return cls({slice_.start: slice_.stop for slice_ in fields})

@classmethod
def from_tuples(
cls, pairs: Iterable[tuple[str, str | DataType]], nullable: bool = True
Expand Down Expand Up @@ -697,7 +756,7 @@ def _pretty_piece(self) -> str:


@public
class Array(Variadic):
class Array(Variadic, Parametric):
"""Array values."""

value_type = datatype
Expand All @@ -711,7 +770,7 @@ def _pretty_piece(self) -> str:


@public
class Set(Variadic):
class Set(Variadic, Parametric):
"""Set values."""

value_type = datatype
Expand All @@ -725,7 +784,7 @@ def _pretty_piece(self) -> str:


@public
class Map(Variadic):
class Map(Variadic, Parametric):
"""Associative array values."""

key_type = datatype
Expand Down Expand Up @@ -887,6 +946,21 @@ class INET(String):

Enum = String


_python_dtypes = {
bool: boolean,
int: int64,
float: float64,
str: string,
bytes: binary,
pydatetime.date: date,
pydatetime.time: time,
pydatetime.datetime: timestamp,
pydatetime.timedelta: interval,
pydecimal.Decimal: decimal,
pyuuid.UUID: uuid,
}

_numpy_dtypes = {
np.dtype("bool"): boolean,
np.dtype("int8"): int8,
Expand Down Expand Up @@ -929,7 +1003,7 @@ class INET(String):


@dtype.register(np.dtype)
def _(value):
def from_numpy_dtype(value):
try:
return _numpy_dtypes[value]
except KeyError:
Expand Down
5 changes: 5 additions & 0 deletions ibis/expr/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ def schema_from_pairs(lst):
return Schema.from_tuples(lst)


@schema.register(type)
def schema_from_class(cls):
return Schema(dt.dtype(cls))


@schema.register(Iterable, Iterable)
def schema_from_names_types(names, types):
# validate lengths of names and types are the same
Expand Down
Loading

0 comments on commit 6722c31

Please sign in to comment.