Skip to content

Commit

Permalink
refactor(datatypes): remove direct ir dependency from datatypes
Browse files Browse the repository at this point in the history
This change was required to prevent import cycles. The value expression
corresponding to a datatype is now can be retrieved using
`getattr(ibis.expr.types, DataType.scalar|column)`.

BREAKING CHANGE: `DataType.scalar` and `column` class attributes are now strings.
  • Loading branch information
kszucs authored and cpcloud committed Apr 13, 2023
1 parent 755555f commit d7f0be0
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 61 deletions.
9 changes: 5 additions & 4 deletions ibis/backends/impala/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def test_udf_primitive_output_types(ty, value, column, table):
ibis_type = dt.validate_type(ty)

expr = func(value)
assert type(expr) == ibis_type.scalar
assert type(expr) == getattr(ir, ibis_type.scalar)

expr = func(table[column])
assert type(expr) == ibis_type.column
assert type(expr) == getattr(ir, ibis_type.column)


@pytest.mark.parametrize(
Expand All @@ -164,12 +164,13 @@ def test_uda_primitive_output_types(ty, value):
func = _register_uda([ty], ty, 'test')

ibis_type = dt.validate_type(ty)
scalar_type = getattr(ir, ibis_type.scalar)

expr1 = func(value)
assert isinstance(expr1, ibis_type.scalar)
assert isinstance(expr1, scalar_type)

expr2 = func(value)
assert isinstance(expr2, ibis_type.scalar)
assert isinstance(expr2, scalar_type)


def test_decimal(dec):
Expand Down
109 changes: 54 additions & 55 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from public import public
from typing_extensions import get_args, get_origin, get_type_hints

import ibis.expr.types as ir
from ibis.common.annotations import attribute, optional
from ibis.common.collections import MapSet
from ibis.common.grounds import Concrete, Singleton
Expand Down Expand Up @@ -294,8 +293,8 @@ def from_ibis_dtype(value: DataType) -> DataType:
class Unknown(DataType, Singleton):
"""An unknown type."""

scalar = ir.UnknownScalar
column = ir.UnknownColumn
scalar = "UnknownScalar"
column = "UnknownColumn"


@public
Expand All @@ -321,16 +320,16 @@ def __class_getitem__(cls, params):
class Null(Primitive):
"""Null values."""

scalar = ir.NullScalar
column = ir.NullColumn
scalar = "NullScalar"
column = "NullColumn"


@public
class Boolean(Primitive):
"""[`True`][True] or [`False`][False] values."""

scalar = ir.BooleanScalar
column = ir.BooleanColumn
scalar = "BooleanScalar"
column = "BooleanColumn"


@public
Expand All @@ -350,8 +349,8 @@ class Numeric(DataType):
class Integer(Primitive, Numeric):
"""Integer values."""

scalar = ir.IntegerScalar
column = ir.IntegerColumn
scalar = "IntegerScalar"
column = "IntegerColumn"

@property
@abstractmethod
Expand All @@ -369,8 +368,8 @@ class String(Variadic, Singleton):
cannot assume that strings are UTF-8 encoded.
"""

scalar = ir.StringScalar
column = ir.StringColumn
scalar = "StringScalar"
column = "StringColumn"


@public
Expand All @@ -386,8 +385,8 @@ class Binary(Variadic, Singleton):
distinct types that have different behavior.
"""

scalar = ir.BinaryScalar
column = ir.BinaryColumn
scalar = "BinaryScalar"
column = "BinaryColumn"


@public
Expand All @@ -399,16 +398,16 @@ class Temporal(DataType):
class Date(Temporal, Primitive):
"""Date values."""

scalar = ir.DateScalar
column = ir.DateColumn
scalar = "DateScalar"
column = "DateColumn"


@public
class Time(Temporal, Primitive):
"""Time values."""

scalar = ir.TimeScalar
column = ir.TimeColumn
scalar = "TimeScalar"
column = "TimeColumn"


@public
Expand All @@ -421,8 +420,8 @@ class Timestamp(Temporal, Parametric):
scale = optional(isin(range(10)))
"""The scale of the timestamp if known."""

scalar = ir.TimestampScalar
column = ir.TimestampColumn
scalar = "TimestampScalar"
column = "TimestampColumn"

@property
def _pretty_piece(self) -> str:
Expand Down Expand Up @@ -468,8 +467,8 @@ def bounds(self):
class Floating(Primitive, Numeric):
"""Floating point values."""

scalar = ir.FloatingScalar
column = ir.FloatingColumn
scalar = "FloatingScalar"
column = "FloatingColumn"

@property
def largest(self):
Expand Down Expand Up @@ -569,8 +568,8 @@ class Decimal(Numeric, Parametric):
scale = optional(instance_of(int))
"""The number of values after the decimal point."""

scalar = ir.DecimalScalar
column = ir.DecimalColumn
scalar = "DecimalScalar"
column = "DecimalColumn"

def __init__(
self,
Expand Down Expand Up @@ -654,8 +653,8 @@ class Interval(Parametric):
value_type = optional(all_of([datatype, instance_of(Integer)]), default=Int32())
"""The underlying type of the stored values."""

scalar = ir.IntervalScalar
column = ir.IntervalColumn
scalar = "IntervalScalar"
column = "IntervalColumn"

# based on numpy's units
_units = {
Expand Down Expand Up @@ -695,8 +694,8 @@ class Struct(Parametric, MapSet):

fields = frozendict_of(instance_of(str), datatype)

scalar = ir.StructScalar
column = ir.StructColumn
scalar = "StructScalar"
column = "StructColumn"

def __class_getitem__(cls, fields):
return cls({slice_.start: slice_.stop for slice_ in fields})
Expand Down Expand Up @@ -755,8 +754,8 @@ class Array(Variadic, Parametric):

value_type = datatype

scalar = ir.ArrayScalar
column = ir.ArrayColumn
scalar = "ArrayScalar"
column = "ArrayColumn"

@property
def _pretty_piece(self) -> str:
Expand All @@ -769,8 +768,8 @@ class Set(Variadic, Parametric):

value_type = datatype

scalar = ir.SetScalar
column = ir.SetColumn
scalar = "SetScalar"
column = "SetColumn"

@property
def _pretty_piece(self) -> str:
Expand All @@ -784,8 +783,8 @@ class Map(Variadic, Parametric):
key_type = datatype
value_type = datatype

scalar = ir.MapScalar
column = ir.MapColumn
scalar = "MapScalar"
column = "MapColumn"

@property
def _pretty_piece(self) -> str:
Expand All @@ -796,8 +795,8 @@ def _pretty_piece(self) -> str:
class JSON(Variadic):
"""JSON values."""

scalar = ir.JSONScalar
column = ir.JSONColumn
scalar = "JSONScalar"
column = "JSONColumn"


@public
Expand All @@ -810,8 +809,8 @@ class GeoSpatial(DataType):
srid = optional(instance_of(int))
"""The spatial reference identifier."""

column = ir.GeoSpatialColumn
scalar = ir.GeoSpatialScalar
column = "GeoSpatialColumn"
scalar = "GeoSpatialScalar"

@property
def _pretty_piece(self) -> str:
Expand All @@ -827,16 +826,16 @@ def _pretty_piece(self) -> str:
class Point(GeoSpatial):
"""A point described by two coordinates."""

scalar = ir.PointScalar
column = ir.PointColumn
scalar = "PointScalar"
column = "PointColumn"


@public
class LineString(GeoSpatial):
"""A sequence of 2 or more points."""

scalar = ir.LineStringScalar
column = ir.LineStringColumn
scalar = "LineStringScalar"
column = "LineStringColumn"


@public
Expand All @@ -847,56 +846,56 @@ class Polygon(GeoSpatial):
rest represent holes in that shape (internal rings).
"""

scalar = ir.PolygonScalar
column = ir.PolygonColumn
scalar = "PolygonScalar"
column = "PolygonColumn"


@public
class MultiLineString(GeoSpatial):
"""A set of one or more line strings."""

scalar = ir.MultiLineStringScalar
column = ir.MultiLineStringColumn
scalar = "MultiLineStringScalar"
column = "MultiLineStringColumn"


@public
class MultiPoint(GeoSpatial):
"""A set of one or more points."""

scalar = ir.MultiPointScalar
column = ir.MultiPointColumn
scalar = "MultiPointScalar"
column = "MultiPointColumn"


@public
class MultiPolygon(GeoSpatial):
"""A set of one or more polygons."""

scalar = ir.MultiPolygonScalar
column = ir.MultiPolygonColumn
scalar = "MultiPolygonScalar"
column = "MultiPolygonColumn"


@public
class UUID(DataType):
"""A 128-bit number used to identify information in computer systems."""

scalar = ir.UUIDScalar
column = ir.UUIDColumn
scalar = "UUIDScalar"
column = "UUIDColumn"


@public
class MACADDR(String):
"""Media Access Control (MAC) address of a network interface."""

scalar = ir.MACADDRScalar
column = ir.MACADDRColumn
scalar = "MACADDRScalar"
column = "MACADDRColumn"


@public
class INET(String):
"""IP addresses."""

scalar = ir.INETScalar
column = ir.INETColumn
scalar = "INETScalar"
column = "INETColumn"


# ---------------------------------------------------------------------
Expand Down
8 changes: 6 additions & 2 deletions ibis/expr/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,14 @@ def output_shape(self) -> rlz.Shape:
"""

def to_expr(self):
import ibis.expr.types as ir

if self.output_shape.is_columnar():
return self.output_dtype.column(self)
typename = self.output_dtype.column
else:
return self.output_dtype.scalar(self)
typename = self.output_dtype.scalar

return getattr(ir, typename)(self)


@public
Expand Down

0 comments on commit d7f0be0

Please sign in to comment.