Skip to content

Commit

Permalink
fix(decompile): ensure that SelfReference is decompiled with a call…
Browse files Browse the repository at this point in the history
… to `.view()`
  • Loading branch information
cpcloud authored and kszucs committed Feb 12, 2024
1 parent 2a7ae3f commit 4a44c57
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 28 deletions.
22 changes: 11 additions & 11 deletions ibis/backends/tests/sql/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,45 @@

# from ibis.backends.base.sql.compiler import Compiler
from ibis.backends.tests.sql.conftest import to_sql
from ibis.tests.util import assert_decompile_roundtrip
from ibis.tests.util import assert_decompile_roundtrip, schemas_eq

pytestmark = pytest.mark.duckdb


def test_union(union, snapshot):
snapshot.assert_match(to_sql(union), "out.sql")
assert_decompile_roundtrip(union, snapshot, check_equality=False)
assert_decompile_roundtrip(union, snapshot, eq=schemas_eq)


def test_union_project_column(union_all, snapshot):
# select a column, get a subquery
expr = union_all[[union_all.key]]
snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)


def test_table_intersect(intersect, snapshot):
snapshot.assert_match(to_sql(intersect), "out.sql")
assert_decompile_roundtrip(intersect, snapshot, check_equality=False)
assert_decompile_roundtrip(intersect, snapshot, eq=schemas_eq)


def test_table_difference(difference, snapshot):
snapshot.assert_match(to_sql(difference), "out.sql")
assert_decompile_roundtrip(difference, snapshot, check_equality=False)
assert_decompile_roundtrip(difference, snapshot, eq=schemas_eq)


def test_intersect_project_column(intersect, snapshot):
# select a column, get a subquery
expr = intersect[[intersect.key]]
snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)


def test_difference_project_column(difference, snapshot):
# select a column, get a subquery
expr = difference[[difference.key]]
snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)


def test_table_distinct(con, snapshot):
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_having_from_filter(snapshot):
expr = having.aggregate(filt.a.sum().name("sum"))
snapshot.assert_match(to_sql(expr), "out.sql")
# params get different auto incremented counter identifiers
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)


def test_simple_agg_filter(snapshot):
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_table_drop_with_filter(snapshot):
joined = joined[left.a]
expr = joined.filter(joined.a < 1.0)
snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)


def test_table_drop_consistency():
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_subquery_where_location(snapshot):
out = Compiler.to_sql(expr, params={param: "20140101"})
snapshot.assert_match(out, "out.sql")
# params get different auto incremented counter identifiers
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)


def test_column_expr_retains_name(snapshot):
Expand All @@ -231,4 +231,4 @@ def test_union_order_by(snapshot):
t = ibis.table(dict(a="int", b="string"), name="t")
expr = t.order_by("b").union(t.order_by("b"))
snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)
14 changes: 7 additions & 7 deletions ibis/backends/tests/sql/test_select_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# from ibis.backends.base.sql.compiler import Compiler
from ibis.backends.tests.sql.conftest import get_query, to_sql
from ibis.tests.util import assert_decompile_roundtrip
from ibis.tests.util import assert_decompile_roundtrip, schemas_eq

pytestmark = pytest.mark.duckdb

Expand Down Expand Up @@ -117,7 +117,7 @@ def test_join_between_joins(snapshot):
exprs = [left, right.value3, right.value4]
expr = joined.select(exprs)
snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)


def test_join_just_materialized(nation, region, customer, snapshot):
Expand Down Expand Up @@ -178,7 +178,7 @@ def test_where_analyze_scalar_op(functional_alltypes, snapshot):
]
).count()
snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)


def test_bug_duplicated_where(airlines, snapshot):
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_fuse_projections(snapshot):
# fusion works even if there's a filter
table3_filtered = table2_filtered.select([table2, f2])
snapshot.assert_match(to_sql(table3_filtered), "project_filter.sql")
assert_decompile_roundtrip(table3_filtered, snapshot, check_equality=False)
assert_decompile_roundtrip(table3_filtered, snapshot, eq=schemas_eq)


def test_projection_filter_fuse(projection_fuse_filter, snapshot):
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_subquery_in_union(alltypes, snapshot):

expr = join1.union(join2)
snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)


def test_limit_with_self_join(functional_alltypes, snapshot):
Expand All @@ -351,7 +351,7 @@ def test_limit_with_self_join(functional_alltypes, snapshot):

