Skip to content

Commit

Permalink
feat(pyspark): implement Table.unnest
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jun 26, 2024
1 parent 319f6a9 commit 6a8ab90
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
55 changes: 55 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,3 +452,58 @@ def visit_HexDigest(self, op, *, arg, how):
return self.f.sha2(arg, int(how[-3:]))
else:
raise NotImplementedError(f"No available hashing function for {how}")

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
):
quoted = self.quoted

column_alias = sg.to_identifier(gen_name("table_unnest_column"), quoted=quoted)

opcol = op.column
opname = opcol.name

parent_alias_or_name = parent.alias_or_name
table = sg.to_identifier(parent_alias_or_name, quoted=quoted)

selcols = []

if opname in (parent_schema := op.parent.schema):
column_alias_or_name = column.alias_or_name
selcols.extend(
sg.column(col, table=parent_alias_or_name, quoted=quoted)
if col != column_alias_or_name
else sg.column(column_alias, quoted=quoted).as_(col, quoted=quoted)
for col in parent_schema.names
)
else:
selcols.append(sge.Column(this=STAR, table=table))
selcols.append(column_alias.as_(opname, quoted=quoted))

alias_columns = []

if offset is not None:
offset = sg.column(offset, quoted=quoted)
selcols.append(offset)
alias_columns.append(offset)

alias_columns.append(column_alias)

funcname = (
("pos" if offset is not None else "")
+ "explode"
+ ("_outer" if keep_empty else "")
)

res = (
sg.select(*selcols)
.from_(parent)
.lateral(
sge.Lateral(
this=self.f[funcname](column),
view=True,
alias=sge.TableAlias(columns=alias_columns),
)
)
)
return res
12 changes: 6 additions & 6 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ def test_zip_unnest_lift(con):


@pytest.mark.notimpl(
["datafusion", "pandas", "polars", "dask", "flink", "pyspark"],
["datafusion", "pandas", "polars", "dask", "flink"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.parametrize(
Expand All @@ -1407,7 +1407,7 @@ def test_table_unnest(backend, colspec):


@pytest.mark.notimpl(
["datafusion", "pandas", "polars", "dask", "flink", "pyspark"],
["datafusion", "pandas", "polars", "dask", "flink"],
raises=com.OperationNotDefinedError,
)
def test_table_unnest_with_offset(backend):
Expand All @@ -1423,17 +1423,17 @@ def test_table_unnest_with_offset(backend):
idx = iter(df.idx.values)
expected = (
df.assign(**{col: df[col].map(lambda v: v[next(idx)])})
.sort_values(["idx", "y"])
.sort_values(["idx", col])
.reset_index(drop=True)[["idx", col]]
)

expr = t.unnest(col, offset="idx")[["idx", col]].order_by("idx", "y")
expr = t.unnest(col, offset="idx")[["idx", col]].order_by("idx", col)
result = expr.execute()
tm.assert_frame_equal(result, expected)


@pytest.mark.notimpl(
["datafusion", "pandas", "polars", "dask", "flink", "pyspark"],
["datafusion", "pandas", "polars", "dask", "flink"],
raises=com.OperationNotDefinedError,
)
def test_table_unnest_with_keep_empty(con):
Expand All @@ -1444,7 +1444,7 @@ def test_table_unnest_with_keep_empty(con):


@pytest.mark.notimpl(
["datafusion", "pandas", "polars", "dask", "flink", "pyspark"],
["datafusion", "pandas", "polars", "dask", "flink"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(
Expand Down

0 comments on commit 6a8ab90

Please sign in to comment.