diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index 63fbc5537660..1406c227e3ce 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -19,8 +19,8 @@ rewrite_join, ) from ibis.backends.polars.compiler import translate -from ibis.backends.polars.datatypes import dtype_to_polars, schema_from_polars from ibis.backends.sql.dialects import Polars +from ibis.formats.polars import PolarsSchema from ibis.util import gen_name, normalize_filename if TYPE_CHECKING: @@ -70,7 +70,7 @@ def list_tables(self, like=None, database=None): return self._filter_with_like(list(self._tables.keys()), like) def table(self, name: str, _schema: sch.Schema | None = None) -> ir.Table: - schema = schema_from_polars(self._tables[name].schema) + schema = PolarsSchema.to_ibis(self._tables[name].schema) return ops.DatabaseTable(name, schema, self).to_expr() def register( @@ -342,10 +342,7 @@ def create_table( overwrite: bool = False, ) -> ir.Table: if schema is not None and obj is None: - obj = pl.LazyFrame( - [], - schema={name: dtype_to_polars(dtype) for name, dtype in schema.items()}, - ) + obj = pl.LazyFrame([], schema=PolarsSchema.from_ibis(schema)) if database is not None: raise com.IbisError( @@ -413,7 +410,7 @@ def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema: raise NotImplementedError("table.sql() not yet supported in polars") def _get_schema_using_query(self, query: str) -> sch.Schema: - return schema_from_polars(self._context.execute(query).schema) + return PolarsSchema.to_ibis(self._context.execute(query).schema) def execute( self, diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 22a88d84cb58..3dfac85712bc 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -16,8 +16,8 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.pandas.rewrites import PandasAsofJoin, PandasJoin, PandasRename -from ibis.backends.polars.datatypes import dtype_to_polars, schema_from_polars from ibis.expr.operations.udf import InputType +from ibis.formats.polars import PolarsSchema, PolarsType from ibis.util import gen_name @@ -68,13 +68,13 @@ def dummy_table(op, **kw): @translate.register(ops.InMemoryTable) def pandas_in_memory_table(op, **_): lf = pl.from_pandas(op.data.to_frame()).lazy() - schema = schema_from_polars(lf.schema) + schema = PolarsSchema.to_ibis(lf.schema) columns = [] for name, current_dtype in schema.items(): desired_dtype = op.schema[name] if current_dtype != desired_dtype: - typ = dtype_to_polars(desired_dtype) + typ = PolarsType.from_ibis(desired_dtype) columns.append(pl.col(name).cast(typ)) if columns: @@ -101,12 +101,12 @@ def literal(op, **_): if dtype.is_array(): value = pl.Series("", value) - typ = dtype_to_polars(dtype) + typ = PolarsType.from_ibis(dtype) val = pl.lit(value, dtype=typ) return val.implode() elif dtype.is_struct(): values = [ - pl.lit(v, dtype=dtype_to_polars(dtype[k])).alias(k) + pl.lit(v, dtype=PolarsType.from_ibis(dtype[k])).alias(k) for k, v in value.items() ] return pl.struct(values) @@ -117,7 +117,7 @@ def literal(op, **_): elif dtype.is_binary(): return pl.lit(value) else: - typ = dtype_to_polars(dtype) + typ = PolarsType.from_ibis(dtype) return pl.lit(op.value, dtype=typ) @@ -179,7 +179,7 @@ def _cast(op, strict=True, **kw): return arg.dt.truncate("1s") return arg - typ = dtype_to_polars(to) + typ = PolarsType.from_ibis(to) return arg.cast(typ, strict=strict) @@ -509,14 +509,14 @@ def in_values(op, **kw): @translate.register(ops.StringLength) def string_length(op, **kw): arg = translate(op.arg, **kw) - typ = dtype_to_polars(op.dtype) + typ = PolarsType.from_ibis(op.dtype) return arg.str.len_bytes().cast(typ) @translate.register(ops.Capitalize) def capitalize(op, **kw): arg = translate(op.arg, **kw) - typ = dtype_to_polars(op.dtype) + typ = PolarsType.from_ibis(op.dtype) first = arg.str.slice(0, 1).str.to_uppercase() rest = arg.str.slice(1, None).str.to_lowercase() return (first + rest).cast(typ) @@ -652,7 +652,7 @@ def str_right(op, **kw): @translate.register(ops.Round) def round(op, **kw): arg = translate(op.arg, **kw) - typ = dtype_to_polars(op.dtype) + typ = PolarsType.from_ibis(op.dtype) digits = _literal_value(op.digits) return arg.round(digits or 0).cast(typ) @@ -705,7 +705,7 @@ def repeat(op, **kw): @translate.register(ops.Sign) def sign(op, **kw): arg = translate(op.arg, **kw) - typ = dtype_to_polars(op.dtype) + typ = PolarsType.from_ibis(op.dtype) return arg.sign().cast(typ) @@ -765,7 +765,7 @@ def reduction(op, **kw): first, *rest = args method = operator.methodcaller(agg, *rest) return method(first.filter(reduce(operator.and_, predicates))).cast( - dtype_to_polars(op.dtype) + PolarsType.from_ibis(op.dtype) ) @@ -815,7 +815,7 @@ def count_star(op, **kw): result = pl.len() except AttributeError: result = pl.count() - return result.cast(dtype_to_polars(op.dtype)) + return result.cast(PolarsType.from_ibis(op.dtype)) @translate.register(ops.TimestampNow) @@ -1109,7 +1109,7 @@ def bitwise_binops(op, **kw): else: result = pl.map_batches([left, right], lambda cols: ufunc(cols[0], cols[1])) - return result.cast(dtype_to_polars(op.dtype)) + return result.cast(PolarsType.from_ibis(op.dtype)) @translate.register(ops.BitwiseNot) @@ -1149,7 +1149,7 @@ def binop(op, **kw): @translate.register(ops.ElementWiseVectorizedUDF) def elementwise_udf(op, **kw): func_args = [translate(arg, **kw) for arg in op.func_args] - return_type = dtype_to_polars(op.return_type) + return_type = PolarsType.from_ibis(op.return_type) return pl.map_batches( func_args, lambda args: op.func(*args), return_dtype=return_type @@ -1252,7 +1252,7 @@ def execute_count_distinct_star(op, **kw): # -> convert back to a polars series InputType.PYTHON: lambda func, dtype, args: pl.Series( map(func, *(arg.to_list() for arg in args)), - dtype=dtype_to_polars(dtype), + dtype=PolarsType.from_ibis(dtype), ), # Convert polars series into a pyarrow array # -> invoke the function on the pyarrow array @@ -1272,7 +1272,7 @@ def execute_scalar_udf(op, **kw): return pl.map_batches( exprs=[translate(arg, **kw) for arg in op.args], function=partial(_UDF_INVOKERS[input_type], op.__func__, dtype), - return_dtype=dtype_to_polars(dtype), + return_dtype=PolarsType.from_ibis(dtype), ) elif input_type == InputType.BUILTIN: first, *rest = map(translate, op.args) @@ -1307,7 +1307,9 @@ def split(args): arg = translate(op.arg, **kw) pattern = translate(op.pattern, **kw) return pl.map_batches( - exprs=(arg, pattern), function=split, return_dtype=dtype_to_polars(op.dtype) + exprs=(arg, pattern), + function=split, + return_dtype=PolarsType.from_ibis(op.dtype), ) @@ -1319,7 +1321,7 @@ def execute_integer_range(op, **kw): ) step = op.step.value - dtype = dtype_to_polars(op.dtype) + dtype = PolarsType.from_ibis(op.dtype) empty = pl.int_ranges(0, 0, dtype=dtype) if step == 0: diff --git a/ibis/backends/polars/datatypes.py b/ibis/backends/polars/datatypes.py deleted file mode 100644 index 0c4b0ee15374..000000000000 --- a/ibis/backends/polars/datatypes.py +++ /dev/null @@ -1,123 +0,0 @@ -from __future__ import annotations - -import functools - -import polars as pl - -import ibis.expr.datatypes as dt -import ibis.expr.schema as sch - -_to_polars_types = { - dt.Boolean: pl.Boolean, - dt.Null: pl.Null, - dt.String: pl.Utf8, - dt.Binary: pl.Binary, - dt.Date: pl.Date, - dt.Time: pl.Time, - dt.Int8: pl.Int8, - dt.Int16: pl.Int16, - dt.Int32: pl.Int32, - dt.Int64: pl.Int64, - dt.UInt8: pl.UInt8, - dt.UInt16: pl.UInt16, - dt.UInt32: pl.UInt32, - dt.UInt64: pl.UInt64, - dt.Float32: pl.Float32, - dt.Float64: pl.Float64, -} - -_from_polars_types = {v: k for k, v in _to_polars_types.items()} -_from_polars_types[pl.Categorical] = dt.String - -# `physical` and `lexical` were introduced in polars 0.20, but the constructor -# always accepted any string here -_from_polars_types[pl.Categorical(ordering="physical")] = dt.String -_from_polars_types[pl.Categorical(ordering="lexical")] = dt.String - - -@functools.singledispatch -def dtype_to_polars(dtype): - """Convert ibis dtype to the polars counterpart.""" - try: - return _to_polars_types[dtype.__class__] # else return pl.Object? - except KeyError: - raise NotImplementedError(f"Unsupported type: {dtype!r}") - - -@dtype_to_polars.register(dt.Decimal) -def from_ibis_decimal(dtype): - return pl.Decimal(precision=dtype.precision, scale=dtype.scale) - - -@dtype_to_polars.register(dt.Timestamp) -def from_ibis_timestamp(dtype): - return pl.Datetime("ns", dtype.timezone) - - -@dtype_to_polars.register(dt.Interval) -def from_ibis_interval(dtype): - if dtype.unit.short in {"us", "ns", "ms"}: - return pl.Duration(dtype.unit.short) - else: - raise ValueError(f"Unsupported polars duration unit: {dtype.unit}") - - -@dtype_to_polars.register(dt.Struct) -def from_ibis_struct(dtype): - fields = [ - pl.Field(name=name, dtype=dtype_to_polars(dtype)) - for name, dtype in dtype.fields.items() - ] - return pl.Struct(fields) - - -@dtype_to_polars.register(dt.Array) -def from_ibis_array(dtype): - return pl.List(dtype_to_polars(dtype.value_type)) - - -@functools.singledispatch -def dtype_from_polars(typ): - """Convert polars dtype to the ibis counterpart.""" - klass = _from_polars_types[typ] - return klass() - - -@dtype_from_polars.register(pl.Datetime) -def from_polars_datetime(typ): - try: - timezone = typ.time_zone - except AttributeError: # pragma: no cover - timezone = typ.tz # pragma: no cover - return dt.Timestamp(timezone=timezone) - - -@dtype_from_polars.register(pl.Duration) -def from_polars_duration(typ): - try: - time_unit = typ.time_unit - except AttributeError: # pragma: no cover - time_unit = typ.tu # pragma: no cover - return dt.Interval(unit=time_unit) - - -@dtype_from_polars.register(pl.List) -def from_polars_list(typ): - return dt.Array(dtype_from_polars(typ.inner)) - - -@dtype_from_polars.register(pl.Struct) -def from_polars_struct(typ): - return dt.Struct.from_tuples( - [(field.name, dtype_from_polars(field.dtype)) for field in typ.fields] - ) - - -@dtype_from_polars.register(pl.Decimal) -def from_polars_decimal(typ: pl.Decimal): - return dt.Decimal(precision=typ.precision, scale=typ.scale) - - -def schema_from_polars(schema: pl.Schema) -> sch.Schema: - fields = [(name, dtype_from_polars(typ)) for name, typ in schema.items()] - return sch.Schema.from_tuples(fields) diff --git a/ibis/formats/polars.py b/ibis/formats/polars.py new file mode 100644 index 000000000000..840b1bf588ba --- /dev/null +++ b/ibis/formats/polars.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import polars as pl + +import ibis.expr.datatypes as dt +from ibis.expr.schema import Schema +from ibis.formats import SchemaMapper, TypeMapper + +_to_polars_types = { + dt.Boolean: pl.Boolean, + dt.Null: pl.Null, + dt.String: pl.Utf8, + dt.Binary: pl.Binary, + dt.Date: pl.Date, + dt.Time: pl.Time, + dt.Int8: pl.Int8, + dt.Int16: pl.Int16, + dt.Int32: pl.Int32, + dt.Int64: pl.Int64, + dt.UInt8: pl.UInt8, + dt.UInt16: pl.UInt16, + dt.UInt32: pl.UInt32, + dt.UInt64: pl.UInt64, + dt.Float32: pl.Float32, + dt.Float64: pl.Float64, +} + +_from_polars_types = {v: k for k, v in _to_polars_types.items()} + + +class PolarsType(TypeMapper): + @classmethod + def to_ibis(cls, typ: pl.DataType, nullable=True) -> dt.DataType: + """Convert a polars type to an ibis type.""" + + base_type = typ.base_type() + if base_type is pl.Categorical: + return dt.String(nullable=nullable) + elif base_type is pl.Decimal: + return dt.Decimal( + precision=typ.precision, scale=typ.scale, nullable=nullable + ) + elif base_type is pl.Datetime: + try: + timezone = typ.time_zone + except AttributeError: # pragma: no cover + timezone = typ.tz # pragma: no cover + return dt.Timestamp(timezone=timezone, nullable=nullable) + elif base_type is pl.Duration: + try: + time_unit = typ.time_unit + except AttributeError: # pragma: no cover + time_unit = typ.tu # pragma: no cover + return dt.Interval(unit=time_unit, nullable=nullable) + elif base_type is pl.List: + return dt.Array(cls.to_ibis(typ.inner), nullable=nullable) + elif base_type is pl.Struct: + return dt.Struct.from_tuples( + [(field.name, cls.to_ibis(field.dtype)) for field in typ.fields], + nullable=nullable, + ) + else: + return _from_polars_types[base_type](nullable=nullable) + + @classmethod + def from_ibis(cls, dtype: dt.DataType) -> pl.DataType: + """Convert an ibis type to a polars type.""" + if dtype.is_decimal(): + return pl.Decimal( + precision=dtype.precision, + scale=9 if dtype.scale is None else dtype.scale, + ) + elif dtype.is_timestamp(): + return pl.Datetime("ns", dtype.timezone) + elif dtype.is_interval(): + if dtype.unit.short in {"us", "ns", "ms"}: + return pl.Duration(dtype.unit.short) + else: + raise ValueError(f"Unsupported polars duration unit: {dtype.unit}") + elif dtype.is_struct(): + fields = [ + pl.Field(name=name, dtype=cls.from_ibis(dtype)) + for name, dtype in dtype.fields.items() + ] + return pl.Struct(fields) + elif dtype.is_array(): + return pl.List(cls.from_ibis(dtype.value_type)) + else: + try: + return _to_polars_types[type(dtype)] + except KeyError: + raise NotImplementedError( + f"Converting {dtype} to polars is not supported yet" + ) + + +class PolarsSchema(SchemaMapper): + @classmethod + def from_ibis(cls, schema: Schema) -> dict[str, pl.DataType]: + """Convert a schema to a polars schema.""" + return {name: PolarsType.from_ibis(typ) for name, typ in schema.items()} + + @classmethod + def to_ibis(cls, schema: dict[str, pl.DataType]) -> Schema: + """Convert a polars schema to a schema.""" + return Schema.from_tuples( + [(name, PolarsType.to_ibis(typ)) for name, typ in schema.items()] + ) diff --git a/ibis/backends/polars/tests/test_datatypes.py b/ibis/formats/tests/test_polars.py similarity index 57% rename from ibis/backends/polars/tests/test_datatypes.py rename to ibis/formats/tests/test_polars.py index 61a1e09ecd01..43638232b3bb 100644 --- a/ibis/backends/polars/tests/test_datatypes.py +++ b/ibis/formats/tests/test_polars.py @@ -4,8 +4,9 @@ import pytest from pytest import param +import ibis import ibis.expr.datatypes as dt -from ibis.backends.polars.datatypes import dtype_from_polars, dtype_to_polars +from ibis.formats.polars import PolarsSchema, PolarsType @pytest.mark.parametrize( @@ -52,9 +53,44 @@ ], ) def test_to_from_ibis_type(ibis_dtype, polars_type): - assert dtype_to_polars(ibis_dtype) == polars_type - assert dtype_from_polars(polars_type) == ibis_dtype + assert PolarsType.from_ibis(ibis_dtype) == polars_type + assert PolarsType.to_ibis(polars_type) == ibis_dtype + assert PolarsType.to_ibis(polars_type, nullable=False) == ibis_dtype(nullable=False) + + +def test_decimal(): + assert PolarsType.to_ibis(pl.Decimal()) == dt.Decimal(precision=None, scale=0) + assert PolarsType.to_ibis(pl.Decimal(precision=6, scale=3)) == dt.Decimal( + precision=6, scale=3 + ) + assert PolarsType.from_ibis(dt.Decimal()) == pl.Decimal(precision=None, scale=9) + assert PolarsType.from_ibis(dt.Decimal(precision=6, scale=3)) == pl.Decimal( + precision=6, scale=3 + ) def test_categorical(): - assert dtype_from_polars(pl.Categorical()) == dt.string + assert PolarsType.to_ibis(pl.Categorical()) == dt.string + + +def test_interval_unsupported_unit(): + typ = dt.Interval(unit="s") + with pytest.raises(ValueError, match="Unsupported polars duration unit"): + PolarsType.from_ibis(typ) + + +def test_map_unsupported(): + typ = dt.Map(dt.String(), dt.Int64()) + with pytest.raises(NotImplementedError, match="to polars is not supported"): + PolarsType.from_ibis(typ) + + +def test_schema_to_and_from_ibis(): + polars_schema = {"x": pl.Int64, "y": pl.List(pl.Utf8)} + ibis_schema = ibis.schema({"x": "int64", "y": "array"}) + + s1 = PolarsSchema.to_ibis(polars_schema) + assert s1.equals(ibis_schema) + + s2 = PolarsSchema.from_ibis(ibis_schema) + assert s2 == polars_schema