Skip to content

Commit

Permalink
feat(sql): fuse distinct with other select nodes when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Aug 27, 2024
1 parent 335a538 commit 913808e
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 33 deletions.
14 changes: 7 additions & 7 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,9 +1257,11 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]):
else:
yield value.as_(name, quoted=self.quoted, copy=False)

def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
def visit_Select(
self, op, *, parent, selections, predicates, qualified, sort_keys, distinct
):
# if we've constructed a useless projection return the parent relation
if not (selections or predicates or qualified or sort_keys):
if not (selections or predicates or qualified or sort_keys or distinct):
return parent

result = parent
Expand All @@ -1286,6 +1288,9 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke
if sort_keys:
result = result.order_by(*sort_keys, copy=False)

if distinct:
result = result.distinct()

return result

def visit_DummyTable(self, op, *, values):
Expand Down Expand Up @@ -1470,11 +1475,6 @@ def visit_Limit(self, op, *, parent, n, offset):
return result.subquery(alias, copy=False)
return result

def visit_Distinct(self, op, *, parent):
return (
sg.select(STAR, copy=False).distinct(copy=False).from_(parent, copy=False)
)

def visit_CTE(self, op, *, parent):
return sg.table(parent.alias_or_name, quoted=self.quoted)

Expand Down
33 changes: 33 additions & 0 deletions ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Select(ops.Relation):
predicates: VarTuple[ops.Value[dt.Boolean]] = ()
qualified: VarTuple[ops.Value[dt.Boolean]] = ()
sort_keys: VarTuple[ops.SortKey] = ()
distinct: bool = False

def is_star_selection(self):
return tuple(self.values.items()) == tuple(self.parent.fields.items())
Expand Down Expand Up @@ -128,6 +129,16 @@ def sort_to_select(_, **kwargs):
return Select(_.parent, selections=_.values, sort_keys=_.keys)


@replace(p.Distinct)
def distinct_to_select(_, **kwargs):
"""Convert a Distinct node to a Select node."""
if isinstance(_.parent, Select):
# Common case of some_relation_op().distinct() ..., can always eagerly
# fuse into one Select op.
return _.parent.copy(distinct=True)
return Select(_.parent, selections=_.values, distinct=True)


@replace(p.DropColumns)
def drop_columns_to_select(_, **kwargs):
"""Convert a DropColumns node to a Select node."""
Expand Down Expand Up @@ -244,6 +255,26 @@ def merge_select_select(_, **kwargs):
if _.parent.find_below(blocking, filter=ops.Value):
return _

if _.parent.distinct:
# If the inner query is distinct, we can only fuse if the outer query
# is a trivial (SELECT [DISTINCT] a, b FROM (...)) select.
# - If the outer query is not distinct, it's not safe to fuse since new
# columns might be non-distinct, where applying a distinct would change
# the result.
# - If the outer query is distinct, it's not safe to fuse if the outer
# query uses any non-deterministic functions. Fusing would be correct
# if only deterministic functions were used, but if any function calls
# were expensive this may result in a performance issue as dropping
# the subquery may result in many more calls of that function.
if all(isinstance(v, ops.Field) for v in _.selections.values()):
distinct = True
else:
return _
else:
# Can always fuse if the inner query isn't distinct, using the distinct
# value of the outer query.
distinct = _.distinct

subs = {ops.Field(_.parent, k): v for k, v in _.parent.values.items()}
selections = {k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()}

Expand All @@ -266,6 +297,7 @@ def merge_select_select(_, **kwargs):
predicates=unique_predicates,
qualified=unique_qualified,
sort_keys=unique_sort_keys,
distinct=distinct,
)
return result if complexity(result) <= complexity(_) else _

Expand Down Expand Up @@ -322,6 +354,7 @@ def sqlize(
| project_to_select
| filter_to_select
| sort_to_select
| distinct_to_select
| fill_null_to_select
| drop_null_to_select
| drop_columns_to_select
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
SELECT DISTINCT
*
FROM (
SELECT
"t0"."string_col"
FROM "functional_alltypes" AS "t0"
) AS "t1"
"t0"."string_col"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
SELECT DISTINCT
*
FROM (
SELECT
"t0"."string_col",
"t0"."int_col"
FROM "functional_alltypes" AS "t0"
) AS "t1"
"t0"."string_col",
"t0"."int_col"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT DISTINCT
"t1"."a",
"t1"."b" % 2 AS "d"
FROM (
SELECT DISTINCT
"t0"."a",
"t0"."b",
"t0"."c"
FROM "test" AS "t0"
WHERE
"t0"."c" > 10
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
SELECT
"t1"."a",
"t1"."b" % 2 AS "d"
FROM (
SELECT DISTINCT
"t0"."a",
"t0"."b",
"t0"."c"
FROM "test" AS "t0"
WHERE
"t0"."c" > 10
) AS "t1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT DISTINCT
"t0"."a",
"t0"."b"
FROM "test" AS "t0"
WHERE
"t0"."c" > 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT DISTINCT
"t0"."a",
"t0"."b"
FROM "test" AS "t0"
WHERE
"t0"."c" > 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT DISTINCT
"t0"."a",
"t0"."b" % 2 AS "d"
FROM "test" AS "t0"
WHERE
"t0"."c" > 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT DISTINCT
"t0"."a",
"t0"."b"
FROM "test" AS "t0"
WHERE
"t0"."c" > 10
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
SELECT DISTINCT
*
FROM (
SELECT
"t0"."string_col",
"t0"."int_col"
FROM "functional_alltypes" AS "t0"
) AS "t1"
"t0"."string_col",
"t0"."int_col"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
SELECT DISTINCT
*
FROM (
SELECT
"t0"."string_col"
FROM "functional_alltypes" AS "t0"
) AS "t1"
"t0"."string_col"
FROM "functional_alltypes" AS "t0"
29 changes: 29 additions & 0 deletions ibis/backends/tests/sql/test_select_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,35 @@ def test_projection_filter_fuse(projection_fuse_filter, snapshot):
snapshot.assert_match(to_sql(expr3), "out.sql")


@pytest.mark.parametrize(
"transform",
[
param(lambda t: t.select("a", "b").distinct(), id="select-distinct"),
param(lambda t: t.distinct().select("a", "b"), id="distinct-select"),
param(
lambda t: t.distinct().select("a", "b").distinct(),
id="distinct-select-distinct",
),
param(
lambda t: t.distinct().select("a", d=(_.b % 2)),
id="distinct-non-trivial-select",
),
param(
lambda t: t.select("a", d=(_.b % 2)).distinct(),
id="non-trivial-select-distinct",
),
param(
lambda t: t.distinct().select("a", d=(_.b % 2)).distinct(),
id="distinct-non-trivial-select-distinct",
),
],
)
def test_fuse_distinct(snapshot, transform):
t = ibis.table({"a": "int", "b": "int", "c": "int", "d": "int"}, name="test")
expr = transform(t.select("a", "b", "c").filter(t.c > 10))
snapshot.assert_match(to_sql(expr), "out.sql")


def test_bug_project_multiple_times(customer, nation, region, snapshot):
# GH: 108
joined = customer.inner_join(
Expand Down

0 comments on commit 913808e

Please sign in to comment.