From 62a1864245a0fdbb0b8db36316e1ac6d3645697a Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 13 Jul 2024 10:07:34 -0700 Subject: [PATCH] feat(polars): support version 1.0 and later (#9516) --- conda/environment-arm64-flink.yml | 2 +- conda/environment-arm64.yml | 2 +- conda/environment.yml | 2 +- ibis/backends/polars/__init__.py | 15 +++++----- ibis/backends/polars/compiler.py | 20 +++++++++++-- ibis/backends/tests/errors.py | 4 +-- ibis/backends/tests/test_aggregation.py | 15 +++------- ibis/backends/tests/test_dot_sql.py | 7 ----- ibis/backends/tests/test_export.py | 3 ++ ibis/backends/tests/test_generic.py | 15 ++++++---- ibis/backends/tests/test_numeric.py | 7 ++++- ibis/backends/tests/test_temporal.py | 17 ++++------- ibis/expr/schema.py | 2 +- poetry.lock | 39 +++++++++++++------------ pyproject.toml | 2 +- requirements-dev.txt | 2 +- 16 files changed, 81 insertions(+), 73 deletions(-) diff --git a/conda/environment-arm64-flink.yml b/conda/environment-arm64-flink.yml index 477e215e0bf3..fdeb90aaa0f1 100644 --- a/conda/environment-arm64-flink.yml +++ b/conda/environment-arm64-flink.yml @@ -27,7 +27,7 @@ dependencies: - pins >=0.8.2 - poetry-core >=1.0.0 - poetry-dynamic-versioning >=0.18.0 - - polars >=0.20.17 + - polars >=1,<2 - psycopg2 >=2.8.4 - pyarrow =11.0.0 - pyarrow-tests diff --git a/conda/environment-arm64.yml b/conda/environment-arm64.yml index 9b733687aaf3..ff3eb085cee4 100644 --- a/conda/environment-arm64.yml +++ b/conda/environment-arm64.yml @@ -27,7 +27,7 @@ dependencies: - pins >=0.8.2 - poetry-core >=1.0.0 - poetry-dynamic-versioning >=0.18.0 - - polars >=0.20.17 + - polars >=1,<2 - psycopg2 >=2.8.4 - pyarrow >=10.0.1 - pyarrow-tests diff --git a/conda/environment.yml b/conda/environment.yml index ee14f8e7d5b4..4be800d63f56 100644 --- a/conda/environment.yml +++ b/conda/environment.yml @@ -28,7 +28,7 @@ dependencies: - pip - poetry-core >=1.0.0 - poetry-dynamic-versioning >=0.18.0 - - polars >=0.20.17 + - polars >=1,<2 - psycopg2 >=2.8.4 - pyarrow >=10.0.1 - pyarrow-hotfix >=0.4 diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index 9f2290ba8414..aca178ccb254 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -71,8 +71,8 @@ def version(self) -> str: def list_tables(self, like=None, database=None): return self._filter_with_like(list(self._tables.keys()), like) - def table(self, name: str, _schema: sch.Schema | None = None) -> ir.Table: - schema = PolarsSchema.to_ibis(self._tables[name].schema) + def table(self, name: str) -> ir.Table: + schema = sch.infer(self._tables[name]) return ops.DatabaseTable(name, schema, self).to_expr() @deprecated( @@ -198,7 +198,7 @@ def read_csv( table = pl.scan_csv(source_list, **kwargs) # triggers a schema computation to handle compressed csv inference # and raise a compute error - table.schema # noqa: B018 + table.collect_schema() except pl.exceptions.ComputeError: # handles compressed csvs table = pl.read_csv(source_list, **kwargs) @@ -463,7 +463,8 @@ def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema: return self._get_schema_using_query(sql) def _get_schema_using_query(self, query: str) -> sch.Schema: - return PolarsSchema.to_ibis(self._context.execute(query, eager=False).schema) + lazy_frame = self._context.execute(query, eager=False) + return sch.infer(lazy_frame) def _to_dataframe( self, @@ -477,10 +478,8 @@ def _to_dataframe( if limit == "default": limit = ibis.options.sql.default_limit if limit is not None: - df = lf.fetch(limit, streaming=streaming) - else: - df = lf.collect(streaming=streaming) - return df + lf = lf.limit(limit) + return lf.collect(streaming=streaming) def execute( self, diff --git a/ibis/backends/polars/compiler.py b/ibis/backends/polars/compiler.py index 4cdfd8714dfe..a05d28212f21 100644 --- a/ibis/backends/polars/compiler.py +++ b/ibis/backends/polars/compiler.py @@ -739,7 +739,6 @@ def struct_column(op, **kw): ops.Mean: "mean", ops.Median: "median", ops.Min: "min", - ops.Mode: "mode", ops.StandardDev: "std", ops.Sum: "sum", ops.Variance: "var", @@ -768,6 +767,23 @@ def reduction(op, **kw): ) +@translate.register(ops.Mode) +def execute_mode(op, **kw): + arg = translate(op.arg, **kw) + + predicate = arg.is_not_null() + if (where := op.where) is not None: + predicate &= translate(where, **kw) + + dtype = PolarsType.from_ibis(op.dtype) + # `mode` can return more than one value so the additional `get(0)` call is + # necessary to enforce aggregation behavior of a scalar value per group + # + # eventually we may want to support an Ibis API like `modes` that returns a + # list of all the modes per group. + return arg.filter(predicate).mode().get(0).cast(dtype) + + @translate.register(ops.Quantile) def execute_quantile(op, **kw): arg = translate(op.arg, **kw) @@ -1228,7 +1244,7 @@ def _arg_min_max(op, func, **kw): translate_key = translate(key, **kw) not_null_mask = translate_arg.is_not_null() & translate_key.is_not_null() - return translate_arg.filter(not_null_mask).gather( + return translate_arg.filter(not_null_mask).get( func(translate_key.filter(not_null_mask)) ) diff --git a/ibis/backends/tests/errors.py b/ibis/backends/tests/errors.py index b11cdf66a9c5..044b85dc981f 100644 --- a/ibis/backends/tests/errors.py +++ b/ibis/backends/tests/errors.py @@ -59,10 +59,10 @@ GoogleBadRequest = None try: - from polars import ComputeError as PolarsComputeError - from polars import PanicException as PolarsPanicException from polars.exceptions import ColumnNotFoundError as PolarsColumnNotFoundError + from polars.exceptions import ComputeError as PolarsComputeError from polars.exceptions import InvalidOperationError as PolarsInvalidOperationError + from polars.exceptions import PanicException as PolarsPanicException from polars.exceptions import SchemaError as PolarsSchemaError except ImportError: PolarsComputeError = PolarsPanicException = PolarsInvalidOperationError = ( diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 5b9627181514..7ed2819adaad 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1180,19 +1180,12 @@ def test_string_quantile(alltypes, func): reason="doesn't support median of dates", ) @pytest.mark.notimpl(["dask"], raises=(AssertionError, NotImplementedError, TypeError)) -@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError) @pytest.mark.notyet(["datafusion"], raises=Exception, reason="not supported upstream") -@pytest.mark.parametrize( - "func", - [ - param( - methodcaller("quantile", 0.5), - id="quantile", - ), - ], +@pytest.mark.notyet( + ["polars"], raises=PolarsInvalidOperationError, reason="not supported upstream" ) -def test_date_quantile(alltypes, func): - expr = func(alltypes.timestamp_col.date()) +def test_date_quantile(alltypes): + expr = alltypes.timestamp_col.date().quantile(0.5) result = expr.execute() assert result == date(2009, 12, 31) diff --git a/ibis/backends/tests/test_dot_sql.py b/ibis/backends/tests/test_dot_sql.py index 0b42e68c7b5a..1d59b6ec5aff 100644 --- a/ibis/backends/tests/test_dot_sql.py +++ b/ibis/backends/tests/test_dot_sql.py @@ -19,7 +19,6 @@ ExaQueryError, GoogleBadRequest, OracleDatabaseError, - PolarsComputeError, ) dot_sql_never = pytest.mark.never( @@ -119,11 +118,6 @@ def test_table_dot_sql(backend): @dot_sql_never -@pytest.mark.notyet( - ["polars"], - raises=PolarsComputeError, - reason="polars doesn't support aliased tables", -) @pytest.mark.notyet( ["bigquery"], raises=GoogleBadRequest, reason="requires a qualified name" ) @@ -287,7 +281,6 @@ def test_order_by_no_projection(backend): @dot_sql_never -@pytest.mark.notyet(["polars"], raises=PolarsComputeError) def test_dot_sql_limit(con): expr = con.sql('SELECT * FROM (SELECT \'abc\' "ts") "x"', dialect="duckdb").limit(1) result = expr.execute() diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index 0960016eccf7..12c6aa2d7e49 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -371,6 +371,9 @@ def test_table_to_csv_writer_kwargs(delimiter, tmp_path, awards_players): ], ) def test_to_pyarrow_decimal(backend, dtype, pyarrow_dtype): + if backend.name() == "polars": + pytest.skip("polars crashes the interpreter") + result = ( backend.functional_alltypes.limit(1) .double_col.cast(dtype) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index ac0f18516ec8..c4670fd69a97 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -795,11 +795,6 @@ def test_table_info_large(con): raises=com.OperationNotDefinedError, reason="Mode and StandardDev is not supported", ) -@pytest.mark.notimpl( - ["polars"], - raises=PolarsSchemaError, - reason="cannot extend/append Float64 with Float32", -) @pytest.mark.notyet( ["druid"], raises=PyDruidProgrammingError, @@ -863,6 +858,11 @@ def test_table_info_large(con): condition=is_newer_than("pandas", "2.1.0"), reason="FutureWarning: concat empty or all-NA entries is deprecated", ), + pytest.mark.notyet( + ["polars"], + raises=PolarsSchemaError, + reason="type Float32 is incompatible with expected type Float64", + ), ], id="all_cols", ), @@ -894,6 +894,11 @@ def test_table_info_large(con): raises=OracleDatabaseError, reason="Mode is not supported and ORA-02000: missing AS keyword", ), + pytest.mark.notimpl( + ["polars"], + raises=PolarsSchemaError, + reason="type Float32 is incompatible with expected type Float64", + ), ], id="numeric_col", ), diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index b5699a089194..2019b08c406d 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -259,6 +259,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "datafusion": decimal.Decimal("1.1"), "oracle": decimal.Decimal("1.1"), "flink": decimal.Decimal("1.1"), + "polars": decimal.Decimal("1.1"), }, { "bigquery": "NUMERIC", @@ -303,6 +304,7 @@ def test_numeric_literal(con, backend, expr, expected_types): "datafusion": decimal.Decimal("1.1"), "oracle": decimal.Decimal("1.1"), "flink": decimal.Decimal("1.1"), + "polars": decimal.Decimal("1.1"), }, { "bigquery": "NUMERIC", @@ -371,6 +373,7 @@ def test_numeric_literal(con, backend, expr, expected_types): raises=Py4JJavaError, ), pytest.mark.notyet(["mssql"], raises=PyODBCProgrammingError), + pytest.mark.notyet(["polars"], raises=RuntimeError), ], id="decimal-big", ), @@ -435,6 +438,7 @@ def test_numeric_literal(con, backend, expr, expected_types): ), pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest), pytest.mark.notyet(["exasol"], raises=ExaQueryError), + pytest.mark.broken(["polars"], reason="panic", raises=BaseException), ], id="decimal-infinity+", ), @@ -499,6 +503,7 @@ def test_numeric_literal(con, backend, expr, expected_types): ), pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest), pytest.mark.notyet(["exasol"], raises=ExaQueryError), + pytest.mark.broken(["polars"], reason="panic", raises=BaseException), ], id="decimal-infinity-", ), @@ -566,12 +571,12 @@ def test_numeric_literal(con, backend, expr, expected_types): ), pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest), pytest.mark.notyet(["exasol"], raises=ExaQueryError), + pytest.mark.broken(["polars"], reason="panic", raises=BaseException), ], id="decimal-NaN", ), ], ) -@pytest.mark.notimpl(["polars"], raises=TypeError) def test_decimal_literal(con, backend, expr, expected_types, expected_result): backend_name = backend.name() result = con.execute(expr) diff --git a/ibis/backends/tests/test_temporal.py b/ibis/backends/tests/test_temporal.py index b3fb08dd914d..a1e5667dc7e6 100644 --- a/ibis/backends/tests/test_temporal.py +++ b/ibis/backends/tests/test_temporal.py @@ -30,7 +30,7 @@ MySQLOperationalError, MySQLProgrammingError, OracleDatabaseError, - PolarsComputeError, + PolarsInvalidOperationError, PolarsPanicException, PsycoPg2InternalError, Py4JJavaError, @@ -1447,7 +1447,7 @@ def test_integer_to_timestamp(backend, con, unit): pytest.mark.never( ["polars"], reason="datetime formatting style not supported", - raises=PolarsComputeError, + raises=PolarsInvalidOperationError, ), pytest.mark.never( ["duckdb"], @@ -1526,7 +1526,7 @@ def test_string_to_timestamp(alltypes, fmt): pytest.mark.never( ["polars"], reason="datetime formatting style not supported", - raises=PolarsComputeError, + raises=PolarsInvalidOperationError, ), pytest.mark.never( ["duckdb"], @@ -2073,7 +2073,7 @@ def test_integer_cast_to_timestamp_scalar(alltypes, df): ["flink"], raises=ArrowInvalid, ) -@pytest.mark.notyet(["polars"], raises=PolarsComputeError) +@pytest.mark.notyet(["polars"], raises=PolarsInvalidOperationError) def test_big_timestamp(con): # TODO: test with a timezone ts = "2419-10-11 10:10:25" @@ -2135,14 +2135,7 @@ def test_timestamp_date_comparison(backend, alltypes, df, left_fn, right_fn): raises=AssertionError, ) @pytest.mark.notimpl(["pyspark"], raises=pd.errors.OutOfBoundsDatetime) -@pytest.mark.notimpl( - ["polars"], - raises=PolarsPanicException, - reason=( - "called `Result::unwrap()` on an `Err` value: PyErr { type: , " - "value: OverflowError('int too big to convert'), traceback: None }" - ), -) +@pytest.mark.broken(["polars"], raises=AssertionError, reason="returns NaT") @pytest.mark.broken( ["flink"], reason="Casting from timestamp[s] to timestamp[ns] would result in out of bounds timestamp: 81953424000", diff --git a/ibis/expr/schema.py b/ibis/expr/schema.py index 472834c8a323..6f3553dd28f5 100644 --- a/ibis/expr/schema.py +++ b/ibis/expr/schema.py @@ -290,7 +290,7 @@ def infer_pyarrow_table(table): def infer_polars_dataframe(df): from ibis.formats.polars import PolarsSchema - return PolarsSchema.to_ibis(df.schema) + return PolarsSchema.to_ibis(df.collect_schema()) # lock the dispatchers to avoid adding new implementations diff --git a/poetry.lock b/poetry.lock index 814650806885..93294fa1b274 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4518,39 +4518,40 @@ poetry-core = ">=1.7.0,<3.0.0" [[package]] name = "polars" -version = "0.20.31" +version = "1.0.0" description = "Blazingly fast DataFrame library" optional = true python-versions = ">=3.8" files = [ - {file = "polars-0.20.31-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:86454ade5ed302bbf87f145cfcb1b14f7a5765a9440e448659e1f3dba6ac4e79"}, - {file = "polars-0.20.31-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:67f2fe842262b7e1b9371edad21b760f6734d28b74c78dda88dff1bf031b9499"}, - {file = "polars-0.20.31-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24b82441f93409e0e8abd6f427b029db102f02b8de328cee9a680f84b84e3736"}, - {file = "polars-0.20.31-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:87f43bce4d41abf8c8c5658d881e4b8378e5c61010a696bfea8b4106b908e916"}, - {file = "polars-0.20.31-cp38-abi3-win_amd64.whl", hash = "sha256:2d7567c9fd9d3b9aa93387ca9880d9e8f7acea3c0a0555c03d8c0c2f0715d43c"}, - {file = "polars-0.20.31.tar.gz", hash = "sha256:00f62dec6bf43a4e2a5db58b99bf0e79699fe761c80ae665868eaea5168f3bbb"}, + {file = "polars-1.0.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:cf454ee75a2346cd7f44fb536cc69af7a26d8a243ea58bda50f6c810742c76ad"}, + {file = "polars-1.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:8191d8b5cf68d5ebaf9efb497120ff6d7e607a57a116bcce43618d50a536fe1c"}, + {file = "polars-1.0.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5b58575fd7ddc12bc53adfde933da3b40c2841fdc5396fecbd85e80dfc9332e"}, + {file = "polars-1.0.0-cp38-abi3-manylinux_2_24_aarch64.whl", hash = "sha256:44475877179f261f4ce1a6cfa0fc955392798b9987c17fc2b1a4b294602ace8a"}, + {file = "polars-1.0.0-cp38-abi3-win_amd64.whl", hash = "sha256:bd483045c0629afced9e9ebc83b58550640022db5924d553a068a57621260a22"}, + {file = "polars-1.0.0.tar.gz", hash = "sha256:144a63d6d61dc5d675304673c4261ceccf4cfc75277431389d4afe9a5be0f70b"}, ] [package.extras] -adbc = ["adbc-driver-manager", "adbc-driver-sqlite"] -all = ["polars[adbc,async,cloudpickle,connectorx,deltalake,fastexcel,fsspec,gevent,iceberg,numpy,pandas,plot,pyarrow,pydantic,sqlalchemy,timezone,xlsx2csv,xlsxwriter]"] -async = ["nest-asyncio"] +adbc = ["adbc-driver-manager[dbapi]", "adbc-driver-sqlite[dbapi]"] +all = ["polars[async,cloudpickle,database,deltalake,excel,fsspec,graph,iceberg,numpy,pandas,plot,pyarrow,pydantic,style,timezone]"] +async = ["gevent"] +calamine = ["fastexcel (>=0.9)"] cloudpickle = ["cloudpickle"] connectorx = ["connectorx (>=0.3.2)"] +database = ["nest-asyncio", "polars[adbc,connectorx,sqlalchemy]"] deltalake = ["deltalake (>=0.15.0)"] -fastexcel = ["fastexcel (>=0.9)"] +excel = ["polars[calamine,openpyxl,xlsx2csv,xlsxwriter]"] fsspec = ["fsspec"] -gevent = ["gevent"] +graph = ["matplotlib"] iceberg = ["pyiceberg (>=0.5.0)"] -matplotlib = ["matplotlib"] -numpy = ["numpy (>=1.16.0)"] +numpy = ["numpy (>=1.16.0,<2.0.0)"] openpyxl = ["openpyxl (>=3.0.0)"] -pandas = ["pandas", "pyarrow (>=7.0.0)"] -plot = ["hvplot (>=0.9.1)"] +pandas = ["pandas", "polars[pyarrow]"] +plot = ["hvplot (>=0.9.1)", "polars[pandas]"] pyarrow = ["pyarrow (>=7.0.0)"] pydantic = ["pydantic"] -pyxlsb = ["pyxlsb (>=1.0)"] -sqlalchemy = ["pandas", "sqlalchemy"] +sqlalchemy = ["polars[pandas]", "sqlalchemy"] +style = ["great-tables (>=0.8.0)"] timezone = ["backports-zoneinfo", "tzdata"] xlsx2csv = ["xlsx2csv (>=0.8.0)"] xlsxwriter = ["xlsxwriter"] @@ -7693,4 +7694,4 @@ visualization = ["graphviz"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "c9c2d6ae775188d8776fd6d6ec563edf98377b56ddd59d934161d8712dca9d90" +content-hash = "860d081f536ce48d748d67ad828aac337b0f08fee7cf772d172db2b8f6fd37c2" diff --git a/pyproject.toml b/pyproject.toml index 865f03251916..7c9eed91bcee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ oracledb = { version = ">=1.3.1,<3", optional = true } packaging = { version = ">=21.3,<25", optional = true } pins = { version = ">=0.8.3,<1", extras = ["gcs"], optional = true } fsspec = { version = "<2024.6.2", optional = true } -polars = { version = ">=0.20.17,<1", optional = true } +polars = { version = ">=1,<2", optional = true } psycopg2 = { version = ">=2.8.4,<3", optional = true } pydata-google-auth = { version = ">=1.4.0,<2", optional = true } pydruid = { version = ">=0.6.7,<1", optional = true } diff --git a/requirements-dev.txt b/requirements-dev.txt index 12f3549416c9..8dcc0da9c755 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -170,7 +170,7 @@ poetry-core==1.9.0 ; python_version >= "3.10" and python_version < "4.0" poetry-dynamic-versioning==1.4.0 ; python_version >= "3.10" and python_version < "4.0" poetry-plugin-export==1.8.0 ; python_version >= "3.10" and python_version < "4.0" poetry==1.8.3 ; python_version >= "3.10" and python_version < "4.0" -polars==0.20.31 ; python_version >= "3.10" and python_version < "4.0" +polars==1.0.0 ; python_version >= "3.10" and python_version < "4.0" pprintpp==0.4.0 ; python_version >= "3.10" and python_version < "4.0" pre-commit==3.7.1 ; python_version >= "3.10" and python_version < "4.0" prometheus-client==0.20.0 ; python_version >= "3.10" and python_version < "3.13"