Skip to content

Commit

Permalink
feat(polars): support version 1.0 and later (#9516)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Jul 13, 2024
1 parent 92eda02 commit 62a1864
Show file tree
Hide file tree
Showing 16 changed files with 81 additions and 73 deletions.
2 changes: 1 addition & 1 deletion conda/environment-arm64-flink.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies:
- pins >=0.8.2
- poetry-core >=1.0.0
- poetry-dynamic-versioning >=0.18.0
- polars >=0.20.17
- polars >=1,<2
- psycopg2 >=2.8.4
- pyarrow =11.0.0
- pyarrow-tests
Expand Down
2 changes: 1 addition & 1 deletion conda/environment-arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dependencies:
- pins >=0.8.2
- poetry-core >=1.0.0
- poetry-dynamic-versioning >=0.18.0
- polars >=0.20.17
- polars >=1,<2
- psycopg2 >=2.8.4
- pyarrow >=10.0.1
- pyarrow-tests
Expand Down
2 changes: 1 addition & 1 deletion conda/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- pip
- poetry-core >=1.0.0
- poetry-dynamic-versioning >=0.18.0
- polars >=0.20.17
- polars >=1,<2
- psycopg2 >=2.8.4
- pyarrow >=10.0.1
- pyarrow-hotfix >=0.4
Expand Down
15 changes: 7 additions & 8 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def version(self) -> str:
def list_tables(self, like=None, database=None):
return self._filter_with_like(list(self._tables.keys()), like)

def table(self, name: str, _schema: sch.Schema | None = None) -> ir.Table:
schema = PolarsSchema.to_ibis(self._tables[name].schema)
def table(self, name: str) -> ir.Table:
schema = sch.infer(self._tables[name])
return ops.DatabaseTable(name, schema, self).to_expr()

@deprecated(
Expand Down Expand Up @@ -198,7 +198,7 @@ def read_csv(
table = pl.scan_csv(source_list, **kwargs)
# triggers a schema computation to handle compressed csv inference
# and raise a compute error
table.schema # noqa: B018
table.collect_schema()
except pl.exceptions.ComputeError:
# handles compressed csvs
table = pl.read_csv(source_list, **kwargs)
Expand Down Expand Up @@ -463,7 +463,8 @@ def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema:
return self._get_schema_using_query(sql)

def _get_schema_using_query(self, query: str) -> sch.Schema:
return PolarsSchema.to_ibis(self._context.execute(query, eager=False).schema)
lazy_frame = self._context.execute(query, eager=False)
return sch.infer(lazy_frame)

def _to_dataframe(
self,
Expand All @@ -477,10 +478,8 @@ def _to_dataframe(
if limit == "default":
limit = ibis.options.sql.default_limit
if limit is not None:
df = lf.fetch(limit, streaming=streaming)
else:
df = lf.collect(streaming=streaming)
return df
lf = lf.limit(limit)
return lf.collect(streaming=streaming)

def execute(
self,
Expand Down
20 changes: 18 additions & 2 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,6 @@ def struct_column(op, **kw):
ops.Mean: "mean",
ops.Median: "median",
ops.Min: "min",
ops.Mode: "mode",
ops.StandardDev: "std",
ops.Sum: "sum",
ops.Variance: "var",
Expand Down Expand Up @@ -768,6 +767,23 @@ def reduction(op, **kw):
)


@translate.register(ops.Mode)
def execute_mode(op, **kw):
arg = translate(op.arg, **kw)

predicate = arg.is_not_null()
if (where := op.where) is not None:
predicate &= translate(where, **kw)

dtype = PolarsType.from_ibis(op.dtype)
# `mode` can return more than one value so the additional `get(0)` call is
# necessary to enforce aggregation behavior of a scalar value per group
#
# eventually we may want to support an Ibis API like `modes` that returns a
# list of all the modes per group.
return arg.filter(predicate).mode().get(0).cast(dtype)


@translate.register(ops.Quantile)
def execute_quantile(op, **kw):
arg = translate(op.arg, **kw)
Expand Down Expand Up @@ -1228,7 +1244,7 @@ def _arg_min_max(op, func, **kw):
translate_key = translate(key, **kw)

not_null_mask = translate_arg.is_not_null() & translate_key.is_not_null()
return translate_arg.filter(not_null_mask).gather(
return translate_arg.filter(not_null_mask).get(
func(translate_key.filter(not_null_mask))
)

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@
GoogleBadRequest = None

try:
from polars import ComputeError as PolarsComputeError
from polars import PanicException as PolarsPanicException
from polars.exceptions import ColumnNotFoundError as PolarsColumnNotFoundError
from polars.exceptions import ComputeError as PolarsComputeError
from polars.exceptions import InvalidOperationError as PolarsInvalidOperationError
from polars.exceptions import PanicException as PolarsPanicException
from polars.exceptions import SchemaError as PolarsSchemaError
except ImportError:
PolarsComputeError = PolarsPanicException = PolarsInvalidOperationError = (
Expand Down
15 changes: 4 additions & 11 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,19 +1180,12 @@ def test_string_quantile(alltypes, func):
reason="doesn't support median of dates",
)
@pytest.mark.notimpl(["dask"], raises=(AssertionError, NotImplementedError, TypeError))
@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError)
@pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream")
@pytest.mark.parametrize(
"func",
[
param(
methodcaller("quantile", 0.5),
id="quantile",
),
],
@pytest.mark.notyet(
["polars"], raises=PolarsInvalidOperationError, reason="not supported upstream"
)
def test_date_quantile(alltypes, func):
expr = func(alltypes.timestamp_col.date())
def test_date_quantile(alltypes):
expr = alltypes.timestamp_col.date().quantile(0.5)
result = expr.execute()
assert result == date(2009, 12, 31)

Expand Down
7 changes: 0 additions & 7 deletions ibis/backends/tests/test_dot_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
ExaQueryError,
GoogleBadRequest,
OracleDatabaseError,
PolarsComputeError,
)

dot_sql_never = pytest.mark.never(
Expand Down Expand Up @@ -119,11 +118,6 @@ def test_table_dot_sql(backend):


@dot_sql_never
@pytest.mark.notyet(
["polars"],
raises=PolarsComputeError,
reason="polars doesn't support aliased tables",
)
@pytest.mark.notyet(
["bigquery"], raises=GoogleBadRequest, reason="requires a qualified name"
)
Expand Down Expand Up @@ -287,7 +281,6 @@ def test_order_by_no_projection(backend):


@dot_sql_never
@pytest.mark.notyet(["polars"], raises=PolarsComputeError)
def test_dot_sql_limit(con):
expr = con.sql('SELECT * FROM (SELECT \'abc\' "ts") "x"', dialect="duckdb").limit(1)
result = expr.execute()
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players):
],
)
def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype):
if backend.name() == "polars":
pytest.skip("polars crashes the interpreter")

result = (
backend.functional_alltypes.limit(1)
.double_col.cast(dtype)
Expand Down
15 changes: 10 additions & 5 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,11 +795,6 @@ def test_table_info_large(con):
raises=com.OperationNotDefinedError,
reason="Mode and StandardDev is not supported",
)
@pytest.mark.notimpl(
["polars"],
raises=PolarsSchemaError,
reason="cannot extend/append Float64 with Float32",
)
@pytest.mark.notyet(
["druid"],
raises=PyDruidProgrammingError,
Expand Down Expand Up @@ -863,6 +858,11 @@ def test_table_info_large(con):
condition=is_newer_than("pandas", "2.1.0"),
reason="FutureWarning: concat empty or all-NA entries is deprecated",
),
pytest.mark.notyet(
["polars"],
raises=PolarsSchemaError,
reason="type Float32 is incompatible with expected type Float64",
),
],
id="all_cols",
),
Expand Down Expand Up @@ -894,6 +894,11 @@ def test_table_info_large(con):
raises=OracleDatabaseError,
reason="Mode is not supported and ORA-02000: missing AS keyword",
),
pytest.mark.notimpl(
["polars"],
raises=PolarsSchemaError,
reason="type Float32 is incompatible with expected type Float64",
),
],
id="numeric_col",
),
Expand Down
7 changes: 6 additions & 1 deletion ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def test_numeric_literal(con, backend, expr, expected_types):
"datafusion": decimal.Decimal("1.1"),
"oracle": decimal.Decimal("1.1"),
"flink": decimal.Decimal("1.1"),
"polars": decimal.Decimal("1.1"),
},
{
"bigquery": "NUMERIC",
Expand Down Expand Up @@ -303,6 +304,7 @@ def test_numeric_literal(con, backend, expr, expected_types):
"datafusion": decimal.Decimal("1.1"),
"oracle": decimal.Decimal("1.1"),
"flink": decimal.Decimal("1.1"),
"polars": decimal.Decimal("1.1"),
},
{
"bigquery": "NUMERIC",
Expand Down Expand Up @@ -371,6 +373,7 @@ def test_numeric_literal(con, backend, expr, expected_types):
raises=Py4JJavaError,
),
pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError),
pytest.mark.notyet(["polars"], raises=RuntimeError),
],
id="decimal-big",
),
Expand Down Expand Up @@ -435,6 +438,7 @@ def test_numeric_literal(con, backend, expr, expected_types):
),
pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest),
pytest.mark.notyet(["exasol"], raises=ExaQueryError),
pytest.mark.broken(["polars"], reason="panic", raises=BaseException),
],
id="decimal-infinity+",
),
Expand Down Expand Up @@ -499,6 +503,7 @@ def test_numeric_literal(con, backend, expr, expected_types):
),
pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest),
pytest.mark.notyet(["exasol"], raises=ExaQueryError),
pytest.mark.broken(["polars"], reason="panic", raises=BaseException),
],
id="decimal-infinity-",
),
Expand Down Expand Up @@ -566,12 +571,12 @@ def test_numeric_literal(con, backend, expr, expected_types):
),
pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest),
pytest.mark.notyet(["exasol"], raises=ExaQueryError),
pytest.mark.broken(["polars"], reason="panic", raises=BaseException),
],
id="decimal-NaN",
),
],
)
@pytest.mark.notimpl(["polars"], raises=TypeError)
def test_decimal_literal(con, backend, expr, expected_types, expected_result):
backend_name = backend.name()
result = con.execute(expr)
Expand Down
17 changes: 5 additions & 12 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
MySQLOperationalError,
MySQLProgrammingError,
OracleDatabaseError,
PolarsComputeError,
PolarsInvalidOperationError,
PolarsPanicException,
PsycoPg2InternalError,
Py4JJavaError,
Expand Down Expand Up @@ -1447,7 +1447,7 @@ def test_integer_to_timestamp(backend, con, unit):
pytest.mark.never(
["polars"],
reason="datetime formatting style not supported",
raises=PolarsComputeError,
raises=PolarsInvalidOperationError,
),
pytest.mark.never(
["duckdb"],
Expand Down Expand Up @@ -1526,7 +1526,7 @@ def test_string_to_timestamp(alltypes, fmt):
pytest.mark.never(
["polars"],
reason="datetime formatting style not supported",
raises=PolarsComputeError,
raises=PolarsInvalidOperationError,
),
pytest.mark.never(
["duckdb"],
Expand Down Expand Up @@ -2073,7 +2073,7 @@ def test_integer_cast_to_timestamp_scalar(alltypes, df):
["flink"],
raises=ArrowInvalid,
)
@pytest.mark.notyet(["polars"], raises=PolarsComputeError)
@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError)
def test_big_timestamp(con):
# TODO: test with a timezone
ts = "2419-10-11 10:10:25"
Expand Down Expand Up @@ -2135,14 +2135,7 @@ def test_timestamp_date_comparison(backend, alltypes, df, left_fn, right_fn):
raises=AssertionError,
)
@pytest.mark.notimpl(["pyspark"], raises=pd.errors.OutOfBoundsDatetime)
@pytest.mark.notimpl(
["polars"],
raises=PolarsPanicException,
reason=(
"called `Result::unwrap()` on an `Err` value: PyErr { type: <class 'OverflowError'>, "
"value: OverflowError('int too big to convert'), traceback: None }"
),
)
@pytest.mark.broken(["polars"], raises=AssertionError, reason="returns NaT")
@pytest.mark.broken(
["flink"],
reason="Casting from timestamp[s] to timestamp[ns] would result in out of bounds timestamp: 81953424000",
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def infer_pyarrow_table(table):
def infer_polars_dataframe(df):
from ibis.formats.polars import PolarsSchema

return PolarsSchema.to_ibis(df.schema)
return PolarsSchema.to_ibis(df.collect_schema())


# lock the dispatchers to avoid adding new implementations
Expand Down
Loading

0 comments on commit 62a1864

Please sign in to comment.