From 522f3a4b5a9d1a3f39b6952ec5dc13fdd59c0855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 13 Nov 2023 14:39:52 +0100 Subject: [PATCH] fix(ir): `ibis.parse_sql()` removes where clause --- ibis/expr/sql.py | 14 +++-- .../decompiled.py | 41 +++++++++++++ .../decompiled.py | 17 ++++++ .../decompiled.py | 22 +++++++ .../inner/decompiled.py | 34 +++++++++++ .../left/decompiled.py | 32 ++++++++++ .../right/decompiled.py | 34 +++++++++++ .../decompiled.py | 12 ++++ .../decompiled.py | 32 ++++++++++ .../decompiled.py | 38 ++++++++++++ .../decompiled.py | 16 +++++ .../decompiled.py | 15 +++++ .../decompiled.py | 9 +++ .../test_parse_sql_table_alias/decompiled.py | 9 +++ ibis/expr/tests/test_sql.py | 61 +++++++++++++------ pyproject.toml | 3 +- 16 files changed, 363 insertions(+), 26 deletions(-) create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_aggregation_with_multiple_joins/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation_with_join/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/inner/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/left/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/right/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_projection/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_join_with_filter/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_multiple_joins/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_scalar_subquery/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_reduction/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_select_count/decompiled.py create mode 100644 ibis/expr/tests/snapshots/test_sql/test_parse_sql_table_alias/decompiled.py diff --git a/ibis/expr/sql.py b/ibis/expr/sql.py index 7d4521dd02bf..5e5241561b8b 100644 --- a/ibis/expr/sql.py +++ b/ibis/expr/sql.py @@ -125,8 +125,8 @@ def convert_join(join, catalog): catalog = catalog.overlay(join) left_name = join.name + left_table = catalog[left_name] for right_name, desc in join.joins.items(): - left_table = catalog[left_name] right_table = catalog[right_name] join_kind = _join_types[desc["side"]] @@ -139,11 +139,15 @@ def convert_join(join, catalog): else: predicate &= left_key == right_key - catalog[left_name] = left_table.join( - right_table, predicates=predicate, how=join_kind - ) + left_table = left_table.join(right_table, predicates=predicate, how=join_kind) - return catalog[left_name] + if join.condition: + predicate = convert(join.condition, catalog=catalog) + left_table = left_table.filter(predicate) + + catalog[left_name] = left_table + + return left_table @convert.register(sgp.Aggregate) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_aggregation_with_multiple_joins/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_aggregation_with_multiple_joins/decompiled.py new file mode 100644 index 000000000000..03fcc9f2791f --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_aggregation_with_multiple_joins/decompiled.py @@ -0,0 +1,41 @@ +import ibis + + +call = ibis.table( + name="call", + schema={ + "start_time": "timestamp", + "end_time": "timestamp", + "employee_id": "int64", + "call_outcome_id": "int64", + "call_attempts": "int64", + }, +) +call_outcome = ibis.table( + name="call_outcome", schema={"outcome_text": "string", "id": "int64"} +) +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) +innerjoin = employee.inner_join(call, employee.id == call.employee_id) + +result = ( + innerjoin.inner_join(call_outcome, call.call_outcome_id == call_outcome.id) + .select( + [ + innerjoin.first_name, + innerjoin.last_name, + innerjoin.id, + innerjoin.start_time, + innerjoin.end_time, + innerjoin.employee_id, + innerjoin.call_outcome_id, + innerjoin.call_attempts, + call_outcome.outcome_text, + call_outcome.id.name("id_right"), + ] + ) + .group_by(call.employee_id) + .aggregate(call.call_attempts.mean().name("avg_attempts")) +) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation/decompiled.py new file mode 100644 index 000000000000..b5bd4842d48b --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation/decompiled.py @@ -0,0 +1,17 @@ +import ibis + + +call = ibis.table( + name="call", + schema={ + "start_time": "timestamp", + "end_time": "timestamp", + "employee_id": "int64", + "call_outcome_id": "int64", + "call_attempts": "int64", + }, +) + +result = call.group_by(call.employee_id).aggregate( + call.call_attempts.sum().name("attempts") +) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation_with_join/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation_with_join/decompiled.py new file mode 100644 index 000000000000..392f50271b0b --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_aggregation_with_join/decompiled.py @@ -0,0 +1,22 @@ +import ibis + + +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) +call = ibis.table( + name="call", + schema={ + "start_time": "timestamp", + "end_time": "timestamp", + "employee_id": "int64", + "call_outcome_id": "int64", + "call_attempts": "int64", + }, +) +leftjoin = employee.left_join(call, employee.id == call.employee_id) + +result = leftjoin.group_by(leftjoin.id).aggregate( + call.call_attempts.sum().name("attempts") +) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/inner/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/inner/decompiled.py new file mode 100644 index 000000000000..05f419d668db --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/inner/decompiled.py @@ -0,0 +1,34 @@ +import ibis + + +call = ibis.table( + name="call", + schema={ + "start_time": "timestamp", + "end_time": "timestamp", + "employee_id": "int64", + "call_outcome_id": "int64", + "call_attempts": "int64", + }, +) +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) +proj = employee.inner_join(call, employee.id == call.employee_id).filter( + employee.id < 5 +) + +result = proj.select( + [ + proj.first_name, + proj.last_name, + proj.id, + call.start_time, + call.end_time, + call.employee_id, + call.call_outcome_id, + call.call_attempts, + proj.first_name.name("first"), + ] +).order_by(proj.id.desc()) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/left/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/left/decompiled.py new file mode 100644 index 000000000000..2ed2c808d726 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/left/decompiled.py @@ -0,0 +1,32 @@ +import ibis + + +call = ibis.table( + name="call", + schema={ + "start_time": "timestamp", + "end_time": "timestamp", + "employee_id": "int64", + "call_outcome_id": "int64", + "call_attempts": "int64", + }, +) +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) +proj = employee.left_join(call, employee.id == call.employee_id).filter(employee.id < 5) + +result = proj.select( + [ + proj.first_name, + proj.last_name, + proj.id, + call.start_time, + call.end_time, + call.employee_id, + call.call_outcome_id, + call.call_attempts, + proj.first_name.name("first"), + ] +).order_by(proj.id.desc()) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/right/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/right/decompiled.py new file mode 100644 index 000000000000..0f31dffd1532 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_join/right/decompiled.py @@ -0,0 +1,34 @@ +import ibis + + +call = ibis.table( + name="call", + schema={ + "start_time": "timestamp", + "end_time": "timestamp", + "employee_id": "int64", + "call_outcome_id": "int64", + "call_attempts": "int64", + }, +) +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) +proj = employee.right_join(call, employee.id == call.employee_id).filter( + employee.id < 5 +) + +result = proj.select( + [ + proj.first_name, + proj.last_name, + proj.id, + call.start_time, + call.end_time, + call.employee_id, + call.call_outcome_id, + call.call_attempts, + proj.first_name.name("first"), + ] +).order_by(proj.id.desc()) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_projection/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_projection/decompiled.py new file mode 100644 index 000000000000..b6e37f2ab518 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_basic_projection/decompiled.py @@ -0,0 +1,12 @@ +import ibis + + +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) +proj = employee.filter(employee.id < 5) + +result = proj.select( + [proj.first_name, proj.last_name, proj.id, proj.first_name.name("first")] +).order_by(proj.id.desc()) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_join_with_filter/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_join_with_filter/decompiled.py new file mode 100644 index 000000000000..2ed2c808d726 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_join_with_filter/decompiled.py @@ -0,0 +1,32 @@ +import ibis + + +call = ibis.table( + name="call", + schema={ + "start_time": "timestamp", + "end_time": "timestamp", + "employee_id": "int64", + "call_outcome_id": "int64", + "call_attempts": "int64", + }, +) +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) +proj = employee.left_join(call, employee.id == call.employee_id).filter(employee.id < 5) + +result = proj.select( + [ + proj.first_name, + proj.last_name, + proj.id, + call.start_time, + call.end_time, + call.employee_id, + call.call_outcome_id, + call.call_attempts, + proj.first_name.name("first"), + ] +).order_by(proj.id.desc()) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_multiple_joins/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_multiple_joins/decompiled.py new file mode 100644 index 000000000000..ae6bfd9788f7 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_multiple_joins/decompiled.py @@ -0,0 +1,38 @@ +import ibis + + +call_outcome = ibis.table( + name="call_outcome", schema={"outcome_text": "string", "id": "int64"} +) +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) +call = ibis.table( + name="call", + schema={ + "start_time": "timestamp", + "end_time": "timestamp", + "employee_id": "int64", + "call_outcome_id": "int64", + "call_attempts": "int64", + }, +) +innerjoin = employee.inner_join(call, employee.id == call.employee_id) + +result = innerjoin.inner_join( + call_outcome, call.call_outcome_id == call_outcome.id +).select( + [ + innerjoin.first_name, + innerjoin.last_name, + innerjoin.id, + innerjoin.start_time, + innerjoin.end_time, + innerjoin.employee_id, + innerjoin.call_outcome_id, + innerjoin.call_attempts, + call_outcome.outcome_text, + call_outcome.id.name("id_right"), + ] +) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_scalar_subquery/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_scalar_subquery/decompiled.py new file mode 100644 index 000000000000..81f627719b17 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_scalar_subquery/decompiled.py @@ -0,0 +1,16 @@ +import ibis + + +call = ibis.table( + name="call", + schema={ + "start_time": "timestamp", + "end_time": "timestamp", + "employee_id": "int64", + "call_outcome_id": "int64", + "call_attempts": "int64", + }, +) +agg = call.aggregate(call.call_attempts.mean().name("mean")) + +result = call.inner_join(agg, []) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_reduction/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_reduction/decompiled.py new file mode 100644 index 000000000000..8fa935b7c5e4 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_reduction/decompiled.py @@ -0,0 +1,15 @@ +import ibis + + +call = ibis.table( + name="call", + schema={ + "start_time": "timestamp", + "end_time": "timestamp", + "employee_id": "int64", + "call_outcome_id": "int64", + "call_attempts": "int64", + }, +) + +result = call.aggregate(call.call_attempts.mean().name("mean")) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_select_count/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_select_count/decompiled.py new file mode 100644 index 000000000000..d993ac2ac040 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_simple_select_count/decompiled.py @@ -0,0 +1,9 @@ +import ibis + + +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) + +result = employee.aggregate(employee.first_name.count().name("_col_0")) diff --git a/ibis/expr/tests/snapshots/test_sql/test_parse_sql_table_alias/decompiled.py b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_table_alias/decompiled.py new file mode 100644 index 000000000000..ec5df1972413 --- /dev/null +++ b/ibis/expr/tests/snapshots/test_sql/test_parse_sql_table_alias/decompiled.py @@ -0,0 +1,9 @@ +import ibis + + +employee = ibis.table( + name="employee", + schema={"first_name": "string", "last_name": "string", "id": "int64"}, +) + +result = employee.select([employee.first_name, employee.last_name, employee.id]) diff --git a/ibis/expr/tests/test_sql.py b/ibis/expr/tests/test_sql.py index d53b8fe997b3..9e2961401191 100644 --- a/ibis/expr/tests/test_sql.py +++ b/ibis/expr/tests/test_sql.py @@ -17,14 +17,15 @@ } -def test_parse_sql_basic_projection(): +def test_parse_sql_basic_projection(snapshot): sql = "SELECT *, first_name as first FROM employee WHERE id < 5 ORDER BY id DESC" expr = ibis.parse_sql(sql, catalog) - code = ibis.decompile(expr, format=True) # noqa: F841 + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") @pytest.mark.parametrize("how", ["right", "left", "inner"]) -def test_parse_sql_basic_join(how): +def test_parse_sql_basic_join(how, snapshot): sql = f""" SELECT *, @@ -37,10 +38,11 @@ def test_parse_sql_basic_join(how): ORDER BY id DESC""" expr = ibis.parse_sql(sql, catalog) - code = ibis.decompile(expr, format=True) # noqa: F841 + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") -def test_parse_sql_multiple_joins(): +def test_parse_sql_multiple_joins(snapshot): sql = """ SELECT * FROM employee @@ -49,10 +51,11 @@ def test_parse_sql_multiple_joins(): JOIN call_outcome ON call.call_outcome_id = call_outcome.id""" expr = ibis.parse_sql(sql, catalog) - code = ibis.decompile(expr, format=True) # noqa: F841 + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") -def test_parse_sql_basic_aggregation(): +def test_parse_sql_basic_aggregation(snapshot): sql = """ SELECT employee_id, @@ -60,10 +63,11 @@ def test_parse_sql_basic_aggregation(): FROM call GROUP BY employee_id""" expr = ibis.parse_sql(sql, catalog) - code = ibis.decompile(expr, format=True) # noqa: F841 + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") -def test_parse_sql_basic_aggregation_with_join(): +def test_parse_sql_basic_aggregation_with_join(snapshot): sql = """ SELECT id, @@ -73,10 +77,11 @@ def test_parse_sql_basic_aggregation_with_join(): ON employee.id = call.employee_id GROUP BY id""" expr = ibis.parse_sql(sql, catalog) - code = ibis.decompile(expr, format=True) # noqa: F841 + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") -def test_parse_sql_aggregation_with_multiple_joins(): +def test_parse_sql_aggregation_with_multiple_joins(snapshot): sql = """ SELECT t.employee_id, @@ -87,16 +92,18 @@ def test_parse_sql_aggregation_with_multiple_joins(): ) AS t GROUP BY t.employee_id""" expr = ibis.parse_sql(sql, catalog) - code = ibis.decompile(expr, format=True) # noqa: F841 + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") -def test_parse_sql_simple_reduction(): +def test_parse_sql_simple_reduction(snapshot): sql = """SELECT AVG(call_attempts) AS mean FROM call""" expr = ibis.parse_sql(sql, catalog) - code = ibis.decompile(expr, format=True) # noqa: F841 + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") -def test_parse_sql_scalar_subquery(): +def test_parse_sql_scalar_subquery(snapshot): sql = """ SELECT * FROM call @@ -105,16 +112,30 @@ def test_parse_sql_scalar_subquery(): FROM call )""" expr = ibis.parse_sql(sql, catalog) - code = ibis.decompile(expr, format=True) # noqa: F841 + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") -def test_parse_sql_simple_select_count(): +def test_parse_sql_simple_select_count(snapshot): sql = """SELECT COUNT(first_name) FROM employee""" expr = ibis.parse_sql(sql, catalog) - code = ibis.decompile(expr, format=True) # noqa: F841 + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") -def test_parse_sql_table_alias(): +def test_parse_sql_table_alias(snapshot): sql = """SELECT e.* FROM employee AS e""" expr = ibis.parse_sql(sql, catalog) - code = ibis.decompile(expr, format=True) # noqa: F841 + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") + + +def test_parse_sql_join_with_filter(snapshot): + sql = """ +SELECT *, first_name as first FROM employee +LEFT JOIN call ON employee.id = call.employee_id +WHERE id < 5 +ORDER BY id DESC""" + expr = ibis.parse_sql(sql, catalog) + code = ibis.decompile(expr, format=True) + snapshot.assert_match(code, "decompiled.py") diff --git a/pyproject.toml b/pyproject.toml index 528f48e2ffb0..0b13f5924008 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -451,7 +451,7 @@ ignore = [ "SIM300", # yoda conditions "UP007", # Optional[str] -> str | None ] -exclude = ["*_py310.py", "ibis/tests/*/snapshots/*"] +exclude = ["*_py310.py"] target-version = "py39" # none of these codes will be automatically fixed by ruff unfixable = [ @@ -473,6 +473,7 @@ required-imports = ["from __future__ import annotations"] ] "ci/*.py" = ["INP001"] "docs/*.py" = ["INP001"] +"*/decompiled.py" = ["ALL"] "ci/release/verify_release.py" = ["T201"] # CLI tool that prints stuff [tool.blackdoc]