diff --git a/ibis/backends/sql/compilers/flink.py b/ibis/backends/sql/compilers/flink.py index 6a7f9cf456ee..f064205d861b 100644 --- a/ibis/backends/sql/compilers/flink.py +++ b/ibis/backends/sql/compilers/flink.py @@ -72,7 +72,6 @@ class FlinkCompiler(SQLGlotCompiler): ops.ArgMax, ops.ArgMin, ops.ArrayFlatten, - ops.ArraySort, ops.ArrayStringJoin, ops.Correlation, ops.CountDistinctStar, @@ -102,6 +101,7 @@ class FlinkCompiler(SQLGlotCompiler): ops.ArrayLength: "cardinality", ops.ArrayPosition: "array_position", ops.ArrayRemove: "array_remove", + ops.ArraySort: "array_sort", ops.ArrayUnion: "array_union", ops.ExtractDayOfYear: "dayofyear", ops.MapKeys: "map_keys", @@ -576,10 +576,20 @@ def visit_StructColumn(self, op, *, names, values): return self.cast(sge.Struct(expressions=list(values)), op.dtype) def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): + if order_by: + raise com.UnsupportedOperationError( + "ordering of order-sensitive aggregations via `order_by` is " + "not supported for this backend" + ) + # the only way to get filtering *and* respecting nulls is to use + # `FILTER` syntax, but it's broken in various ways for other aggregates + out = self.f.array_agg(arg) if not include_null: cond = arg.is_(sg.not_(NULL, copy=False)) where = cond if where is None else sge.And(this=cond, expression=where) - return self.agg.array_agg(arg, where=where, order_by=order_by) + if where is not None: + out = sge.Filter(this=out, expression=sge.Where(this=where)) + return out compiler = FlinkCompiler() diff --git a/ibis/backends/sql/dialects.py b/ibis/backends/sql/dialects.py index e90c0a16787b..c7ad3c479ec4 100644 --- a/ibis/backends/sql/dialects.py +++ b/ibis/backends/sql/dialects.py @@ -212,6 +212,7 @@ class Generator(Hive.Generator): sge.ArrayConcat: rename_func("array_concat"), sge.ArraySize: rename_func("cardinality"), sge.ArrayAgg: rename_func("array_agg"), + sge.ArraySort: rename_func("array_sort"), sge.Length: rename_func("char_length"), sge.TryCast: lambda self, e: f"TRY_CAST({e.this.sql(self.dialect)} AS {e.to.sql(self.dialect)})", diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index a139359e2480..679e3c155e62 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1480,13 +1480,13 @@ def test_collect_ordered(alltypes, df, filtered): def test_collect(alltypes, df, filtered, include_null): ibis_cond = (_.id % 13 == 0) if filtered else None pd_cond = (df.id % 13 == 0) if filtered else slice(None) - res = ( + expr = ( alltypes.string_col.nullif("3") .collect(where=ibis_cond, include_null=include_null) .length() - .execute() ) - vals = df.string_col if include_null else df.string_col[(df.string_col != "3")] + res = expr.execute() + vals = df.string_col if include_null else df.string_col[df.string_col != "3"] sol = len(vals[pd_cond]) assert res == sol diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index d4ca266629ca..7c373fd59f53 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -316,13 +316,13 @@ def test_unnest_idempotent(backend): ["scalar_column", array_types.x.cast("!array").unnest().name("x")] ) .group_by("scalar_column") - .aggregate(x=lambda t: t.x.collect()) + .aggregate(x=lambda t: t.x.collect().sort()) .order_by("scalar_column") ) result = expr.execute().reset_index(drop=True) expected = ( df[["scalar_column", "x"]] - .assign(x=df.x.map(lambda arr: [i for i in arr if not pd.isna(i)])) + .assign(x=df.x.map(lambda arr: sorted(i for i in arr if not pd.isna(i)))) .sort_values("scalar_column") .reset_index(drop=True) ) @@ -718,20 +718,34 @@ def test_array_unique(con, input, expected): @builtin_array -@pytest.mark.notimpl( - ["flink", "polars"], - raises=com.OperationNotDefinedError, -) +@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.notyet( ["risingwave"], raises=AssertionError, reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14735", ) -def test_array_sort(con): - t = ibis.memtable({"a": [[3, 2], [], [42, 42], []], "id": range(4)}) +@pytest.mark.parametrize( + "data", + ( + param( + [[3, 2], [], [42, 42], []], + marks=[ + pytest.mark.notyet( + ["flink"], + raises=Py4JJavaError, + reason="flink cannot handle empty arrays", + ) + ], + ), + [[3, 2], [42, 42]], + ), + ids=["empty", "nonempty"], +) +def test_array_sort(con, data): + t = ibis.memtable({"a": data, "id": range(len(data))}) expr = t.mutate(a=t.a.sort()).order_by("id") result = con.execute(expr) - expected = pd.Series([[2, 3], [], [42, 42], []], dtype="object") + expected = pd.Series(list(map(sorted, data)), dtype="object") assert frozenset(map(tuple, result["a"].values)) == frozenset( map(tuple, expected.values) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index 5eb8b95d721f..3c37cb4234c8 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -116,9 +116,6 @@ def test_struct_column(alltypes, df): @pytest.mark.notimpl(["postgres", "risingwave", "polars"]) @pytest.mark.notyet(["datafusion"], raises=Exception, reason="unsupported syntax") -@pytest.mark.notyet( - ["flink"], reason="flink doesn't support creating struct columns from collect" -) def test_collect_into_struct(alltypes): from ibis import _