diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 14e121137e17f..7b8a20f6a2718 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -1215,10 +1215,10 @@ def visit_Sort(self, op, *, parent, keys): return sg.select(STAR).from_(parent).order_by(*keys) def visit_Union(self, op, *, left, right, distinct): - if isinstance(left, sge.Table): + if isinstance(left, (sge.Table, sge.Subquery)): left = sg.select(STAR).from_(left) - if isinstance(right, sge.Table): + if isinstance(right, (sge.Table, sge.Subquery)): right = sg.select(STAR).from_(right) return sg.union( @@ -1228,10 +1228,10 @@ def visit_Union(self, op, *, left, right, distinct): ) def visit_Intersection(self, op, *, left, right, distinct): - if isinstance(left, sge.Table): + if isinstance(left, (sge.Table, sge.Subquery)): left = sg.select(STAR).from_(left) - if isinstance(right, sge.Table): + if isinstance(right, (sge.Table, sge.Subquery)): right = sg.select(STAR).from_(right) return sg.intersect( @@ -1241,10 +1241,10 @@ def visit_Intersection(self, op, *, left, right, distinct): ) def visit_Difference(self, op, *, left, right, distinct): - if isinstance(left, sge.Table): + if isinstance(left, (sge.Table, sge.Subquery)): left = sg.select(STAR).from_(left) - if isinstance(right, sge.Table): + if isinstance(right, (sge.Table, sge.Subquery)): right = sg.select(STAR).from_(right) return sg.except_( diff --git a/ibis/backends/tests/test_set_ops.py b/ibis/backends/tests/test_set_ops.py index 4df076da7f978..8e4528fd304c2 100644 --- a/ibis/backends/tests/test_set_ops.py +++ b/ibis/backends/tests/test_set_ops.py @@ -192,9 +192,13 @@ def test_table_set_operations_api(alltypes, method): False, ], ) -def test_top_level_union(backend, con, alltypes, distinct): +@pytest.mark.parametrize("ordered", [False, True]) +def test_top_level_union(backend, con, alltypes, distinct, ordered): t1 = alltypes.select(a="bigint_col").filter(lambda t: t.a == 10).distinct() t2 = alltypes.select(a="bigint_col").filter(lambda t: t.a == 20).distinct() + if ordered: + t1 = t1.order_by("a") + t2 = t2.order_by("a") expr = t1.union(t2, distinct=distinct).limit(2) result = con.execute(expr) expected = pd.DataFrame({"a": [10, 20]}) @@ -237,10 +241,11 @@ def test_top_level_union(backend, con, alltypes, distinct): ], ids=["intersect", "difference"], ) +@pytest.mark.parametrize("ordered", [False, True]) @pytest.mark.notimpl(["polars"], raises=com.OperationNotDefinedError) @pytest.mark.broken(["druid"], raises=PyDruidProgrammingError) def test_top_level_intersect_difference( - backend, con, alltypes, distinct, opname, expected + backend, con, alltypes, distinct, opname, expected, ordered ): t1 = ( alltypes.select(a="bigint_col") @@ -252,6 +257,9 @@ def test_top_level_intersect_difference( .filter(lambda t: (t.a == 20) | (t.a == 30)) .distinct() ) + if ordered: + t1 = t1.order_by("a") + t2 = t2.order_by("a") op = getattr(t1, opname) expr = op(t2, distinct=distinct).limit(2) result = con.execute(expr)