Skip to content

Commit

Permalink
refactor(api): enforce at least one argument for Table set operations
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `Table.difference()`, `Table.intersection()`, and `Table.union()` now require at least one argument.
  • Loading branch information
kszucs authored and cpcloud committed Jun 6, 2023
1 parent c87f695 commit 57e948f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 113 deletions.
12 changes: 9 additions & 3 deletions ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,17 @@ def test_difference(backend, alltypes, df, distinct):


@pytest.mark.parametrize("method", ["intersect", "difference", "union"])
@pytest.mark.parametrize("source", [ibis, ir.Table], ids=["top_level", "method"])
def test_empty_set_op(alltypes, method, source):
result = getattr(source, method)(alltypes)
def test_table_set_operations_api(alltypes, method):
# top level variadic
result = getattr(ibis, method)(alltypes)
assert result.equals(alltypes)

# table level methods require at least one argument
with pytest.raises(
TypeError, match="missing 1 required positional argument: 'table'"
):
getattr(ir.Table, method)(alltypes)


@pytest.mark.parametrize(
"distinct",
Expand Down
19 changes: 15 additions & 4 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,21 @@ def trailing_range_window(preceding, order_by, group_by=None):
)


@functools.wraps(ir.Table.union)
def union(table: ir.Table, *rest: ir.Table, distinct: bool = False):
return table.union(*rest, distinct=distinct) if rest else table


@functools.wraps(ir.Table.intersect)
def intersect(table: ir.Table, *rest: ir.Table, distinct: bool = True):
return table.intersect(*rest, distinct=distinct) if rest else table


@functools.wraps(ir.Table.difference)
def difference(table: ir.Table, *rest: ir.Table, distinct: bool = True):
return table.difference(*rest, distinct=distinct) if rest else table


e = ops.E().to_expr()
pi = ops.Pi().to_expr()

Expand Down Expand Up @@ -1342,10 +1357,6 @@ def trailing_range_window(preceding, order_by, group_by=None):
join = ir.Table.join
asof_join = ir.Table.asof_join

union = ir.Table.union
intersect = ir.Table.intersect
difference = ir.Table.difference

_ = deferred = Deferred()
"""Deferred expression object.
Expand Down
134 changes: 28 additions & 106 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,15 +724,17 @@ def view(self) -> Table:
"""
return ops.SelfReference(self).to_expr()

def difference(self, *tables: Table, distinct: bool = True) -> Table:
def difference(self, table: Table, *rest: Table, distinct: bool = True) -> Table:
"""Compute the set difference of multiple table expressions.
The input tables must have identical schemas.
Parameters
----------
tables
One or more table expressions
table:
A table expression
*rest:
Additional table expressions
distinct
Only diff distinct rows not occurring in the calling table
Expand Down Expand Up @@ -777,39 +779,11 @@ def difference(self, *tables: Table, distinct: bool = True) -> Table:
├───────┤
│ 1 │
└───────┘
Passing no arguments to `difference` returns the table expression
This can be useful when you have a sequence of tables to process, and
you don't know the length prior to running your program (for example, user input).
>>> t1
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.difference()
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.difference().equals(t1)
True
"""
t = functools.reduce(
functools.partial(ops.Difference, distinct=distinct), tables, self.op()
).to_expr()
if t.equals(self):
return t
return t.select(self.columns)
node = ops.Difference(self, table, distinct=distinct)
for table in rest:
node = ops.Difference(node, table, distinct=distinct)
return node.to_expr().select(self.columns)

def aggregate(
self,
Expand Down Expand Up @@ -1247,15 +1221,17 @@ def order_by(

return self.op().order_by(sort_keys).to_expr()

def union(self, *tables: Table, distinct: bool = False) -> Table:
def union(self, table: Table, *rest: Table, distinct: bool = False) -> Table:
"""Compute the set union of multiple table expressions.
The input tables must have identical schemas.
Parameters
----------
*tables
One or more table expressions
table
A table expression
*rest
Additional table expressions
distinct
Only return distinct rows
Expand Down Expand Up @@ -1313,49 +1289,23 @@ def union(self, *tables: Table, distinct: bool = False) -> Table:
│ 2 │
│ 3 │
└───────┘
Passing no arguments to `union` returns the table expression
This can be useful when you have a sequence of tables to process, and
you don't know the length prior to running your program (for example, user input).
>>> t1
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.union()
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.union().equals(t1)
True
"""
t = functools.reduce(
functools.partial(ops.Union, distinct=distinct), tables, self.op()
).to_expr()
if t.equals(self):
return t
return t.select(self.columns)

def intersect(self, *tables: Table, distinct: bool = True) -> Table:
node = ops.Union(self, table, distinct=distinct)
for table in rest:
node = ops.Union(node, table, distinct=distinct)
return node.to_expr().select(self.columns)

def intersect(self, table: Table, *rest: Table, distinct: bool = True) -> Table:
"""Compute the set intersection of multiple table expressions.
The input tables must have identical schemas.
Parameters
----------
*tables
One or more table expressions
table
A table expression
*rest
Additional table expressions
distinct
Only return distinct rows
Expand Down Expand Up @@ -1400,39 +1350,11 @@ def intersect(self, *tables: Table, distinct: bool = True) -> Table:
├───────┤
│ 2 │
└───────┘
Passing no arguments to `intersect` returns the table expression.
This can be useful when you have a sequence of tables to process, and
you don't know the length prior to running your program (for example, user input).
>>> t1
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.intersect()
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t1.intersect().equals(t1)
True
"""
t = functools.reduce(
functools.partial(ops.Intersection, distinct=distinct), tables, self.op()
).to_expr()
if t.equals(self):
return t
return t.select(self.columns)
node = ops.Intersection(self, table, distinct=distinct)
for table in rest:
node = ops.Intersection(node, table, distinct=distinct)
return node.to_expr().select(self.columns)

def to_array(self) -> ir.Column:
"""View a single column table as an array.
Expand Down

0 comments on commit 57e948f

Please sign in to comment.