From 0b96b6829ef4634be8200a82f262a7938d2affdd Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Fri, 27 Oct 2023 22:48:23 +0200 Subject: [PATCH] feat(datafusion): add some array functions --- ibis/backends/datafusion/__init__.py | 3 +- ibis/backends/datafusion/compiler/values.py | 23 +++++++++++ ibis/backends/datafusion/tests/conftest.py | 4 +- ibis/backends/tests/test_array.py | 46 ++++++++++++++------- 4 files changed, 60 insertions(+), 16 deletions(-) diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 67c36b7b64b3..39549adfc50f 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -41,7 +41,7 @@ import pandas as pd -_exclude_exp = (exp.Pow,) +_exclude_exp = (exp.Pow, exp.ArrayContains) # the DataFusion dialect was created to skip the power function to operator transformation @@ -66,6 +66,7 @@ class Backend(BaseBackend, CanCreateDatabase, CanCreateSchema): dialect = "datafusion" builder = None supports_in_memory_tables = True + supports_arrays = True @property def version(self): diff --git a/ibis/backends/datafusion/compiler/values.py b/ibis/backends/datafusion/compiler/values.py index a37ca10c8b63..32a76c4abcb1 100644 --- a/ibis/backends/datafusion/compiler/values.py +++ b/ibis/backends/datafusion/compiler/values.py @@ -105,6 +105,9 @@ def translate_val(op, **_): ops.Degrees: "degrees", ops.NullIf: "nullif", ops.Pi: "pi", + ops.ArrayContains: "array_contains", + ops.ArrayLength: "array_length", + ops.ArrayRemove: "array_remove_all", } for _op, _name in _simple_ops.items(): @@ -710,3 +713,23 @@ def _if_else(op, *, bool_expr, true_expr, false_null_expr, **_): @translate_val.register(ops.NotNull) def _not_null(op, *, arg, **_): return sg.not_(arg.is_(NULL)) + + +@translate_val.register(ops.ArrayColumn) +def array_column(op, *, cols, **_): + return F.make_array(*cols) + + +@translate_val.register(ops.ArrayRepeat) +def array_repeat(op, *, arg, times, **_): + return F.flatten(F.array_repeat(arg, times)) + + +@translate_val.register(ops.ArrayConcat) +def array_concat(op, *, arg, **_): + return F.array_concat(*arg) + + +@translate_val.register(ops.ArrayPosition) +def array_position(op, *, arg, other, **_): + return F.coalesce(F.array_position(arg, other), 0) diff --git a/ibis/backends/datafusion/tests/conftest.py b/ibis/backends/datafusion/tests/conftest.py index a7a5a7a899b5..a6656b9b81b7 100644 --- a/ibis/backends/datafusion/tests/conftest.py +++ b/ibis/backends/datafusion/tests/conftest.py @@ -7,6 +7,7 @@ import ibis from ibis.backends.conftest import TEST_TABLES from ibis.backends.tests.base import BackendTest, RoundAwayFromZero +from ibis.backends.tests.data import array_types class TestConf(BackendTest, RoundAwayFromZero): @@ -15,7 +16,7 @@ class TestConf(BackendTest, RoundAwayFromZero): # returned_timestamp_unit = 'ns' supports_structs = False supports_json = False - supports_arrays = False + supports_arrays = True stateful = False deps = ("datafusion",) @@ -24,6 +25,7 @@ def _load_data(self, **_: Any) -> None: for table_name in TEST_TABLES: path = self.data_dir / "parquet" / f"{table_name}.parquet" con.register(path, table_name=table_name) + con.register(array_types, table_name="array_types") @staticmethod def connect(*, tmpdir, worker_id, **kw): diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 67ca7058feee..3d69c9ecc730 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -54,7 +54,6 @@ # list. -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_array_column(backend, alltypes, df): expr = ibis.array([alltypes["double_col"], alltypes["double_col"]]) assert isinstance(expr, ir.ArrayColumn) @@ -91,7 +90,7 @@ def test_array_scalar(con, backend): assert con.execute(expr.typeof()) == ARRAY_BACKEND_TYPES[backend_name] -@pytest.mark.notimpl(["polars", "datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) def test_array_repeat(con): expr = ibis.array([1.0, 2.0]) * 2 @@ -102,7 +101,6 @@ def test_array_repeat(con): # Issues #2370 -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_array_concat(con): left = ibis.literal([1, 2, 3]) right = ibis.literal([2, 1]) @@ -113,7 +111,6 @@ def test_array_concat(con): # Issues #2370 -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_array_concat_variadic(con): left = ibis.literal([1, 2, 3]) right = ibis.literal([2, 1]) @@ -124,7 +121,7 @@ def test_array_concat_variadic(con): # Issues #2370 -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["datafusion"], raises=BaseException) @pytest.mark.notyet( ["postgres", "trino"], raises=sa.exc.ProgrammingError, @@ -139,7 +136,6 @@ def test_array_concat_some_empty(con): assert np.array_equal(result, expected) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_array_radd_concat(con): left = [1] right = ibis.literal([2]) @@ -150,7 +146,6 @@ def test_array_radd_concat(con): assert np.array_equal(result, expected) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_array_length(con): expr = ibis.literal([1, 2, 3]).length() assert con.execute(expr.name("tmp")) == 3 @@ -193,13 +188,21 @@ def test_array_index(con, idx): ["sqlite"], reason="array types are unsupported", raises=NotImplementedError ), # someone just needs to implement these - pytest.mark.notimpl(["datafusion"], raises=Exception), ) @builtin_array @pytest.mark.never( - ["clickhouse", "duckdb", "pandas", "pyspark", "snowflake", "polars", "trino"], + [ + "clickhouse", + "duckdb", + "pandas", + "pyspark", + "snowflake", + "polars", + "trino", + "datafusion", + ], reason="backend does not flatten array types", raises=AssertionError, ) @@ -234,7 +237,16 @@ def test_array_discovery_postgres(backend): raises=AssertionError, ) @pytest.mark.never( - ["duckdb", "pandas", "postgres", "pyspark", "snowflake", "polars", "trino"], + [ + "duckdb", + "pandas", + "postgres", + "pyspark", + "snowflake", + "polars", + "trino", + "datafusion", + ], reason="backend supports nullable nested types", raises=AssertionError, ) @@ -334,6 +346,7 @@ def test_array_discovery_snowflake(backend): raises=BadRequest, ) @pytest.mark.notimpl(["dask"], raises=ValueError) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_unnest_simple(backend): array_types = backend.array_types expected = ( @@ -350,6 +363,7 @@ def test_unnest_simple(backend): @builtin_array @pytest.mark.notimpl("dask", raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_unnest_complex(backend): array_types = backend.array_types df = array_types.execute() @@ -389,6 +403,7 @@ def test_unnest_complex(backend): raises=AssertionError, ) @pytest.mark.notimpl(["dask"], raises=ValueError) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_unnest_idempotent(backend): array_types = backend.array_types df = array_types.execute() @@ -409,6 +424,7 @@ def test_unnest_idempotent(backend): @builtin_array @pytest.mark.notimpl("dask", raises=ValueError) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_unnest_no_nulls(backend): array_types = backend.array_types df = array_types.execute() @@ -435,6 +451,7 @@ def test_unnest_no_nulls(backend): @builtin_array @pytest.mark.notimpl("dask", raises=ValueError) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_unnest_default_name(backend): array_types = backend.array_types df = array_types.execute() @@ -561,10 +578,9 @@ def test_array_filter(backend, con, input, output): @builtin_array @pytest.mark.notimpl( - ["datafusion", "mssql", "pandas", "polars", "postgres"], + ["mssql", "pandas", "polars", "postgres"], raises=com.OperationNotDefinedError, ) -@pytest.mark.notimpl(["datafusion"], raises=Exception) @pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError) @pytest.mark.never(["impala"], reason="array_types table isn't defined") def test_array_contains(backend, con): @@ -577,7 +593,7 @@ def test_array_contains(backend, con): @builtin_array @pytest.mark.notimpl( - ["dask", "datafusion", "impala", "mssql", "pandas", "polars"], + ["dask", "impala", "mssql", "pandas", "polars"], raises=com.OperationNotDefinedError, ) def test_array_position(backend, con): @@ -590,7 +606,7 @@ def test_array_position(backend, con): @builtin_array @pytest.mark.notimpl( - ["dask", "datafusion", "impala", "mssql", "pandas", "polars"], + ["dask", "impala", "mssql", "pandas", "polars"], raises=com.OperationNotDefinedError, ) def test_array_remove(backend, con): @@ -708,6 +724,7 @@ def test_array_intersect(con): reason="ClickHouse won't accept dicts for struct type values", ) @pytest.mark.notimpl(["postgres"], raises=sa.exc.ProgrammingError) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) def test_unnest_struct(con): data = {"value": [[{"a": 1}, {"a": 2}], [{"a": 3}, {"a": 4}]]} t = ibis.memtable(data, schema=ibis.schema({"value": "!array>"})) @@ -754,6 +771,7 @@ def test_zip(backend): reason="https://github.com/ClickHouse/ClickHouse/issues/41112", ) @pytest.mark.notimpl(["postgres"], raises=sa.exc.ProgrammingError) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl( ["polars"], raises=com.OperationNotDefinedError,