Skip to content

Commit

Permalink
feat(pyspark): implement Table.unnest
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jun 27, 2024
1 parent 9eb5f93 commit 268299b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
65 changes: 65 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,68 @@ def visit_HexDigest(self, op, *, arg, how):
return self.f.sha2(arg, int(how[-3:]))
else:
raise NotImplementedError(f"No available hashing function for {how}")

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
):
quoted = self.quoted

column_alias = sg.to_identifier(gen_name("table_unnest_column"), quoted=quoted)

opname = op.column.name
parent_schema = op.parent.schema
overlaps_with_parent = opname in parent_schema
computed_column = column_alias.as_(opname, quoted=quoted)

parent_alias = parent.alias_or_name

selcols = []

if overlaps_with_parent:
column_alias_or_name = column.alias_or_name
selcols.extend(
sg.column(col, table=parent_alias, quoted=quoted)
if col != column_alias_or_name
else computed_column
for col in parent_schema.names
)
else:
selcols.append(
sge.Column(
this=STAR, table=sg.to_identifier(parent_alias, quoted=quoted)
)
)
selcols.append(computed_column)

alias_columns = []

if offset is not None:
offset = sg.column(offset, quoted=quoted)
selcols.append(offset)
alias_columns.append(offset)

alias_columns.append(column_alias)

# four possible functions
#
# explode: unnest
# explode_outer: unnest preserving empties and nulls
# posexplode: unnest with index
# posexplode_outer: unnest with index preserving empties and nulls
funcname = (
("pos" if offset is not None else "")
+ "explode"
+ ("_outer" if keep_empty else "")
)

return (
sg.select(*selcols)
.from_(parent)
.lateral(
sge.Lateral(
this=self.f[funcname](column),
view=True,
alias=sge.TableAlias(columns=alias_columns),
)
)
)
12 changes: 6 additions & 6 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ def test_zip_unnest_lift(con):


@pytest.mark.notimpl(
["datafusion", "pandas", "polars", "dask", "flink", "pyspark"],
["datafusion", "pandas", "polars", "dask", "flink"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.parametrize(
Expand All @@ -1407,7 +1407,7 @@ def test_table_unnest(backend, colspec):


@pytest.mark.notimpl(
["datafusion", "pandas", "polars", "dask", "flink", "pyspark"],
["datafusion", "pandas", "polars", "dask", "flink"],
raises=com.OperationNotDefinedError,
)
def test_table_unnest_with_offset(backend):
Expand All @@ -1423,17 +1423,17 @@ def test_table_unnest_with_offset(backend):
idx = iter(df.idx.values)
expected = (
df.assign(**{col: df[col].map(lambda v: v[next(idx)])})
.sort_values(["idx", "y"])
.sort_values(["idx", col])
.reset_index(drop=True)[["idx", col]]
)

expr = t.unnest(col, offset="idx")[["idx", col]].order_by("idx", "y")
expr = t.unnest(col, offset="idx")[["idx", col]].order_by("idx", col)
result = expr.execute()
tm.assert_frame_equal(result, expected)


@pytest.mark.notimpl(
["datafusion", "pandas", "polars", "dask", "flink", "pyspark"],
["datafusion", "pandas", "polars", "dask", "flink"],
raises=com.OperationNotDefinedError,
)
def test_table_unnest_with_keep_empty(con):
Expand All @@ -1444,7 +1444,7 @@ def test_table_unnest_with_keep_empty(con):


@pytest.mark.notimpl(
["datafusion", "pandas", "polars", "dask", "flink", "pyspark"],
["datafusion", "pandas", "polars", "dask", "flink"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
Expand Down

0 comments on commit 268299b

Please sign in to comment.