Skip to content

Commit

Permalink
fix(pivot-wider): handle the case of empty id_cols (#9912)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored Aug 27, 2024
1 parent d7401e0 commit 4a4bc64
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
48 changes: 48 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2453,3 +2453,51 @@ def test_union_generates_predictable_aliases(con):
expr = ibis.union(sub1, sub2)
df = con.execute(expr)
assert len(df) == 2


@pytest.mark.parametrize("id_cols", [s.none(), [], s.c()])
def test_pivot_wider_empty_id_columns(con, backend, id_cols, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)
data = pd.DataFrame(
{
"id": range(10),
"actual": [0, 1, 1, 0, 0, 1, 0, 0, 0, 1],
"prediction": [1, 0, 0, 1, 0, 0, 0, 0, 0, 1],
}
)
t = ibis.memtable(data)
expr = t.mutate(
outcome=(
ibis.case()
.when((_["actual"] == 0) & (_["prediction"] == 0), "TN")
.when((_["actual"] == 0) & (_["prediction"] == 1), "FP")
.when((_["actual"] == 1) & (_["prediction"] == 0), "FN")
.when((_["actual"] == 1) & (_["prediction"] == 1), "TP")
.end()
)
)
expr = expr.pivot_wider(
id_cols=id_cols,
names_from="outcome",
values_from="outcome",
values_agg=_.count(),
names_sort=True,
)
result = expr.to_pandas()
expected = pd.DataFrame({"FN": [3], "FP": [2], "TN": [4], "TP": [1]})
backend.assert_frame_equal(result, expected)


@pytest.mark.notyet(
["mysql", "risingwave", "impala", "mssql", "druid", "exasol", "oracle", "flink"],
raises=com.OperationNotDefinedError,
reason="backend doesn't support Arbitrary agg",
)
def test_simple_pivot_wider(con, backend, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)
data = pd.DataFrame({"outcome": ["yes", "no"], "counted": [3, 4]})
t = ibis.memtable(data)
expr = t.pivot_wider(names_from="outcome", values_from="counted", names_sort=True)
result = expr.to_pandas()
expected = pd.DataFrame({"no": [4], "yes": [3]})
backend.assert_frame_equal(result, expected)
29 changes: 28 additions & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4075,6 +4075,27 @@ def pivot_wider(
│ … │ … │ … │ … │ … │ … │ … │ … │ … │
└───────┴─────────┴───────┴────────┴───────┴─────────┴───────┴───────┴───┘
You can do simple transpose-like operations using `pivot_wider`
>>> t = ibis.memtable(dict(outcome=["yes", "no"], counted=[3, 4]))
>>> t
┏━━━━━━━━━┳━━━━━━━━━┓
┃ outcome ┃ counted ┃
┡━━━━━━━━━╇━━━━━━━━━┩
│ string │ int64 │
├─────────┼─────────┤
│ yes │ 3 │
│ no │ 4 │
└─────────┴─────────┘
>>> t.pivot_wider(names_from="outcome", values_from="counted", names_sort=True)
┏━━━━━━━┳━━━━━━━┓
┃ no ┃ yes ┃
┡━━━━━━━╇━━━━━━━┩
│ int64 │ int64 │
├───────┼───────┤
│ 4 │ 3 │
└───────┴───────┘
Fill missing pivoted values using `values_fill`
>>> fish_encounters.pivot_wider(
Expand Down Expand Up @@ -4411,7 +4432,13 @@ def pivot_wider(
key = names_sep.join(filter(None, key_components))
aggs[key] = arg if values_fill is None else arg.coalesce(values_fill)

return self.group_by(id_cols).aggregate(**aggs)
grouping_keys = id_cols.expand(self)

# no id columns, so do an ungrouped aggregation
if not grouping_keys:
return self.aggregate(**aggs)

return self.group_by(*grouping_keys).aggregate(**aggs)

def relocate(
self,
Expand Down

0 comments on commit 4a4bc64

Please sign in to comment.