Skip to content

Commit

Permalink
fix(sql): support set operations wrapping subqueries
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Feb 21, 2024
1 parent ff5d078 commit 8d0e972
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
12 changes: 6 additions & 6 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,10 +1215,10 @@ def visit_Sort(self, op, *, parent, keys):
return sg.select(STAR).from_(parent).order_by(*keys)

def visit_Union(self, op, *, left, right, distinct):
if isinstance(left, sge.Table):
if isinstance(left, (sge.Table, sge.Subquery)):
left = sg.select(STAR).from_(left)

if isinstance(right, sge.Table):
if isinstance(right, (sge.Table, sge.Subquery)):
right = sg.select(STAR).from_(right)

return sg.union(
Expand All @@ -1228,10 +1228,10 @@ def visit_Union(self, op, *, left, right, distinct):
)

def visit_Intersection(self, op, *, left, right, distinct):
if isinstance(left, sge.Table):
if isinstance(left, (sge.Table, sge.Subquery)):
left = sg.select(STAR).from_(left)

if isinstance(right, sge.Table):
if isinstance(right, (sge.Table, sge.Subquery)):
right = sg.select(STAR).from_(right)

return sg.intersect(
Expand All @@ -1241,10 +1241,10 @@ def visit_Intersection(self, op, *, left, right, distinct):
)

def visit_Difference(self, op, *, left, right, distinct):
if isinstance(left, sge.Table):
if isinstance(left, (sge.Table, sge.Subquery)):
left = sg.select(STAR).from_(left)

if isinstance(right, sge.Table):
if isinstance(right, (sge.Table, sge.Subquery)):
right = sg.select(STAR).from_(right)

return sg.except_(
Expand Down
12 changes: 10 additions & 2 deletions ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,13 @@ def test_table_set_operations_api(alltypes, method):
False,
],
)
def test_top_level_union(backend, con, alltypes, distinct):
@pytest.mark.parametrize("ordered", [False, True])
def test_top_level_union(backend, con, alltypes, distinct, ordered):
t1 = alltypes.select(a="bigint_col").filter(lambda t: t.a == 10).distinct()
t2 = alltypes.select(a="bigint_col").filter(lambda t: t.a == 20).distinct()
if ordered:
t1 = t1.order_by("a")
t2 = t2.order_by("a")
expr = t1.union(t2, distinct=distinct).limit(2)
result = con.execute(expr)
expected = pd.DataFrame({"a": [10, 20]})
Expand Down Expand Up @@ -237,10 +241,11 @@ def test_top_level_union(backend, con, alltypes, distinct):
],
ids=["intersect", "difference"],
)
@pytest.mark.parametrize("ordered", [False, True])
@pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(["druid"], raises=PyDruidProgrammingError)
def test_top_level_intersect_difference(
backend, con, alltypes, distinct, opname, expected
backend, con, alltypes, distinct, opname, expected, ordered
):
t1 = (
alltypes.select(a="bigint_col")
Expand All @@ -252,6 +257,9 @@ def test_top_level_intersect_difference(
.filter(lambda t: (t.a == 20) | (t.a == 30))
.distinct()
)
if ordered:
t1 = t1.order_by("a")
t2 = t2.order_by("a")
op = getattr(t1, opname)
expr = op(t2, distinct=distinct).limit(2)
result = con.execute(expr)
Expand Down

0 comments on commit 8d0e972

Please sign in to comment.