From 3955ebea597494b526cb56d6d6616a6997261070 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Wed, 4 Sep 2024 13:51:54 -0500 Subject: [PATCH] fix(sql): properly parenthesize binary ops containing named expressions --- .../test_subquery_where_location/decompiled.py | 2 +- .../test_binop_with_alias_still_parenthesized/out.sql | 5 +++++ ibis/backends/tests/sql/test_compiler.py | 2 +- ibis/backends/tests/sql/test_sql.py | 6 ++++++ ibis/expr/operations/core.py | 8 ++++++++ ibis/tests/expr/test_value_exprs.py | 6 ++++++ 6 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 ibis/backends/tests/sql/snapshots/test_sql/test_binop_with_alias_still_parenthesized/out.sql diff --git a/ibis/backends/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py b/ibis/backends/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py index aa735e73ec473..76f89e244beff 100644 --- a/ibis/backends/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py +++ b/ibis/backends/tests/sql/snapshots/test_compiler/test_subquery_where_location/decompiled.py @@ -11,7 +11,7 @@ }, ) param = ibis.param("timestamp") -f = alltypes.filter((alltypes.timestamp_col < param.name("my_param"))) +f = alltypes.filter((alltypes.timestamp_col < param)) agg = f.aggregate([f.float_col.sum().name("foo")], by=[f.string_col]) result = agg.foo.count() diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_binop_with_alias_still_parenthesized/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_binop_with_alias_still_parenthesized/out.sql new file mode 100644 index 0000000000000..6ad547995b5c7 --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_binop_with_alias_still_parenthesized/out.sql @@ -0,0 +1,5 @@ +SELECT + ( + "t0"."a" + "t0"."b" + ) * "t0"."c" AS "x" +FROM "t" AS "t0" \ No newline at end of file diff --git a/ibis/backends/tests/sql/test_compiler.py b/ibis/backends/tests/sql/test_compiler.py index bc789dba86ba8..dc27463647b8d 100644 --- a/ibis/backends/tests/sql/test_compiler.py +++ b/ibis/backends/tests/sql/test_compiler.py @@ -196,7 +196,7 @@ def test_subquery_where_location(snapshot): ], name="alltypes", ) - param = ibis.param("timestamp").name("my_param") + param = ibis.param("timestamp") expr = ( t[["float_col", "timestamp_col", "int_col", "string_col"]][ lambda t: t.timestamp_col < param diff --git a/ibis/backends/tests/sql/test_sql.py b/ibis/backends/tests/sql/test_sql.py index 01fd47220982e..f979b864d9162 100644 --- a/ibis/backends/tests/sql/test_sql.py +++ b/ibis/backends/tests/sql/test_sql.py @@ -143,6 +143,12 @@ def test_binop_parens(snapshot, opname, dtype, associative): snapshot.assert_match(combined, "out.sql") +def test_binop_with_alias_still_parenthesized(snapshot): + t = ibis.table({"a": "int", "b": "int", "c": "int"}, name="t") + sql = to_sql(((t.a + t.b).name("d") * t.c).name("x")) + snapshot.assert_match(sql, "out.sql") + + @pytest.mark.parametrize( "expr_fn", [ diff --git a/ibis/expr/operations/core.py b/ibis/expr/operations/core.py index 98d5a3c3c1168..2fdb0ee58ac3c 100644 --- a/ibis/expr/operations/core.py +++ b/ibis/expr/operations/core.py @@ -172,6 +172,14 @@ class Binary(Value): left: Value right: Value + def __init__(self, left, right, **kwargs): + # Dealias left and right, passing through any additional fields + if isinstance(left, Alias): + left = left.arg + if isinstance(right, Alias): + right = right.arg + super().__init__(left=left, right=right, **kwargs) + @attribute def shape(self) -> ds.DataShape: return max(self.left.shape, self.right.shape) diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index cc5a756437e57..0a581186abf04 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -731,6 +731,12 @@ def test_string_mul(table, left, right): assert isinstance(expr.op(), ops.Repeat) +def test_binop_strips_aliases(table): + assert (table.a.name("x") + table.b).equals(table.a + table.b) + assert (table.a + table.b.name("x")).equals(table.a + table.b) + assert (table.a.name("x") + 1).equals(table.a + 1) + + @pytest.mark.parametrize( ["op", "name", "case", "ex_type"], [