Skip to content

Commit

Permalink
feat(table): implement pivot_wider API
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Mar 27, 2023
1 parent 2f63ce1 commit 60e7731
Show file tree
Hide file tree
Showing 5 changed files with 494 additions and 3 deletions.
17 changes: 17 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,23 @@ def test_pivot_longer(backend):
assert not df.empty


@pytest.mark.notyet(["datafusion"], raises=com.OperationNotDefinedError)
def test_pivot_wider(backend):
diamonds = backend.diamonds
expr = (
diamonds.group_by(["cut", "color"])
.agg(carat=_.carat.mean())
.pivot_wider(
names_from="cut", values_from="carat", names_sort=True, values_agg="mean"
)
)
df = expr.execute()
assert set(df.columns) == {"color"} | set(
diamonds[["cut"]].distinct().cut.execute()
)
assert len(df) == diamonds.color.nunique().execute()


@pytest.mark.parametrize(
"on",
[
Expand Down
10 changes: 10 additions & 0 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,3 +846,13 @@ def finder(node):
)

return g.traverse(finder, nodes, filter=ops.Node)


def find_toplevel_aggs(nodes: Iterable[ops.Node]) -> Iterator[ops.Table]:
def finder(node):
return (
isinstance(node, ops.Value),
node if isinstance(node, ops.Reduction) else None,
)

return g.traverse(finder, nodes, filter=ops.Node)
11 changes: 8 additions & 3 deletions ibis/expr/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def expand(self, table: ir.Table) -> Sequence[ir.Value]:
return [col for column in table.columns if self.predicate(col := table[column])]

def __and__(self, other: Selector) -> Predicate:
"""Compute the conjunction of two `Selectors`.
"""Compute the conjunction of two `Selector`s.
Parameters
----------
Expand All @@ -104,7 +104,7 @@ def __and__(self, other: Selector) -> Predicate:
return self.__class__(lambda col: self.predicate(col) and other.predicate(col))

def __or__(self, other: Selector) -> Predicate:
"""Compute the disjunction of two `Selectors`.
"""Compute the disjunction of two `Selector`s.
Parameters
----------
Expand All @@ -114,7 +114,7 @@ def __or__(self, other: Selector) -> Predicate:
return self.__class__(lambda col: self.predicate(col) or other.predicate(col))

def __invert__(self) -> Predicate:
"""Compute the logical negation of two `Selectors`."""
"""Compute the logical negation of two `Selector`s."""
return self.__class__(lambda col: not self.predicate(col))


Expand Down Expand Up @@ -457,3 +457,8 @@ def last() -> Predicate:
def all() -> Predicate:
"""Return every column from a table."""
return r[:]


def _to_selector(obj: str | Selector) -> Selector:
"""Convert an object to a `Selector`."""
return c(obj) if isinstance(obj, str) else obj
Loading

0 comments on commit 60e7731

Please sign in to comment.