expr = t.join(t2, t.tinyint_col < t2.timestamp_col.minute()).count()
snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot)
assert_decompile_roundtrip(expr, snapshot, eq=lambda x, y: repr(x) == repr(y))


def test_topk_predicate_pushdown_bug(nation, customer, region, snapshot):
Expand Down Expand Up @@ -414,7 +414,7 @@ def test_case_in_projection(alltypes, snapshot):
expr = t[expr.name("col1"), expr2.name("col2"), t]

snapshot.assert_match(to_sql(expr), "out.sql")
assert_decompile_roundtrip(expr, snapshot, check_equality=False)
assert_decompile_roundtrip(expr, snapshot, eq=schemas_eq)


def test_identifier_quoting(snapshot):
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,9 @@ def test_gh_1045(test1, test2, test3, snapshot):
t2 = test2
t3 = test3

t3 = t3[[c for c in t3.columns if c != "id3"]].mutate(id3=t3.id3.cast("int64"))
t3 = t3.mutate(id3=t3.id3.cast("int64"))

t3 = t3[[c for c in t3.columns if c != "val2"]].mutate(t3_val2=t3.id3)
t3 = t3.mutate(t3_val2=t3.id3)
t4 = t3.join(t2, t2.id2b == t3.id3)

t1 = t1[[t1[c].name(f"t1_{c}") for c in t1.columns]]
Expand Down
14 changes: 12 additions & 2 deletions ibis/expr/decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,14 @@ def aggregation(op, parent, groups, metrics):
raise ValueError("No metrics to aggregate")


@translate.register(ops.Distinct)
def distinct(op, parent):
return f"{parent}.distinct()"


@translate.register(ops.SelfReference)
def self_reference(op, parent, identifier):
return parent
return f"{parent}.view()"


@translate.register(ops.JoinTable)
Expand Down Expand Up @@ -330,7 +335,12 @@ def isin(op, value, options):


class CodeContext:
always_assign = (ops.ScalarParameter, ops.UnboundTable, ops.Aggregate)
always_assign = (
ops.ScalarParameter,
ops.UnboundTable,
ops.Aggregate,
ops.SelfReference,
)
always_ignore = (
ops.JoinTable,
ops.Field,
Expand Down
10 changes: 10 additions & 0 deletions ibis/expr/tests/test_decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,13 @@ def test_basic(expr, expected):
assert restored.equals(expected)
else:
assert restored == expected


def test_view():
expr = ibis.table({"x": "int"}, name="t").view()
assert "t.view()" in decompile(expr)


def test_distinct():
expr = ibis.table({"x": "int"}, name="t").distinct()
assert "t.distinct()" in decompile(expr)
33 changes: 27 additions & 6 deletions ibis/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from __future__ import annotations

import pickle
from typing import Callable

import ibis
import ibis.expr.types as ir
from ibis import util


Expand All @@ -27,8 +29,30 @@ def assert_pickle_roundtrip(obj):
assert obj == loaded


def assert_decompile_roundtrip(expr, snapshot=None, check_equality=True):
"""Assert that an ibis expression remains the same after decompilation."""
def schemas_eq(left: ir.Expr, right: ir.Expr) -> bool:
assert left.as_table().schema().equals(right.as_table().schema())


def assert_decompile_roundtrip(
expr: ir.Expr,
snapshot=None,
eq: Callable[[ir.Expr, ir.Expr], bool] = ir.Expr.equals,
):
"""Assert that an ibis expression remains the same after decompilation.
Parameters
----------
expr
The expression to decompile.
snapshot
A snapshot fixture.
eq
A callable that returns whether two Ibis expressions are equal.
Defaults to `ibis.expr.types.Expr.equals`. Use this to adjust
comparison behavior for expressions that contain `SelfReference`
operations from table.view() calls, or other relations whose equality
is difficult to roundtrip.
"""
rendered = ibis.decompile(expr, format=True)
if snapshot is not None:
snapshot.assert_match(rendered, "decompiled.py")
Expand All @@ -38,7 +62,4 @@ def assert_decompile_roundtrip(expr, snapshot=None, check_equality=True):
exec(rendered, {}, locals_)
restored = locals_["result"]

if check_equality:
assert expr.unbind().equals(restored)
else:
assert expr.as_table().schema().equals(restored.as_table().schema())
assert eq(expr.unbind(), restored)

0 comments on commit 4a44c57

Please sign in to comment.