Skip to content

Commit

Permalink
feat(flink): array sort
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and jcrist committed Aug 9, 2024
1 parent eb857e6 commit ca85ae2
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 17 deletions.
14 changes: 12 additions & 2 deletions ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ArgMax,
ops.ArgMin,
ops.ArrayFlatten,
ops.ArraySort,
ops.ArrayStringJoin,
ops.Correlation,
ops.CountDistinctStar,
Expand Down Expand Up @@ -102,6 +101,7 @@ class FlinkCompiler(SQLGlotCompiler):
ops.ArrayLength: "cardinality",
ops.ArrayPosition: "array_position",
ops.ArrayRemove: "array_remove",
ops.ArraySort: "array_sort",
ops.ArrayUnion: "array_union",
ops.ExtractDayOfYear: "dayofyear",
ops.MapKeys: "map_keys",
Expand Down Expand Up @@ -576,10 +576,20 @@ def visit_StructColumn(self, op, *, names, values):
return self.cast(sge.Struct(expressions=list(values)), op.dtype)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
# the only way to get filtering *and* respecting nulls is to use
# `FILTER` syntax, but it's broken in various ways for other aggregates
out = self.f.array_agg(arg)
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
return self.agg.array_agg(arg, where=where, order_by=order_by)
if where is not None:
out = sge.Filter(this=out, expression=sge.Where(this=where))
return out


compiler = FlinkCompiler()
1 change: 1 addition & 0 deletions ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class Generator(Hive.Generator):
sge.ArrayConcat: rename_func("array_concat"),
sge.ArraySize: rename_func("cardinality"),
sge.ArrayAgg: rename_func("array_agg"),
sge.ArraySort: rename_func("array_sort"),
sge.Length: rename_func("char_length"),
sge.TryCast: lambda self,
e: f"TRY_CAST({e.this.sql(self.dialect)} AS {e.to.sql(self.dialect)})",
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,13 +1480,13 @@ def test_collect_ordered(alltypes, df, filtered):
def test_collect(alltypes, df, filtered, include_null):
ibis_cond = (_.id % 13 == 0) if filtered else None
pd_cond = (df.id % 13 == 0) if filtered else slice(None)
res = (
expr = (
alltypes.string_col.nullif("3")
.collect(where=ibis_cond, include_null=include_null)
.length()
.execute()
)
vals = df.string_col if include_null else df.string_col[(df.string_col != "3")]
res = expr.execute()
vals = df.string_col if include_null else df.string_col[df.string_col != "3"]
sol = len(vals[pd_cond])
assert res == sol

Expand Down
32 changes: 23 additions & 9 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,13 @@ def test_unnest_idempotent(backend):
["scalar_column", array_types.x.cast("!array<int64>").unnest().name("x")]
)
.group_by("scalar_column")
.aggregate(x=lambda t: t.x.collect())
.aggregate(x=lambda t: t.x.collect().sort())
.order_by("scalar_column")
)
result = expr.execute().reset_index(drop=True)
expected = (
df[["scalar_column", "x"]]
.assign(x=df.x.map(lambda arr: [i for i in arr if not pd.isna(i)]))
.assign(x=df.x.map(lambda arr: sorted(i for i in arr if not pd.isna(i))))
.sort_values("scalar_column")
.reset_index(drop=True)
)
Expand Down Expand Up @@ -718,20 +718,34 @@ def test_array_unique(con, input, expected):


@builtin_array
@pytest.mark.notimpl(
["flink", "polars"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(
["risingwave"],
raises=AssertionError,
reason="Refer to https://github.com/risingwavelabs/risingwave/issues/14735",
)
def test_array_sort(con):
t = ibis.memtable({"a": [[3, 2], [], [42, 42], []], "id": range(4)})
@pytest.mark.parametrize(
"data",
(
param(
[[3, 2], [], [42, 42], []],
marks=[
pytest.mark.notyet(
["flink"],
raises=Py4JJavaError,
reason="flink cannot handle empty arrays",
)
],
),
[[3, 2], [42, 42]],
),
ids=["empty", "nonempty"],
)
def test_array_sort(con, data):
t = ibis.memtable({"a": data, "id": range(len(data))})
expr = t.mutate(a=t.a.sort()).order_by("id")
result = con.execute(expr)
expected = pd.Series([[2, 3], [], [42, 42], []], dtype="object")
expected = pd.Series(list(map(sorted, data)), dtype="object")

assert frozenset(map(tuple, result["a"].values)) == frozenset(
map(tuple, expected.values)
Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ def test_struct_column(alltypes, df):

@pytest.mark.notimpl(["postgres", "risingwave", "polars"])
@pytest.mark.notyet(["datafusion"], raises=Exception, reason="unsupported syntax")
@pytest.mark.notyet(
["flink"], reason="flink doesn't support creating struct columns from collect"
)
def test_collect_into_struct(alltypes):
from ibis import _

Expand Down

0 comments on commit ca85ae2

Please sign in to comment.