Skip to content

Commit

Permalink
fix(sqlalchemy): fix correlated subquery compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jan 10, 2022
1 parent f503789 commit 43b9010
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _varargs_call(sa_func, t, expr):


def get_sqla_table(ctx, table):
if ctx.has_ref(table):
if ctx.has_ref(table, parent_contexts=True):
ctx_level = ctx
sa_table = ctx_level.get_ref(table)
while sa_table is None and ctx_level.parent is not ctx_level:
Expand Down
81 changes: 81 additions & 0 deletions ibis/tests/sql/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,87 @@ def test_sort_aggregation_translation_failure(self):

self._compare_sqla(expr, ex)

def test_where_correlated_subquery_with_join(self):
# GH3163
# ibis code
part = ibis.table([("p_partkey", "int64")], name="part")
partsupp = ibis.table(
[
("ps_partkey", "int64"),
("ps_supplycost", "float64"),
("ps_suppkey", "int64"),
],
name="partsupp",
)
supplier = ibis.table([("s_suppkey", "int64")], name="supplier")

q = part.join(partsupp, part.p_partkey == partsupp.ps_partkey)
q = q[
part.p_partkey,
partsupp.ps_supplycost,
]
subq = partsupp.join(
supplier, supplier.s_suppkey == partsupp.ps_suppkey
)
subq = subq.projection([partsupp.ps_partkey, partsupp.ps_supplycost])
subq = subq[subq.ps_partkey == q.p_partkey]

expr = q[q.ps_supplycost == subq.ps_supplycost.min()]

# sqlalchemy code
part = sa.table("part", sa.column("p_partkey"))
supplier = sa.table("supplier", sa.column("s_suppkey"))
partsupp = sa.table(
"partsupp",
sa.column("ps_partkey"),
sa.column("ps_supplycost"),
sa.column("ps_suppkey"),
)

part_t1 = part.alias("t1")
partsupp_t2 = partsupp.alias("t2")

t0 = (
sa.select([part_t1.c.p_partkey, partsupp_t2.c.ps_supplycost])
.select_from(
part_t1.join(
partsupp_t2,
onclause=part_t1.c.p_partkey == partsupp_t2.c.ps_partkey,
)
)
.alias("t0")
)

partsupp_t2 = partsupp.alias("t2")
supplier_t5 = supplier.alias("t5")
t3 = (
sa.select([partsupp_t2.c.ps_partkey, partsupp_t2.c.ps_supplycost])
.select_from(
partsupp_t2.join(
supplier_t5,
onclause=supplier_t5.c.s_suppkey
== partsupp_t2.c.ps_suppkey,
)
)
.alias("t3")
)

ex = (
sa.select([t0.c.p_partkey, t0.c.ps_supplycost])
.select_from(t0)
.where(
t0.c.ps_supplycost
== (
sa.select([sa.func.min(t3.c.ps_supplycost).label("min")])
.select_from(t3)
.where(t3.c.ps_partkey == t0.c.p_partkey)
.as_scalar()
)
)
)

self._compare_sqla(expr, ex)

def _compare_sqla(self, expr, sqla):
context = AlchemyContext(compiler=AlchemyCompiler)
result_sqla = AlchemyCompiler.to_sql(expr, context)
Expand Down

0 comments on commit 43b9010

Please sign in to comment.