Skip to content

Commit

Permalink
feat(api): allow single argument set operations
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed Mar 10, 2023
1 parent dc80512 commit bb0a6f0
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 27 deletions.
8 changes: 5 additions & 3 deletions ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import ibis
import ibis.common.exceptions as com
import ibis.expr.types as ir
from ibis import _


Expand Down Expand Up @@ -141,9 +142,10 @@ def test_difference(backend, alltypes, df, distinct):


@pytest.mark.parametrize("method", ["intersect", "difference", "union"])
def test_empty_set_op(alltypes, method):
with pytest.raises(com.IbisTypeError, match="requires a table or tables"):
getattr(alltypes, method)()
@pytest.mark.parametrize("source", [ibis, ir.Table], ids=["top_level", "method"])
def test_empty_set_op(alltypes, method, source):
result = getattr(source, method)(alltypes)
assert result.equals(alltypes)


@pytest.mark.parametrize(
Expand Down
123 changes: 99 additions & 24 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,10 @@ def difference(self, *tables: Table, distinct: bool = True) -> Table:
distinct
Only diff distinct rows not occurring in the calling table
See Also
--------
[`ibis.difference`][ibis.difference]
Returns
-------
Table
Expand Down Expand Up @@ -797,15 +801,36 @@ def difference(self, *tables: Table, distinct: bool = True) -> Table:
├───────┤
│ 1 │
└───────┘
Passing no arguments to `difference` returns the table expression
This can be useful when you have a sequence of tables to process, and
you don't know the length prior to running your program (for example, user input).
>>> t1
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.difference()
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.difference().equals(t1)
True
"""
left = self
if not tables:
raise com.IbisTypeError(
"difference requires a table or tables to compare against"
)
for right in tables:
left = ops.Difference(left, right, distinct=distinct)
return left.to_expr()
return functools.reduce(
functools.partial(ops.Difference, distinct=distinct), tables, self.op()
).to_expr()

def aggregate(
self,
Expand Down Expand Up @@ -1102,6 +1127,10 @@ def union(self, *tables: Table, distinct: bool = False) -> Table:
Table
A new table containing the union of all input tables.
See Also
--------
[`ibis.union`][ibis.union]
Examples
--------
>>> import ibis
Expand Down Expand Up @@ -1147,15 +1176,36 @@ def union(self, *tables: Table, distinct: bool = False) -> Table:
│ 2 │
│ 3 │
└───────┘
Passing no arguments to `union` returns the table expression
This can be useful when you have a sequence of tables to process, and
you don't know the length prior to running your program (for example, user input).
>>> t1
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.union()
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.union().equals(t1)
True
"""
left = self
if not tables:
raise com.IbisTypeError(
"union requires a table or tables to compare against"
)
for right in tables:
left = ops.Union(left, right, distinct=distinct)
return left.to_expr()
return functools.reduce(
functools.partial(ops.Union, distinct=distinct), tables, self.op()
).to_expr()

def intersect(self, *tables: Table, distinct: bool = True) -> Table:
"""Compute the set intersection of multiple table expressions.
Expand All @@ -1174,6 +1224,10 @@ def intersect(self, *tables: Table, distinct: bool = True) -> Table:
Table
A new table containing the intersection of all input tables.
See Also
--------
[`ibis.intersect`][ibis.intersect]
Examples
--------
>>> import ibis
Expand Down Expand Up @@ -1206,15 +1260,36 @@ def intersect(self, *tables: Table, distinct: bool = True) -> Table:
├───────┤
│ 2 │
└───────┘
Passing no arguments to `intersect` returns the table expression.
This can be useful when you have a sequence of tables to process, and
you don't know the length prior to running your program (for example, user input).
>>> t1
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.intersect()
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.intersect().equals(t1)
True
"""
left = self
if not tables:
raise com.IbisTypeError(
"intersect requires a table or tables to compare against"
)
for right in tables:
left = ops.Intersection(left, right, distinct=distinct)
return left.to_expr()
return functools.reduce(
functools.partial(ops.Intersection, distinct=distinct), tables, self.op()
).to_expr()

def to_array(self) -> ir.Column:
"""View a single column table as an array.
Expand Down

0 comments on commit bb0a6f0

Please sign in to comment.