Skip to content

Commit

Permalink
refactor(datatype): use a mapping to store StructType fields rather…
Browse files Browse the repository at this point in the history
… than `names` and `types` tuples

also schedule `Struct.from_dict()`, `Struct.pairs` and `Struct(names, types)` constructor for removal
  • Loading branch information
kszucs committed Jan 24, 2023
1 parent c162750 commit ff34c7b
Show file tree
Hide file tree
Showing 23 changed files with 123 additions and 80 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _pg_map(dialect, itype):
@to_sqla_type.register(Dialect, dt.Struct)
def _struct(dialect, itype):
return StructType(
[(name, to_sqla_type(dialect, type)) for name, type in itype.pairs.items()]
[(name, to_sqla_type(dialect, type)) for name, type in itype.fields.items()]
)


Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _(ty: dt.Map) -> str:
@serialize_raw.register(dt.Struct)
def _(ty: dt.Struct) -> str:
fields = ", ".join(
f"{name} {serialize(field_ty)}" for name, field_ty in ty.pairs.items()
f"{name} {serialize(field_ty)}" for name, field_ty in ty.fields.items()
)
return f"Tuple({fields})"

Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/clickhouse/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_columns_types_with_additional_argument(con):
param("Decimal(10, 3)", dt.Decimal(10, 3, nullable=False), id="decimal"),
param(
"Tuple(a String, b Array(Nullable(Float64)))",
dt.Struct.from_dict(
dt.Struct(
dict(
a=dt.String(nullable=False),
b=dt.Array(dt.float64, nullable=False),
Expand All @@ -172,7 +172,7 @@ def test_columns_types_with_additional_argument(con):
),
param(
"Tuple(String, Array(Nullable(Float64)))",
dt.Struct.from_dict(
dt.Struct(
dict(
f0=dt.String(nullable=False),
f1=dt.Array(dt.float64, nullable=False),
Expand All @@ -183,7 +183,7 @@ def test_columns_types_with_additional_argument(con):
),
param(
"Tuple(a String, Array(Nullable(Float64)))",
dt.Struct.from_dict(
dt.Struct(
dict(
a=dt.String(nullable=False),
f1=dt.Array(dt.float64, nullable=False),
Expand All @@ -194,7 +194,7 @@ def test_columns_types_with_additional_argument(con):
),
param(
"Nested(a String, b Array(Nullable(Float64)))",
dt.Struct.from_dict(
dt.Struct(
dict(
a=dt.Array(dt.String(nullable=False), nullable=False),
b=dt.Array(dt.Array(dt.float64, nullable=False), nullable=False),
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
["column_name", "column_type", "null"], rows.mappings()
):
ibis_type = parse(type)
yield name, ibis_type(nullable=null.lower() == "yes")
yield name, ibis_type.copy(nullable=null.lower() == "yes")

def _register_in_memory_table(self, table_op):
df = table_op.data.to_frame()
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/duckdb/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
P=dt.string,
Q=dt.Array(dt.int32),
R=dt.Map(dt.string, dt.int64),
S=dt.Struct.from_dict(
S=dt.Struct(
dict(
a=dt.int32,
b=dt.string,
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def from_ibis_interval(dtype):
def from_ibis_struct(dtype):
fields = [
pl.Field(name=name, dtype=to_polars_type(dtype))
for name, dtype in dtype.pairs.items()
for name, dtype in dtype.fields.items()
]
return pl.Struct(fields)

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pyarrow/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def from_ibis_interval(dtype: dt.Interval):
@to_pyarrow_type.register
def from_ibis_struct(dtype: dt.Struct):
return pa.struct(
pa.field(name, to_pyarrow_type(typ)) for name, typ in dtype.pairs.items()
pa.field(name, to_pyarrow_type(typ)) for name, typ in dtype.fields.items()
)


Expand Down
22 changes: 15 additions & 7 deletions ibis/backends/pyspark/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.backends.base.sql.registry import sql_type_names
from ibis.expr.schema import Schema

_sql_type_names = dict(sql_type_names, date='date')

Expand Down Expand Up @@ -72,10 +72,11 @@ def _spark_map(spark_dtype_obj, nullable=True):

@dt.dtype.register(pt.StructType)
def _spark_struct(spark_dtype_obj, nullable=True):
names = spark_dtype_obj.names
fields = spark_dtype_obj.fields
ibis_types = [dt.dtype(f.dataType, nullable=f.nullable) for f in fields]
return dt.Struct(names, ibis_types, nullable=nullable)
fields = {
n: dt.dtype(f.dataType, nullable=f.nullable)
for n, f in zip(spark_dtype_obj.names, spark_dtype_obj.fields)
}
return dt.Struct(fields, nullable=nullable)


_IBIS_DTYPE_TO_SPARK_DTYPE = {v: k for k, v in _SPARK_DTYPE_TO_IBIS_DTYPE.items()}
Expand Down Expand Up @@ -122,10 +123,17 @@ def _map(ibis_dtype_obj):


@spark_dtype.register(dt.Struct)
@spark_dtype.register(Schema)
def _struct(ibis_dtype_obj):
fields = [
pt.StructField(n, spark_dtype(t), t.nullable)
for n, t in zip(ibis_dtype_obj.names, ibis_dtype_obj.types)
for n, t in ibis_dtype_obj.fields.items()
]
return pt.StructType(fields)


@spark_dtype.register(sch.Schema)
def _schema(ibis_schem_obj):
fields = [
pt.StructField(n, spark_dtype(t), t.nullable) for n, t in ibis_schem_obj.items()
]
return pt.StructType(fields)
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_aggregate_multikey_group_reduction_udf(backend, alltypes, df):

@reduction(
input_type=[dt.double],
output_type=dt.Struct(['mean', 'std'], [dt.double, dt.double]),
output_type=dt.Struct({'mean': dt.double, 'std': dt.double}),
)
def mean_and_std(v):
return v.mean(), v.std()
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_null_literal(con, field):
def test_struct_column(alltypes, df):
t = alltypes
expr = ibis.struct(dict(a=t.string_col, b=1, c=t.bigint_col)).name("s")
assert expr.type() == dt.Struct.from_dict(dict(a=dt.string, b=dt.int8, c=dt.int64))
assert expr.type() == dt.Struct(dict(a=dt.string, b=dt.int8, c=dt.int64))
result = expr.execute()
expected = pd.Series(
(dict(a=a, b=1, c=c) for a, c in zip(df.string_col, df.bigint_col)),
Expand Down
16 changes: 8 additions & 8 deletions ibis/backends/tests/test_vectorized_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def add_one_struct(v):
def create_add_one_struct_udf(result_formatter):
return elementwise(
input_type=[dt.double],
output_type=dt.Struct(['col1', 'col2'], [dt.double, dt.double]),
output_type=dt.Struct({'col1': dt.double, 'col2': dt.double}),
)(_format_struct_udf_return_type(add_one_struct, result_formatter))


Expand Down Expand Up @@ -127,7 +127,7 @@ def create_add_one_struct_udf(result_formatter):

@elementwise(
input_type=[dt.double],
output_type=dt.Struct(['double_col', 'col2'], [dt.double, dt.double]),
output_type=dt.Struct({'double_col': dt.double, 'col2': dt.double}),
)
def overwrite_struct_elementwise(v):
assert isinstance(v, pd.Series)
Expand All @@ -137,7 +137,7 @@ def overwrite_struct_elementwise(v):
@elementwise(
input_type=[dt.double],
output_type=dt.Struct(
['double_col', 'col2', 'float_col'], [dt.double, dt.double, dt.double]
{'double_col': dt.double, 'col2': dt.double, 'float_col': dt.double}
),
)
def multiple_overwrite_struct_elementwise(v):
Expand All @@ -147,7 +147,7 @@ def multiple_overwrite_struct_elementwise(v):

@analytic(
input_type=[dt.double, dt.double],
output_type=dt.Struct(['double_col', 'demean_weight'], [dt.double, dt.double]),
output_type=dt.Struct({'double_col': dt.double, 'demean_weight': dt.double}),
)
def overwrite_struct_analytic(v, w):
assert isinstance(v, pd.Series)
Expand All @@ -165,7 +165,7 @@ def demean_struct(v, w):
def create_demean_struct_udf(result_formatter):
return analytic(
input_type=[dt.double, dt.double],
output_type=dt.Struct(['demean', 'demean_weight'], [dt.double, dt.double]),
output_type=dt.Struct({'demean': dt.double, 'demean_weight': dt.double}),
)(_format_struct_udf_return_type(demean_struct, result_formatter))


Expand Down Expand Up @@ -203,7 +203,7 @@ def mean_struct(v, w):
def create_mean_struct_udf(result_formatter):
return reduction(
input_type=[dt.double, dt.int64],
output_type=dt.Struct(['mean', 'mean_weight'], [dt.double, dt.double]),
output_type=dt.Struct({'mean': dt.double, 'mean_weight': dt.double}),
)(_format_struct_udf_return_type(mean_struct, result_formatter))


Expand All @@ -220,7 +220,7 @@ def create_mean_struct_udf(result_formatter):

@reduction(
input_type=[dt.double, dt.int64],
output_type=dt.Struct(['double_col', 'mean_weight'], [dt.double, dt.double]),
output_type=dt.Struct({'double_col': dt.double, 'mean_weight': dt.double}),
)
def overwrite_struct_reduction(v, w):
assert isinstance(v, (np.ndarray, pd.Series))
Expand Down Expand Up @@ -495,7 +495,7 @@ def test_elementwise_udf_destructure_exact_once(
):
@elementwise(
input_type=[dt.double],
output_type=dt.Struct(['col1', 'col2'], [dt.double, dt.double]),
output_type=dt.Struct({'col1': dt.double, 'col2': dt.double}),
)
def add_one_struct_exact_once(v):
key = v.iloc[0]
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/trino/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _string(_, itype):
@to_sqla_type.register(TrinoDialect, dt.Struct)
def _struct(dialect, itype):
return ROW(
[(name, to_sqla_type(dialect, typ)) for name, typ in itype.pairs.items()]
[(name, to_sqla_type(dialect, typ)) for name, typ in itype.fields.items()]
)


Expand Down
78 changes: 57 additions & 21 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from ibis.common.grounds import Concrete, Singleton
from ibis.common.validators import (
all_of,
frozendict_of,
instance_of,
isin,
map_to,
tuple_of,
validator,
)
from ibis.util import deprecated, warn_deprecated

dtype = Dispatcher('dtype')

Expand Down Expand Up @@ -642,18 +643,42 @@ def to_integer_type(self):
class Struct(DataType):
"""Structured values."""

names = tuple_of(instance_of(str))
types = tuple_of(datatype)
fields = frozendict_of(instance_of(str), datatype)

scalar = ir.StructScalar
column = ir.StructColumn

def __init__(self, names, types, **kwargs):
if len(names) != len(types):
raise IbisTypeError(
'Struct datatype names and types must have the same length'
@classmethod
def __create__(cls, names, types=None, nullable=True):
if types is None:
fields = names
else:
warn_deprecated(
"Struct(names, types)",
as_of="4.1",
removed_in="5.0",
instead=(
"construct a Struct type using a mapping of names to types instead: "
"Struct(dict(zip(names, types)))"
),
)
super().__init__(names=names, types=types, **kwargs)
if len(names) != len(types):
raise IbisTypeError(
'Struct datatype names and types must have the same length'
)
fields = dict(zip(names, types))

return super().__create__(fields=fields, nullable=nullable)

def __reduce__(self):
return (self.__class__, (self.fields, None, self.nullable))

def copy(self, fields=None, nullable=None):
if fields is None:
fields = self.fields
if nullable is None:
nullable = self.nullable
return type(self)(fields, nullable=nullable)

@classmethod
def from_tuples(
Expand All @@ -673,10 +698,14 @@ def from_tuples(
Struct
Struct data type instance
"""
names, types = zip(*pairs)
return cls(names, types, nullable=nullable)
return cls(dict(pairs), nullable=nullable)

@classmethod
@deprecated(
as_of="4.1",
removed_in="5.0",
instead="directly construct a Struct type instead",
)
def from_dict(
cls, pairs: Mapping[str, str | DataType], nullable: bool = True
) -> Struct:
Expand All @@ -694,26 +723,33 @@ def from_dict(
Struct
Struct data type instance
"""
names, types = pairs.keys(), pairs.values()
return cls(names, types, nullable=nullable)
return cls(pairs, nullable=nullable)

@property
@deprecated(
as_of="4.1",
removed_in="5.0",
instead="use struct_type.fields attribute instead",
)
def pairs(self) -> Mapping[str, DataType]:
"""Return a mapping from names to data type instances.
return self.fields

Returns
-------
Mapping[str, DataType]
Mapping of field name to data type
"""
return dict(zip(self.names, self.types))
@property
def names(self) -> tuple[str, ...]:
"""Return the names of the struct's fields."""
return tuple(self.fields.keys())

@property
def types(self) -> tuple[DataType, ...]:
"""Return the types of the struct's fields."""
return tuple(self.fields.values())

def __getitem__(self, key: str) -> DataType:
return self.pairs[key]
return self.fields[key]

def __repr__(self) -> str:
return '{}({}, nullable={})'.format(
self.name, list(self.pairs.items()), self.nullable
self.name, list(self.fields.items()), self.nullable
)

@property
Expand Down
9 changes: 6 additions & 3 deletions ibis/expr/datatypes/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def infer(value: Any) -> dt.DataType:
raise InputTypeError(value)


# TODO(kszucs): support NamedTuples and dataclasses instead of OrderedDict
# which should trigger infer_map instead
@infer.register(collections.OrderedDict)
def infer_struct(value: Mapping[str, Any]) -> dt.Struct:
"""Infer the [`Struct`][ibis.expr.datatypes.Struct] type of `value`."""
if not value:
raise TypeError('Empty struct type not supported')
return dt.Struct(list(value.keys()), list(map(infer, value.values())))
fields = {name: infer(val) for name, val in value.items()}
return dt.Struct(fields)


@infer.register(collections.abc.Mapping)
Expand All @@ -51,7 +54,7 @@ def infer_map(value: Mapping[Any, Any]) -> dt.Map:
highest_precedence(map(infer, value.values())),
)
except IbisTypeError:
return dt.Struct.from_dict(toolz.valmap(infer, value, factory=type(value)))
return dt.Struct(toolz.valmap(infer, value, factory=type(value)))


@infer.register((list, tuple))
Expand Down Expand Up @@ -303,7 +306,7 @@ def normalize(typ, value):
return frozendict({k: normalize(typ.value_type, v) for k, v in value.items()})
elif typ.is_struct():
return frozendict(
{k: normalize(typ[k], v) for k, v in value.items() if k in typ.pairs}
{k: normalize(typ[k], v) for k, v in value.items() if k in typ.fields}
)
elif typ.is_geospatial():
if isinstance(value, (tuple, list)):
Expand Down
Loading

0 comments on commit ff34c7b

Please sign in to comment.