Skip to content

Commit

Permalink
fix(sql): outer order by should take precedence over inner order by
Browse files Browse the repository at this point in the history
  • Loading branch information
kszucs authored and cpcloud committed Mar 14, 2024
1 parent 54c2c70 commit 4376c35
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 10 deletions.
19 changes: 15 additions & 4 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ def dtype(self):
return self.func.dtype


# TODO(kszucs): there is a better strategy to rewrite the relational operations
# to Select nodes by wrapping the leaf nodes in a Select node and then merging
# Project, Filter, Sort, etc. incrementally into the Select node. This way we
# can have tighter control over simplification logic.


@replace(p.Project)
def project_to_select(_, **kwargs):
"""Convert a Project node to a Select node."""
Expand Down Expand Up @@ -167,17 +173,22 @@ def merge_select_select(_, **kwargs):

subs = {ops.Field(_.parent, k): v for k, v in _.parent.values.items()}
selections = {k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()}
predicates = tuple(p.replace(subs, filter=ops.Value) for p in _.predicates)
sort_keys = tuple(s.replace(subs, filter=ops.Value) for s in _.sort_keys)

predicates = tuple(p.replace(subs, filter=ops.Value) for p in _.predicates)
unique_predicates = toolz.unique(_.parent.predicates + predicates)
unique_sort_keys = {s.expr: s for s in _.parent.sort_keys + sort_keys}

sort_keys = tuple(s.replace(subs, filter=ops.Value) for s in _.sort_keys)
sort_key_exprs = {s.expr for s in sort_keys}
parent_sort_keys = tuple(
k for k in _.parent.sort_keys if k.expr not in sort_key_exprs
)
unique_sort_keys = sort_keys + parent_sort_keys

return Select(
_.parent.parent,
selections=selections,
predicates=unique_predicates,
sort_keys=unique_sort_keys.values(),
sort_keys=unique_sort_keys,
)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
"t0"."a",
"t0"."b"
FROM "t" AS "t0"
ORDER BY
"t0"."b" DESC,
"t0"."a" ASC
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
SELECT
"t0"."a",
"t0"."b",
"t0"."c"
FROM "t" AS "t0"
ORDER BY
"t0"."b" ASC,
"t0"."a" DESC,
"t0"."c" ASC
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
SELECT
"t0"."a",
"t0"."b",
"t0"."c"
FROM "t" AS "t0"
ORDER BY
"t0"."b" ASC,
"t0"."a" + CAST(1 AS TINYINT) DESC,
"t0"."a" ASC,
"t0"."c" ASC
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
"t0"."a",
"t0"."b"
FROM "t" AS "t0"
ORDER BY
"t0"."a" ASC,
"t0"."b" DESC
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
SELECT
"t0"."a",
"t0"."b",
"t0"."c"
FROM "t" AS "t0"
ORDER BY
"t0"."b" ASC,
"t0"."a" DESC,
"t0"."c" ASC
41 changes: 41 additions & 0 deletions ibis/backends/tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytest import param

import ibis
from ibis import _
from ibis.backends.tests.sql.conftest import to_sql
from ibis.tests.util import assert_decompile_roundtrip

Expand Down Expand Up @@ -518,6 +519,46 @@ def test_order_by_expr(snapshot):
snapshot.assert_match(to_sql(expr), "out.sql")


def test_double_order_by(snapshot):
t = ibis.table(dict(a="int", b="string"), name="t")
# t.b DESC, t.a ASC
expr = t.order_by(t.a).order_by(t.b.desc())
sql = to_sql(expr, pretty=False)
expected = '"t0"."b" DESC, "t0"."a" ASC'
assert expected in sql
snapshot.assert_match(to_sql(expr), "out.sql")


