diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index d19fc668bc90..b0b59bf60559 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1003,6 +1003,23 @@ def test_pivot_longer(backend): assert not df.empty +@pytest.mark.notyet(["datafusion"], raises=com.OperationNotDefinedError) +def test_pivot_wider(backend): + diamonds = backend.diamonds + expr = ( + diamonds.group_by(["cut", "color"]) + .agg(carat=_.carat.mean()) + .pivot_wider( + names_from="cut", values_from="carat", names_sort=True, values_agg="mean" + ) + ) + df = expr.execute() + assert set(df.columns) == {"color"} | set( + diamonds[["cut"]].distinct().cut.execute() + ) + assert len(df) == diamonds.color.nunique().execute() + + @pytest.mark.parametrize( "on", [ diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index a631c822b03d..42f8f237654e 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -846,3 +846,13 @@ def finder(node): ) return g.traverse(finder, nodes, filter=ops.Node) + + +def find_toplevel_aggs(nodes: Iterable[ops.Node]) -> Iterator[ops.Table]: + def finder(node): + return ( + isinstance(node, ops.Value), + node if isinstance(node, ops.Reduction) else None, + ) + + return g.traverse(finder, nodes, filter=ops.Node) diff --git a/ibis/expr/selectors.py b/ibis/expr/selectors.py index 764cae3bf60f..2f6a38e745b8 100644 --- a/ibis/expr/selectors.py +++ b/ibis/expr/selectors.py @@ -94,7 +94,7 @@ def expand(self, table: ir.Table) -> Sequence[ir.Value]: return [col for column in table.columns if self.predicate(col := table[column])] def __and__(self, other: Selector) -> Predicate: - """Compute the conjunction of two `Selectors`. + """Compute the conjunction of two `Selector`s. Parameters ---------- @@ -104,7 +104,7 @@ def __and__(self, other: Selector) -> Predicate: return self.__class__(lambda col: self.predicate(col) and other.predicate(col)) def __or__(self, other: Selector) -> Predicate: - """Compute the disjunction of two `Selectors`. + """Compute the disjunction of two `Selector`s. Parameters ---------- @@ -114,7 +114,7 @@ def __or__(self, other: Selector) -> Predicate: return self.__class__(lambda col: self.predicate(col) or other.predicate(col)) def __invert__(self) -> Predicate: - """Compute the logical negation of two `Selectors`.""" + """Compute the logical negation of two `Selector`s.""" return self.__class__(lambda col: not self.predicate(col)) @@ -457,3 +457,8 @@ def last() -> Predicate: def all() -> Predicate: """Return every column from a table.""" return r[:] + + +def _to_selector(obj: str | Selector) -> Selector: + """Convert an object to a `Selector`.""" + return c(obj) if isinstance(obj, str) else obj diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index a4cf8611ef05..fe06e8d279f8 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -4,6 +4,7 @@ import contextlib import functools import itertools +import operator import re import warnings from keyword import iskeyword @@ -3090,6 +3091,452 @@ def pivot_longer( return self.select(~pivot_sel, **new_cols) + @util.experimental + def pivot_wider( + self, + *, + id_cols: s.Selector | None = None, + names_from: str | Iterable[str] | s.Selector = "name", + names_prefix: str = "", + names_sep: str = "_", + names_sort: bool = False, + names: Iterable[str] | None = None, + values_from: str | Iterable[str] | s.Selector = "value", + values_fill: int | float | str | ir.Scalar | None = None, + values_agg: str | Callable[[ir.Value], ir.Scalar] | Deferred = "arbitrary", + ): + """Pivot a table to a wider format. + + Parameters + ---------- + id_cols + A set of columns that uniquely identify each observation. + names_from + An argument describing which column or columns to use to get the + name of the output columns. + names_prefix + String added to the start of every column name. + names_sep + If `names_from` or `values_from` contains multiple columns, this + argument will be used to join their values together into a single + string to use as a column name. + names_sort + If [`True`][True] columns are sorted. If [`False`][False] column + names are ordered by appearance. + names + An explicit sequence of values to look for in columns matching + `names_from`. + + * When this value is `None`, the values will be computed from + `names_from`. + * When this value is not `None`, each element's length must match + the length of `names_from`. + + See examples below for more detail. + values_from + An argument describing which column or columns to get the cell + values from. + values_fill + A scalar value that specifies what each value should be filled with + when missing. + values_agg + A function applied to the value in each cell in the output. + + Returns + ------- + Table + Wider pivoted table + + Examples + -------- + >>> import ibis + >>> import ibis.expr.selectors as s + >>> from ibis import _ + >>> ibis.options.interactive = True + + Basic usage + + >>> fish_encounters = ibis.examples.fish_encounters.fetch() + >>> fish_encounters + ┏━━━━━━━┳━━━━━━━━━┳━━━━━━━┓ + ┃ fish ┃ station ┃ seen ┃ + ┡━━━━━━━╇━━━━━━━━━╇━━━━━━━┩ + │ int64 │ string │ int64 │ + ├───────┼─────────┼───────┤ + │ 4842 │ Release │ 1 │ + │ 4842 │ I80_1 │ 1 │ + │ 4842 │ Lisbon │ 1 │ + │ 4842 │ Rstr │ 1 │ + │ 4842 │ Base_TD │ 1 │ + │ 4842 │ BCE │ 1 │ + │ 4842 │ BCW │ 1 │ + │ 4842 │ BCE2 │ 1 │ + │ 4842 │ BCW2 │ 1 │ + │ 4842 │ MAE │ 1 │ + │ … │ … │ … │ + └───────┴─────────┴───────┘ + >>> fish_encounters.pivot_wider(names_from="station", values_from="seen") + ┏━━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━━┳━━━┓ + ┃ fish ┃ Release ┃ I80_1 ┃ Lisbon ┃ Rstr ┃ Base_TD ┃ BCE ┃ BCW ┃ … ┃ + ┡━━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━━╇━━━┩ + │ int64 │ int64 │ int64 │ int64 │ int64 │ int64 │ int64 │ int64 │ … │ + ├───────┼─────────┼───────┼────────┼───────┼─────────┼───────┼───────┼───┤ + │ 4842 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ … │ + │ 4843 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ … │ + │ 4844 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ … │ + │ 4845 │ 1 │ 1 │ 1 │ 1 │ 1 │ ∅ │ ∅ │ … │ + │ 4847 │ 1 │ 1 │ 1 │ ∅ │ ∅ │ ∅ │ ∅ │ … │ + │ 4848 │ 1 │ 1 │ 1 │ 1 │ ∅ │ ∅ │ ∅ │ … │ + │ 4849 │ 1 │ 1 │ ∅ │ ∅ │ ∅ │ ∅ │ ∅ │ … │ + │ 4850 │ 1 │ 1 │ ∅ │ 1 │ 1 │ 1 │ 1 │ … │ + │ 4851 │ 1 │ 1 │ ∅ │ ∅ │ ∅ │ ∅ │ ∅ │ … │ + │ 4854 │ 1 │ 1 │ ∅ │ ∅ │ ∅ │ ∅ │ ∅ │ … │ + │ … │ … │ … │ … │ … │ … │ … │ … │ … │ + └───────┴─────────┴───────┴────────┴───────┴─────────┴───────┴───────┴───┘ + + Fill missing pivoted values using `values_fill` + + >>> fish_encounters.pivot_wider(names_from="station", values_from="seen", values_fill=0) + ┏━━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━━┳━━━┓ + ┃ fish ┃ Release ┃ I80_1 ┃ Lisbon ┃ Rstr ┃ Base_TD ┃ BCE ┃ BCW ┃ … ┃ + ┡━━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━━╇━━━┩ + │ int64 │ int64 │ int64 │ int64 │ int64 │ int64 │ int64 │ int64 │ … │ + ├───────┼─────────┼───────┼────────┼───────┼─────────┼───────┼───────┼───┤ + │ 4842 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ … │ + │ 4843 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ … │ + │ 4844 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ 1 │ … │ + │ 4845 │ 1 │ 1 │ 1 │ 1 │ 1 │ 0 │ 0 │ … │ + │ 4847 │ 1 │ 1 │ 1 │ 0 │ 0 │ 0 │ 0 │ … │ + │ 4848 │ 1 │ 1 │ 1 │ 1 │ 0 │ 0 │ 0 │ … │ + │ 4849 │ 1 │ 1 │ 0 │ 0 │ 0 │ 0 │ 0 │ … │ + │ 4850 │ 1 │ 1 │ 0 │ 1 │ 1 │ 1 │ 1 │ … │ + │ 4851 │ 1 │ 1 │ 0 │ 0 │ 0 │ 0 │ 0 │ … │ + │ 4854 │ 1 │ 1 │ 0 │ 0 │ 0 │ 0 │ 0 │ … │ + │ … │ … │ … │ … │ … │ … │ … │ … │ … │ + └───────┴─────────┴───────┴────────┴───────┴─────────┴───────┴───────┴───┘ + + Compute multiple values columns + + >>> us_rent_income = ibis.examples.us_rent_income.fetch() + >>> us_rent_income + ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━┓ + ┃ geoid ┃ name ┃ variable ┃ estimate ┃ moe ┃ + ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━┩ + │ string │ string │ string │ int64 │ int64 │ + ├────────┼────────────┼──────────┼──────────┼───────┤ + │ 01 │ Alabama │ income │ 24476 │ 136 │ + │ 01 │ Alabama │ rent │ 747 │ 3 │ + │ 02 │ Alaska │ income │ 32940 │ 508 │ + │ 02 │ Alaska │ rent │ 1200 │ 13 │ + │ 04 │ Arizona │ income │ 27517 │ 148 │ + │ 04 │ Arizona │ rent │ 972 │ 4 │ + │ 05 │ Arkansas │ income │ 23789 │ 165 │ + │ 05 │ Arkansas │ rent │ 709 │ 5 │ + │ 06 │ California │ income │ 29454 │ 109 │ + │ 06 │ California │ rent │ 1358 │ 3 │ + │ … │ … │ … │ … │ … │ + └────────┴────────────┴──────────┴──────────┴───────┘ + >>> us_rent_income.pivot_wider(names_from="variable", values_from=["estimate", "moe"]) + ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━┓ + ┃ geoid ┃ name ┃ estimate_income ┃ moe_income ┃ … ┃ + ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━┩ + │ string │ string │ int64 │ int64 │ … │ + ├────────┼──────────────────────┼─────────────────┼────────────┼───┤ + │ 01 │ Alabama │ 24476 │ 136 │ … │ + │ 02 │ Alaska │ 32940 │ 508 │ … │ + │ 04 │ Arizona │ 27517 │ 148 │ … │ + │ 05 │ Arkansas │ 23789 │ 165 │ … │ + │ 06 │ California │ 29454 │ 109 │ … │ + │ 08 │ Colorado │ 32401 │ 109 │ … │ + │ 09 │ Connecticut │ 35326 │ 195 │ … │ + │ 10 │ Delaware │ 31560 │ 247 │ … │ + │ 11 │ District of Columbia │ 43198 │ 681 │ … │ + │ 12 │ Florida │ 25952 │ 70 │ … │ + │ … │ … │ … │ … │ … │ + └────────┴──────────────────────┴─────────────────┴────────────┴───┘ + + The column name separator can be changed using the `names_sep` parameter + + >>> us_rent_income.pivot_wider( + ... names_from="variable", + ... names_sep=".", + ... values_from=s.c("estimate", "moe"), + ... ) + ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━┓ + ┃ geoid ┃ name ┃ estimate.income ┃ moe.income ┃ … ┃ + ┡━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━┩ + │ string │ string │ int64 │ int64 │ … │ + ├────────┼──────────────────────┼─────────────────┼────────────┼───┤ + │ 01 │ Alabama │ 24476 │ 136 │ … │ + │ 02 │ Alaska │ 32940 │ 508 │ … │ + │ 04 │ Arizona │ 27517 │ 148 │ … │ + │ 05 │ Arkansas │ 23789 │ 165 │ … │ + │ 06 │ California │ 29454 │ 109 │ … │ + │ 08 │ Colorado │ 32401 │ 109 │ … │ + │ 09 │ Connecticut │ 35326 │ 195 │ … │ + │ 10 │ Delaware │ 31560 │ 247 │ … │ + │ 11 │ District of Columbia │ 43198 │ 681 │ … │ + │ 12 │ Florida │ 25952 │ 70 │ … │ + │ … │ … │ … │ … │ … │ + └────────┴──────────────────────┴─────────────────┴────────────┴───┘ + + Supply an alternative function to summarize values + + >>> warpbreaks = ibis.examples.warpbreaks.fetch().select("wool", "tension", "breaks") + >>> warpbreaks + ┏━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓ + ┃ wool ┃ tension ┃ breaks ┃ + ┡━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩ + │ string │ string │ int64 │ + ├────────┼─────────┼────────┤ + │ A │ L │ 26 │ + │ A │ L │ 30 │ + │ A │ L │ 54 │ + │ A │ L │ 25 │ + │ A │ L │ 70 │ + │ A │ L │ 52 │ + │ A │ L │ 51 │ + │ A │ L │ 26 │ + │ A │ L │ 67 │ + │ A │ M │ 18 │ + │ … │ … │ … │ + └────────┴─────────┴────────┘ + >>> warpbreaks.pivot_wider(names_from="wool", values_from="breaks", values_agg="mean") + ┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┓ + ┃ tension ┃ A ┃ B ┃ + ┡━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━┩ + │ string │ float64 │ float64 │ + ├─────────┼───────────┼───────────┤ + │ L │ 44.555556 │ 28.222222 │ + │ M │ 24.000000 │ 28.777778 │ + │ H │ 24.555556 │ 18.777778 │ + └─────────┴───────────┴───────────┘ + + Passing `Deferred` objects to `values_agg` is supported + + >>> warpbreaks.pivot_wider( + ... names_from="tension", + ... values_from="breaks", + ... values_agg=_.sum(), + ... ) + ┏━━━━━━━━┳━━━━━━━┳━━━━━━━┳━━━━━━━┓ + ┃ wool ┃ L ┃ M ┃ H ┃ + ┡━━━━━━━━╇━━━━━━━╇━━━━━━━╇━━━━━━━┩ + │ string │ int64 │ int64 │ int64 │ + ├────────┼───────┼───────┼───────┤ + │ A │ 401 │ 216 │ 221 │ + │ B │ 254 │ 259 │ 169 │ + └────────┴───────┴───────┴───────┘ + + Use a custom aggregate function + + >>> warpbreaks.pivot_wider( + ... names_from="wool", + ... values_from="breaks", + ... values_agg=lambda col: col.std() / col.mean(), + ... ) + ┏━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┓ + ┃ tension ┃ A ┃ B ┃ + ┡━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━┩ + │ string │ float64 │ float64 │ + ├─────────┼──────────┼──────────┤ + │ L │ 0.406183 │ 0.349325 │ + │ M │ 0.360844 │ 0.327719 │ + │ H │ 0.418344 │ 0.260590 │ + └─────────┴──────────┴──────────┘ + + Generate some random data, setting the random seed for reproducibility + + >>> import random + >>> random.seed(0) + >>> raw = ibis.memtable( + ... [ + ... dict( + ... product=product, + ... country=country, + ... year=year, + ... production=random.random(), + ... ) + ... for product in "AB" + ... for country in ["AI", "EI"] + ... for year in range(2000, 2015) + ... ] + ... ) + >>> production = raw.filter( + ... ((_.product == "A") & (_.country == "AI")) | (_.product == "B") + ... ) + >>> production + ┏━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━┓ + ┃ product ┃ country ┃ year ┃ production ┃ + ┡━━━━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━┩ + │ string │ string │ int64 │ float64 │ + ├─────────┼─────────┼───────┼────────────┤ + │ B │ AI │ 2000 │ 0.477010 │ + │ B │ AI │ 2001 │ 0.865310 │ + │ B │ AI │ 2002 │ 0.260492 │ + │ B │ AI │ 2003 │ 0.805028 │ + │ B │ AI │ 2004 │ 0.548699 │ + │ B │ AI │ 2005 │ 0.014042 │ + │ B │ AI │ 2006 │ 0.719705 │ + │ B │ AI │ 2007 │ 0.398824 │ + │ B │ AI │ 2008 │ 0.824845 │ + │ B │ AI │ 2009 │ 0.668153 │ + │ … │ … │ … │ … │ + └─────────┴─────────┴───────┴────────────┘ + + Pivoting with multiple name columns + + >>> production.pivot_wider( + ... names_from=["product", "country"], + ... values_from="production", + ... ) + ┏━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┓ + ┃ year ┃ B_AI ┃ B_EI ┃ A_AI ┃ + ┡━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━┩ + │ int64 │ float64 │ float64 │ float64 │ + ├───────┼──────────┼──────────┼──────────┤ + │ 2000 │ 0.477010 │ 0.870471 │ 0.844422 │ + │ 2001 │ 0.865310 │ 0.191067 │ 0.757954 │ + │ 2002 │ 0.260492 │ 0.567511 │ 0.420572 │ + │ 2003 │ 0.805028 │ 0.238616 │ 0.258917 │ + │ 2004 │ 0.548699 │ 0.967540 │ 0.511275 │ + │ 2005 │ 0.014042 │ 0.803179 │ 0.404934 │ + │ 2006 │ 0.719705 │ 0.447970 │ 0.783799 │ + │ 2007 │ 0.398824 │ 0.080446 │ 0.303313 │ + │ 2008 │ 0.824845 │ 0.320055 │ 0.476597 │ + │ 2009 │ 0.668153 │ 0.507941 │ 0.583382 │ + │ … │ … │ … │ … │ + └───────┴──────────┴──────────┴──────────┘ + + Select a subset of names. This call incurs no computation when + constructing the expression. + + >>> production.pivot_wider( + ... names_from=["product", "country"], + ... names=[("A", "AI"), ("B", "AI")], + ... values_from="production", + ... ) + ┏━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┓ + ┃ year ┃ A_AI ┃ B_AI ┃ + ┡━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━┩ + │ int64 │ float64 │ float64 │ + ├───────┼──────────┼──────────┤ + │ 2000 │ 0.844422 │ 0.477010 │ + │ 2001 │ 0.757954 │ 0.865310 │ + │ 2002 │ 0.420572 │ 0.260492 │ + │ 2003 │ 0.258917 │ 0.805028 │ + │ 2004 │ 0.511275 │ 0.548699 │ + │ 2005 │ 0.404934 │ 0.014042 │ + │ 2006 │ 0.783799 │ 0.719705 │ + │ 2007 │ 0.303313 │ 0.398824 │ + │ 2008 │ 0.476597 │ 0.824845 │ + │ 2009 │ 0.583382 │ 0.668153 │ + │ … │ … │ … │ + └───────┴──────────┴──────────┘ + + Sort the new columns' names + + >>> production.pivot_wider( + ... names_from=["product", "country"], + ... values_from="production", + ... names_sort=True, + ... ) + ┏━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━┓ + ┃ year ┃ A_AI ┃ B_AI ┃ B_EI ┃ + ┡━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━┩ + │ int64 │ float64 │ float64 │ float64 │ + ├───────┼──────────┼──────────┼──────────┤ + │ 2000 │ 0.844422 │ 0.477010 │ 0.870471 │ + │ 2001 │ 0.757954 │ 0.865310 │ 0.191067 │ + │ 2002 │ 0.420572 │ 0.260492 │ 0.567511 │ + │ 2003 │ 0.258917 │ 0.805028 │ 0.238616 │ + │ 2004 │ 0.511275 │ 0.548699 │ 0.967540 │ + │ 2005 │ 0.404934 │ 0.014042 │ 0.803179 │ + │ 2006 │ 0.783799 │ 0.719705 │ 0.447970 │ + │ 2007 │ 0.303313 │ 0.398824 │ 0.080446 │ + │ 2008 │ 0.476597 │ 0.824845 │ 0.320055 │ + │ 2009 │ 0.583382 │ 0.668153 │ 0.507941 │ + │ … │ … │ … │ … │ + └───────┴──────────┴──────────┴──────────┘ + """ + import pandas as pd + import ibis.expr.selectors as s + import ibis.expr.analysis as an + from ibis import _ + + orig_names_from = util.promote_list(names_from) + + names_from = s.any_of(*map(s._to_selector, orig_names_from)) + values_from = s.any_of(*map(s._to_selector, util.promote_list(values_from))) + + if id_cols is None: + id_cols = ~(names_from | values_from) + else: + id_cols = s._to_selector(id_cols) + + if isinstance(values_agg, str): + values_agg = operator.methodcaller(values_agg) + elif isinstance(values_agg, Deferred): + values_agg = values_agg.resolve + + if names is None: + # no names provided, compute them from the data + names = self.select(names_from).distinct().execute() + else: + if not (columns := [col.get_name() for col in names_from.expand(self)]): + raise com.IbisInputError( + f"No matching names columns in `names_from`: {orig_names_from}" + ) + names = pd.DataFrame(list(map(util.promote_list, names)), columns=columns) + + if names_sort: + names = names.sort_values(by=names.columns.tolist()) + + values_cols = values_from.expand(self) + more_than_one_value = len(values_cols) > 1 + aggs = {} + + names_cols_exprs = [self[col] for col in names.columns] + + for keys in names.itertuples(index=False): + where = ibis.and_(*map(operator.eq, names_cols_exprs, keys)) + + for values_col in values_cols: + arg = values_agg(values_col) + + # add in the where clause to filter the appropriate values + # in/out + # + # this allows users to write the aggregate without having to deal with + # the filter themselves + existing_aggs = an.find_toplevel_aggs(arg.op()) + subs = { + agg: agg.copy( + where=( + where + if (existing := agg.where) is None + else where & existing + ) + ) + for agg in existing_aggs + } + arg = an.sub_for(arg.op(), subs).to_expr() + + # build the components of the group by key + key_components = ( + # user provided prefix + names_prefix, + # include the `values` column name if there's more than one + # `values` column + values_col.get_name() * more_than_one_value, + # values computed from `names`/`names_from` + *keys, + ) + key = names_sep.join(filter(None, key_components)) + aggs[key] = arg if values_fill is None else arg.coalesce(values_fill) + + return self.group_by(id_cols).aggregate(**aggs) + @public class CachedTable(Table): diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 207592bc6574..a10d217b8dd4 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -1811,6 +1811,18 @@ def test_pivot_longer_no_match(): ) +def test_pivot_wider(): + fish = ibis.table({"fish": "int", "station": "string", "seen": "int"}, name="fish") + res = fish.pivot_wider( + names=["Release", "Lisbon"], names_from="station", values_from="seen" + ) + assert res.schema().names == ("fish", "Release", "Lisbon") + with pytest.raises( + com.IbisInputError, match="No matching names columns in `names_from`" + ): + fish.pivot_wider(names=["Release", "Lisbon"], values_from="seen") + + def test_invalid_deferred(): t = ibis.table(dict(value="int", lagged_value="int"), name="t")