Skip to content

Commit

Permalink
perf(api): rewrite union and intersection construction to support mor…
Browse files Browse the repository at this point in the history
…e operands (#9194)
  • Loading branch information
jitingxu1 authored May 31, 2024
1 parent 3506f40 commit 5d7aa55
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 103 deletions.
3 changes: 3 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ def test_table_info(alltypes):
),
[
"name",
"pos",
"type",
"count",
"nulls",
Expand Down Expand Up @@ -712,6 +713,7 @@ def test_table_info(alltypes):
s.of_type("numeric"),
[
"name",
"pos",
"type",
"count",
"nulls",
Expand Down Expand Up @@ -742,6 +744,7 @@ def test_table_info(alltypes):
s.of_type("string"),
[
"name",
"pos",
"type",
"count",
"nulls",
Expand Down
227 changes: 124 additions & 103 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import operator
import re
from collections import deque
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from keyword import iskeyword
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -32,7 +33,7 @@

import ibis.expr.types as ir
import ibis.selectors as s
from ibis.expr.operations.relations import JoinKind
from ibis.expr.operations.relations import JoinKind, Set
from ibis.expr.schema import SchemaLike
from ibis.expr.types import Table
from ibis.expr.types.groupby import GroupedTable
Expand Down Expand Up @@ -1005,67 +1006,6 @@ def view(self) -> Table:
else:
return ops.SelfReference(self).to_expr()

def difference(self, table: Table, *rest: Table, distinct: bool = True) -> Table:
"""Compute the set difference of multiple table expressions.
The input tables must have identical schemas.
Parameters
----------
table:
A table expression
*rest:
Additional table expressions
distinct
Only diff distinct rows not occurring in the calling table
See Also
--------
[`ibis.difference`](./expression-tables.qmd#ibis.difference)
Returns
-------
Table
The rows present in `self` that are not present in `tables`.
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t1 = ibis.memtable({"a": [1, 2]})
>>> t1
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t2 = ibis.memtable({"a": [2, 3]})
>>> t2
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 2 │
│ 3 │
└───────┘
>>> t1.difference(t2)
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
└───────┘
"""
node = ops.Difference(self, table, distinct=distinct)
for table in rest:
node = ops.Difference(node, table, distinct=distinct)
return node.to_expr().select(self.columns)

def aggregate(
self,
metrics: Sequence[ir.Scalar] | None = (),
Expand Down Expand Up @@ -1680,6 +1620,31 @@ def order_by(
node = ops.Sort(self, keys.values())
return node.to_expr()

def _assemble_set_op(
self, opcls: type[Set], table: Table, *rest: Table, distinct: bool
) -> Table:
"""Assemble a set operation expression.
This exists to workaround an issue in sqlglot where codegen blows the
Python stack because of set operation nesting.
The implementation here uses a queue to balance the operation tree.
"""
queue = deque()

queue.append(self)
queue.append(table)
queue.extend(rest)

while len(queue) > 1:
left = queue.popleft()
right = queue.popleft()
node = opcls(left, right, distinct=distinct)
queue.append(node)
result = queue.popleft()
assert not queue, "items left in queue"
return result.to_expr().select(*self.columns)

def union(self, table: Table, *rest: Table, distinct: bool = False) -> Table:
"""Compute the set union of multiple table expressions.
Expand Down Expand Up @@ -1749,10 +1714,7 @@ def union(self, table: Table, *rest: Table, distinct: bool = False) -> Table:
│ 3 │
└───────┘
"""
node = ops.Union(self, table, distinct=distinct)
for table in rest:
node = ops.Union(node, table, distinct=distinct)
return node.to_expr().select(self.columns)
return self._assemble_set_op(ops.Union, table, *rest, distinct=distinct)

def intersect(self, table: Table, *rest: Table, distinct: bool = True) -> Table:
"""Compute the set intersection of multiple table expressions.
Expand Down Expand Up @@ -1810,9 +1772,67 @@ def intersect(self, table: Table, *rest: Table, distinct: bool = True) -> Table:
│ 2 │
└───────┘
"""
node = ops.Intersection(self, table, distinct=distinct)
return self._assemble_set_op(ops.Intersection, table, *rest, distinct=distinct)

def difference(self, table: Table, *rest: Table, distinct: bool = True) -> Table:
"""Compute the set difference of multiple table expressions.
The input tables must have identical schemas.
Parameters
----------
table:
A table expression
*rest:
Additional table expressions
distinct
Only diff distinct rows not occurring in the calling table
See Also
--------
[`ibis.difference`](./expression-tables.qmd#ibis.difference)
Returns
-------
Table
The rows present in `self` that are not present in `tables`.
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t1 = ibis.memtable({"a": [1, 2]})
>>> t1
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
│ 2 │
└───────┘
>>> t2 = ibis.memtable({"a": [2, 3]})
>>> t2
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 2 │
│ 3 │
└───────┘
>>> t1.difference(t2)
┏━━━━━━━┓
┃ a ┃
┡━━━━━━━┩
│ int64 │
├───────┤
│ 1 │
└───────┘
"""
node = ops.Difference(self, table, distinct=distinct)
for table in rest:
node = ops.Intersection(node, table, distinct=distinct)
node = ops.Difference(node, table, distinct=distinct)
return node.to_expr().select(self.columns)

@deprecated(as_of="9.0", instead="use table.as_scalar() instead")
Expand Down Expand Up @@ -2861,42 +2881,42 @@ def describe(
>>> ibis.options.interactive = True
>>> p = ibis.examples.penguins.fetch()
>>> p.describe()
┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━┓
┃ name ┃ type ┃ count ┃ nulls ┃ unique ┃ mode ┃ … ┃
┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━┩
│ string │ string │ int64 │ int64 │ int64 │ string │ … │
├───────────────────┼─────────┼───────┼───────┼────────┼────────┼───┤
│ species │ string │ 344 │ 0 │ 3 │ Adelie │ … │
│ island │ string │ 344 │ 0 │ 3 │ Biscoe │ … │
│ bill_length_mm │ float64 │ 344 │ 2 │ 164 │ NULL │ … │
│ bill_depth_mm │ float64 │ 344 │ 2 │ 80 │ NULL │ … │
│ flipper_length_mm │ int64 │ 344 │ 2 │ 55 │ NULL │ … │
│ body_mass_g │ int64 │ 344 │ 2 │ 94 │ NULL │ … │
│ sex │ string │ 344 │ 11 │ 2 │ male │ … │
│ year │ int64 │ 344 │ 0 │ 3 │ NULL │ … │
└───────────────────┴─────────┴───────┴───────┴────────┴────────┴───┘
┏━━━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━┓
┃ name ┃ pos ┃ type ┃ count ┃ nulls ┃ unique ┃ mode ┃ … ┃
┡━━━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━┩
│ string │ int8 │ string │ int64 │ int64 │ int64 │ string │ … │
├───────────────────┼──────┼─────────┼───────┼───────┼────────┼────────┼───┤
│ species │ 0 │ string │ 344 │ 0 │ 3 │ Adelie │ … │
│ island │ 1 │ string │ 344 │ 0 │ 3 │ Biscoe │ … │
│ bill_length_mm │ 2 │ float64 │ 344 │ 2 │ 164 │ NULL │ … │
│ bill_depth_mm │ 3 │ float64 │ 344 │ 2 │ 80 │ NULL │ … │
│ flipper_length_mm │ 4 │ int64 │ 344 │ 2 │ 55 │ NULL │ … │
│ body_mass_g │ 5 │ int64 │ 344 │ 2 │ 94 │ NULL │ … │
│ sex │ 6 │ string │ 344 │ 11 │ 2 │ male │ … │
│ year │ 7 │ int64 │ 344 │ 0 │ 3 │ NULL │ … │
└───────────────────┴──────┴─────────┴───────┴───────┴────────┴────────┴───┘
>>> p.select(s.of_type("numeric")).describe()
┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━┳━━━┓
┃ name ┃ type ┃ count ┃ nulls ┃ unique ┃ mean ┃ … ┃
┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━╇━━━┩
│ string │ string │ int64 │ int64 │ int64 │ float64 │ … │
├───────────────────┼───────────────┼───────┼────────┼─────────────┼───┤
bill_length_mm │ float64 │ 344 │ 2 │ 164 │ 43.921930 │ … │
bill_depth_mm │ float64344 │ 2 8017.151170 │ … │
flipper_length_mm │ int64344 │ 2 55200.915205 │ … │
body_mass_g │ int64 │ 344 │ 2 │ 94 │ 4201.754386 │ … │
year │ int64 │ 344 │ 0 3 │ 2008.029070 │ … │
└───────────────────┴───────────────┴───────┴────────┴─────────────┴───┘
┏━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━━━━┳━━━┓
┃ name ┃ pos ┃ type ┃ count ┃ nulls ┃ unique ┃ … ┃
┡━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━━━━╇━━━┩
│ string │ int8 │ string │ int64 │ int64 │ int64 │ … │
├───────────────────┼───────────────┼───────┼───────────────┼───┤
flipper_length_mm │ 2 │ int64 │ 344 │ 2 │ 55 │ … │
body_mass_g 3 │ int64 │ 344 2 94 │ … │
year 4 │ int64 │ 344 0 3 │ … │
bill_length_mm 0 │ float64 │ 344 │ 2 │ 164 │ … │
bill_depth_mm 1 │ float64 │ 344 │ 280 │ … │
└───────────────────┴───────────────┴───────┴───────────────┴───┘
>>> p.select(s.of_type("string")).describe()
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┓
┃ name ┃ type ┃ count ┃ nulls ┃ unique ┃ mode ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━┩
│ string │ string │ int64 │ int64 │ int64 │ string │
├─────────┼────────┼───────┼───────┼────────┼────────┤
species │ string │ 344 │ 0 3Adelie
island │ string │ 344 │ 0 │ 3 │ Biscoe
sex │ string │ 344 │ 11 2male
└─────────┴────────┴───────┴───────┴────────┴────────┘
┏━━━━━━━━━┳━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━┳━━━━━━━━┳━━━━━━━━┓
┃ name ┃ pos ┃ type ┃ count ┃ nulls ┃ unique ┃ mode ┃
┡━━━━━━━━━╇━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━╇━━━━━━━━╇━━━━━━━━┩
│ string │ int8 │ string │ int64 │ int64 │ int64 │ string │
├─────────┼──────┼────────┼───────┼───────┼────────┼────────┤
sex │ 2 │ string │ 344 │ 11 2male
species │ 0 │ string │ 344 │ 0 │ 3 │ Adelie
island │ 1 │ string │ 344 │ 0 3Biscoe
└─────────┴──────┴────────┴───────┴───────┴────────┴────────┘
"""
import ibis.selectors as s
from ibis import literal as lit
Expand All @@ -2905,7 +2925,7 @@ def describe(
aggs = []
string_col = False
numeric_col = False
for colname in self.columns:
for pos, colname in enumerate(self.columns):
col = self[colname]
typ = col.type()

Expand Down Expand Up @@ -2942,6 +2962,7 @@ def describe(

agg = self.agg(
name=lit(colname),
pos=lit(pos),
type=lit(str(typ)),
count=col.isnull().count(),
nulls=col.isnull().sum(),
Expand Down
18 changes: 18 additions & 0 deletions ibis/tests/benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,21 @@ def many_cols():
)
def test_column_access(benchmark, many_cols, getter):
benchmark(getter, many_cols)


@pytest.fixture(scope="module")
def many_tables():
num_cols = 10
num_tables = 1000
return [
ibis.table({f"c{i}": "int" for i in range(num_cols)}) for _ in range(num_tables)
]


def test_large_union_construct(benchmark, many_tables):
assert benchmark(lambda args: ibis.union(*args), many_tables) is not None


def test_large_union_compile(benchmark, many_tables):
expr = ibis.union(*many_tables)
assert benchmark(ibis.to_sql, expr) is not None

0 comments on commit 5d7aa55

Please sign in to comment.