def test_double_order_by_same_column(snapshot):
t = ibis.table(dict(a="int", b="string", c="float"), name="t")
# t.b ASC, t.a DESC, t.c ASC
expr = t.order_by(t.a, t.c).order_by(t.b.asc(), t.a.desc())
sql = to_sql(expr, pretty=False)
expected = '"t0"."b" ASC, "t0"."a" DESC, "t0"."c" ASC'
assert expected in sql
snapshot.assert_match(to_sql(expr), "out.sql")


def test_double_order_by_deferred(snapshot):
t = ibis.table(dict(a="int", b="string", c="float"), name="t")
expr = t.order_by(t.a, t.c).order_by(t.b.asc(), _.a.desc())
sql = to_sql(expr, pretty=False)
expected = '"t0"."b" ASC, "t0"."a" DESC, "t0"."c" ASC'
assert expected in sql
snapshot.assert_match(to_sql(expr), "out.sql")


def test_double_order_by_different_expression(snapshot):
t = ibis.table(dict(a="int", b="string", c="float"), name="t")
expr = t.order_by(t.a, t.c).order_by(t.b.asc(), (t.a + 1).desc())
sql = to_sql(expr, pretty=False)
expected = (
'"t0"."b" ASC, "t0"."a" + CAST(1 AS TINYINT) DESC, "t0"."a" ASC, "t0"."c" ASC'
)
assert expected in sql
snapshot.assert_match(to_sql(expr), "out.sql")


def test_no_cartesian_join(snapshot):
customers = ibis.table(
dict(customer_id="int64", first_name="string", last_name="string"),
Expand Down
35 changes: 32 additions & 3 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,38 @@ def test_dropna_table(backend, alltypes, how, subset):
backend.assert_frame_equal(result, expected)


def test_select_sort_sort(alltypes):
query = alltypes[alltypes.year, alltypes.bool_col]
query = query.order_by(query.year).order_by(query.bool_col)
def test_select_sort_sort(backend, alltypes, df):
t = alltypes

expr = t.order_by(t.year, t.id.desc()).order_by(t.bool_col)

result = expr.execute().reset_index(drop=True)
expected = df.sort_values(
["bool_col", "year", "id"], ascending=[True, True, False]
).reset_index(drop=True)

backend.assert_frame_equal(result, expected)


def test_select_sort_sort_deferred(backend, alltypes, df):
t = alltypes

expr = t.order_by(t.tinyint_col, t.bool_col, t.id).order_by(
_.bool_col.asc(), (_.tinyint_col + 1).desc()
)
result = expr.execute().reset_index(drop=True)

df = df.assign(tinyint_col_plus=df.tinyint_col + 1)
expected = (
df.sort_values(
["bool_col", "tinyint_col_plus", "tinyint_col", "id"],
ascending=[True, False, True, True],
)
.drop(columns=["tinyint_col_plus"])
.reset_index(drop=True)
)

backend.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
Expand Down
10 changes: 7 additions & 3 deletions ibis/expr/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ def _repr_pretty_(self, p, cycle) -> str:


@public
def to_sql(expr: ir.Expr, dialect: str | None = None, **kwargs) -> SQLString:
def to_sql(
expr: ir.Expr, dialect: str | None = None, pretty: bool = True, **kwargs
) -> SQLString:
"""Return the formatted SQL string for an expression.
Parameters
Expand All @@ -349,6 +351,8 @@ def to_sql(expr: ir.Expr, dialect: str | None = None, **kwargs) -> SQLString:
Ibis expression.
dialect
SQL dialect to use for compilation.
pretty
Whether to use pretty formatting.
kwargs
Scalar parameters
Expand Down Expand Up @@ -380,5 +384,5 @@ def to_sql(expr: ir.Expr, dialect: str | None = None, **kwargs) -> SQLString:
read = write = getattr(backend, "dialect", dialect)

sql = backend._to_sql(expr.unbind(), **kwargs)
(pretty,) = sg.transpile(sql, read=read, write=write, pretty=True)
return SQLString(pretty)
(transpiled,) = sg.transpile(sql, read=read, write=write, pretty=pretty)
return SQLString(transpiled)

0 comments on commit 4376c35

Please sign in to comment.