Skip to content

Commit

Permalink
refactor(api): revamp asof join predicates
Browse files Browse the repository at this point in the history
Previously the `ASOF` join API was imprecise.

The backends supporting `asof` joins require exactly one nearest match
(inequality) predicate along with arbitrary number of ordinary join
predicates, see [ClickHouse ASOF](https://clickhouse.com/docs/en/sql-reference/statements/select/join#asof-join-usage),
[DuckDB ASOF](https://duckdb.org/docs/guides/sql_features/asof_join.html#asof-joins-with-the-using-keyword) and
[Pandas ASOF](https://pandas.pydata.org/docs/reference/api/pandas.merge_asof.html).

This change alters the API to
`table.asof_join(left, right, on, predicates, ...)` where `on` is the
nearest match predicate defaulting to `left[on] <= right[on]` if not an
expression is given. I kept the `by` argument for compatibility reasons,
but we should phase that out in favor of `predicates`.

Also ensure that all the join methods or `ir.Join` have the exact same
docstrings as `ir.Table`.

BREAKING CHANGE: `on` paremater of `table.asof_join()` is now only
accept a single predicate, use `predicates` to supply additional
join predicates.
  • Loading branch information
kszucs committed Feb 12, 2024
1 parent e3e17db commit 9fb3627
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 81 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ def test_join_with_external_table(alltypes, df):
def test_asof_join(time_left, time_right):
expr = time_left.asof_join(
time_right,
on=time_left["time"] >= time_right["time"],
predicates=[
time_left["key"] == time_right["key"],
time_left["time"] >= time_right["time"],
],
).drop("time_right")
result = expr.execute()
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def schema(self):
def to_expr(self):
import ibis.expr.types as ir

return ir.JoinExpr(self)
return ir.Join(self)


@public
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ r1 := UnboundTable: right

JoinChain[r0]
JoinLink[asof, r1]
r0.time1 == r1.time2
r0.time1 <= r1.time2
JoinLink[inner, r1]
r0.value == r1.value2
values:
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/tests/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def test_fillna(snapshot):
def test_asof_join(snapshot):
left = ibis.table([("time1", "int32"), ("value", "double")], name="left")
right = ibis.table([("time2", "int32"), ("value2", "double")], name="right")
joined = left.asof_join(right, [("time1", "time2")]).inner_join(
joined = left.asof_join(right, ("time1", "time2")).inner_join(
right, left.value == right.value2
)

Expand Down
35 changes: 32 additions & 3 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,14 +487,14 @@ def test_join():
t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"})
joined = t1.join(t2, [t1.a == t2.c])

assert isinstance(joined, ir.JoinExpr)
assert isinstance(joined, ir.Join)
assert isinstance(joined.op(), JoinChain)
assert isinstance(joined.op().to_expr(), ir.JoinExpr)
assert isinstance(joined.op().to_expr(), ir.Join)

result = joined._finish()
assert isinstance(joined, ir.TableExpr)
assert isinstance(joined.op(), JoinChain)
assert isinstance(joined.op().to_expr(), ir.JoinExpr)
assert isinstance(joined.op().to_expr(), ir.Join)

with join_tables(t1, t2) as (t1, t2):
assert result.op() == JoinChain(
Expand Down Expand Up @@ -1264,3 +1264,32 @@ def test_join_between_joins():
},
)
assert expr.op() == expected


def test_join_method_docstrings():
t1 = ibis.table(name="t1", schema={"a": "int64", "b": "string"})
t2 = ibis.table(name="t2", schema={"c": "int64", "d": "string"})
joined = t1.join(t2, [t1.a == t2.c])

assert isinstance(t1, ir.Table)
assert isinstance(joined, ir.Join)
assert isinstance(joined, ir.Table)

method_names = [
"select",
"join",
"inner_join",
"left_join",
"outer_join",
"semi_join",
"anti_join",
"asof_join",
"cross_join",
"right_join",
"any_inner_join",
"any_left_join",
]
for method in method_names:
join_method = getattr(joined, method)
table_method = getattr(t1, method)
assert join_method.__doc__ == table_method.__doc__
145 changes: 120 additions & 25 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from ibis.expr.types.relations import (
bind,
dereference_values,
unwrap_aliases,
)
from __future__ import annotations

import functools
from public import public
import ibis.expr.operations as ops
from ibis.expr.types import Table, ValueExpr
from typing import Any, Optional
from collections.abc import Iterator, Mapping

import ibis
import ibis.expr.operations as ops

from ibis import util
from ibis.expr.types import Table, ValueExpr
from ibis.common.deferred import Deferred
from ibis.expr.analysis import flatten_predicates
from ibis.expr.operations.relations import JoinKind
from ibis.common.exceptions import ExpressionError, IntegrityError
from ibis import util
import functools
from ibis.expr.types.relations import dereference_mapping
import ibis

from ibis.expr.types.relations import (
bind,
dereference_values,
dereference_mapping,
unwrap_aliases,
)
from ibis.expr.operations.relations import JoinKind
from ibis.expr.rewrites import peel_join_field


Expand Down Expand Up @@ -91,9 +94,12 @@ def dereference_value(pred, deref_left, deref_right):
return pred.replace(deref_both, filter=ops.Value)


def prepare_predicates(left, right, predicates, deref_left, deref_right):
def prepare_predicates(
left, right, predicates, deref_left, deref_right, comparison=ops.Equals
):
"""Bind and dereference predicates to the left and right tables."""

left, right = left.to_expr(), right.to_expr()
for pred in util.promote_list(predicates):
if pred is True or pred is False:
yield ops.Literal(pred, dtype="bool")
Expand All @@ -120,7 +126,7 @@ def prepare_predicates(left, right, predicates, deref_left, deref_right):
left_value, right_value = dereference_sides(
left_value.op(), right_value.op(), deref_left, deref_right
)
yield ops.Equals(left_value, right_value).to_expr()
yield comparison(left_value, right_value)


def finished(method):
Expand All @@ -134,10 +140,18 @@ def wrapper(self, *args, **kwargs):


@public
class JoinExpr(Table):
class Join(Table):
__slots__ = ("_collisions",)

def __init__(self, arg, collisions=None):
assert isinstance(arg, ops.Node)
if not isinstance(arg, ops.JoinChain):
# coerce the input node to a join chain operation by first wrapping
# the input relation in a JoinTable so that we can join the same
# table with itself multiple times and to enable optimization
# passes later on
arg = ops.JoinTable(arg, index=0)
arg = ops.JoinChain(arg, rest=(), values=arg.fields)
super().__init__(arg)
object.__setattr__(self, "_collisions", collisions or set())

Expand All @@ -147,7 +161,8 @@ def _finish(self) -> Table:
raise IntegrityError(f"Name collisions: {self._collisions}")
return Table(self.op())

def join(
@functools.wraps(Table.join)
def join( # noqa: D102
self,
right,
predicates: Any,
Expand All @@ -156,10 +171,10 @@ def join(
lname: str = "",
rname: str = "{name}_right",
):
"""Join with another table."""
import pyarrow as pa
import pandas as pd

# TODO(kszucs): factor out to a helper function
if isinstance(right, (pd.DataFrame, pa.Table)):
right = ibis.memtable(right)
elif not isinstance(right, Table):
Expand All @@ -169,6 +184,8 @@ def join(

if how == "left_semi":
how = "semi"
elif how == "asof":
raise IbisInputError("use table.asof_join(...) instead")

left = self.op()
right = ops.JoinTable(right, index=left.length)
Expand All @@ -177,17 +194,17 @@ def join(

# bind and dereference the predicates
preds = prepare_predicates(
left.to_expr(),
right.to_expr(),
left,
right,
predicates,
deref_left=subs_left,
deref_right=subs_right,
)
preds = flatten_predicates(list(preds))

# if there are no predicates, default to every row matching unless the
# join is a cross join, because a cross join already has this behavior
if not preds and how != "cross":
# if there are no predicates, default to every row matching unless
# the join is a cross join, because a cross join already has this
# behavior
preds.append(ops.Literal(True, dtype="bool"))

# calculate the fields based in lname and rname, this should be a best
Expand All @@ -205,8 +222,83 @@ def join(
# return with a new JoinExpr wrapping the new join chain
return self.__class__(left, collisions=collisions)

def select(self, *args, **kwargs):
"""Select expressions."""
@functools.wraps(Table.asof_join)
def asof_join( # noqa: D102
self: Table,
right: Table,
on,
predicates=(),
by=(),
tolerance=None,
*,
lname: str = "",
rname: str = "{name}_right",
):
predicates = util.promote_list(predicates) + util.promote_list(by)
if tolerance is not None:
if not isinstance(on, str):
raise TypeError(
"tolerance can only be specified when predicates is a string"
)
# construct a predicate with two sides from the two tables
predicates.append(self[on] <= right[on] + tolerance)

left = self.op()
right = ops.JoinTable(right, index=left.length)
subs_left = dereference_mapping_left(left)
subs_right = dereference_mapping_right(right)

# TODO(kszucs): add extra validation for `on` with clear error messages
preds = list(
prepare_predicates(
left,
right,
[on],
deref_left=subs_left,
deref_right=subs_right,
comparison=ops.LessEqual,
)
)
preds += flatten_predicates(
list(
prepare_predicates(
left,
right,
predicates,
deref_left=subs_left,
deref_right=subs_right,
comparison=ops.Equals,
)
)
)
values, collisions = disambiguate_fields(
"asof", left.values, right.fields, lname, rname
)

# construct a new join link and add it to the join chain
link = ops.JoinLink("asof", table=right, predicates=preds)
left = left.copy(rest=left.rest + (link,), values=values)

# return with a new JoinExpr wrapping the new join chain
return self.__class__(left, collisions=collisions)

@functools.wraps(Table.cross_join)
def cross_join( # noqa: D102
self: Table,
right: Table,
*rest: Table,
lname: str = "",
rname: str = "{name}_right",
):
left = self.join(right, how="cross", predicates=(), lname=lname, rname=rname)
for right in rest:
left = left.join(
right, how="cross", predicates=(), lname=lname, rname=rname
)
return left

@functools.wraps(Table.select)
def select(self, *args, **kwargs): # noqa: D102
chain = self.op()
values = bind(self, (args, kwargs))
values = unwrap_aliases(values)
Expand Down Expand Up @@ -245,3 +337,6 @@ def select(self, *args, **kwargs):
unbind = finished(Table.unbind)
union = finished(Table.union)
view = finished(Table.view)


public(JoinExpr=Join)
Loading

0 comments on commit 9fb3627

Please sign in to comment.