Skip to content

Commit

Permalink
refactor(sqlalchemy): remove the need for deferred columns
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jan 13, 2023
1 parent efa42bd commit e4011aa
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 31 deletions.
21 changes: 6 additions & 15 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,13 @@ def get_sqla_table(ctx, table):
return sa_table


def get_col_or_deferred_col(sa_table, colname):
"""Get a `Column`, or create a "deferred" column.
def get_col(sa_table, op: ops.TableColumn):
"""Get a `Column`."""
cols = sa_table.exported_columns
colname = op.name

This is to handle the case when selecting a column from a join, which
happens when a join expression is cached during join traversal
We'd like to avoid generating a subquery just for selection but in
sqlalchemy the Join object is not selectable. However, at this point
know that the column can be referred to unambiguously
Later the expression is assembled into
`sa.select([sa.column(colname)]).select_from(table_set)` (roughly)
where `table_set` is `sa_table` above.
"""
try:
return sa_table.exported_columns[colname]
return cols[colname]
except KeyError:
# cols is a sqlalchemy column collection which contains column
# names that are secretly prefixed by their containing table
Expand All @@ -116,7 +107,7 @@ def _table_column(t, op):

sa_table = get_sqla_table(ctx, table)

out_expr = get_col_or_deferred_col(sa_table, op.name)
out_expr = get_col(sa_table, op)
out_expr.quote = t._always_quote_columns

# If the column does not originate from the table set in the current SELECT
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
_bitwise_op,
_extract,
geospatial_functions,
get_col_or_deferred_col,
get_col,
)

operation_registry = sqlalchemy_operation_registry.copy()
Expand Down Expand Up @@ -302,7 +302,7 @@ def _table_column(t, op):
table = op.table

sa_table = get_sqla_table(ctx, table)
out_expr = get_col_or_deferred_col(sa_table, op.name)
out_expr = get_col(sa_table, op)

if op.output_dtype.is_timestamp():
timezone = op.output_dtype.timezone
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT ancestor_node_sort_order, 1 AS n
SELECT t1.ancestor_node_sort_order, 1 AS n
FROM facts AS t0 JOIN (SELECT t2.ancestor_level_name AS ancestor_level_name, t2.ancestor_level_number AS ancestor_level_number, t2.ancestor_node_sort_order AS ancestor_node_sort_order, t2.descendant_node_natural_key AS descendant_node_natural_key, concat(lpad('-', (t2.ancestor_level_number - 1) * 7, '-'), t2.ancestor_level_name) AS product_level_name
FROM products AS t2) AS t1 ON t0.product_id = t1.descendant_node_natural_key GROUP BY ancestor_node_sort_order ORDER BY ancestor_node_sort_order ASC
FROM products AS t2) AS t1 ON t0.product_id = t1.descendant_node_natural_key GROUP BY t1.ancestor_node_sort_order ORDER BY t1.ancestor_node_sort_order ASC
18 changes: 6 additions & 12 deletions ibis/tests/sql/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,32 +1095,26 @@ def test_tpc_h11(h11):
t4 = supplier.alias("t4")
t1 = (
sa.select(
sa.column("ps_partkey"),
sa.func.sum(sa.column("ps_supplycost") * sa.column("ps_availqty")).label(
"value"
),
t3.c.ps_partkey,
sa.func.sum(t3.c.ps_supplycost * t3.c.ps_availqty).label("value"),
)
.select_from(
t3.join(t4, onclause=t3.c.ps_suppkey == t4.c.s_suppkey).join(
t2, onclause=t2.c.n_nationkey == t4.c.s_nationkey
)
)
.where(sa.column("n_name") == NATION)
.group_by(sa.column("ps_partkey"))
.where(t2.c.n_name == NATION)
.group_by(t3.c.ps_partkey)
).alias("t1")

anon_1 = (
sa.select(
sa.func.sum(sa.column("ps_supplycost") * sa.column("ps_availqty")).label(
"total"
)
)
sa.select(sa.func.sum(t3.c.ps_supplycost * t3.c.ps_availqty).label("total"))
.select_from(
t3.join(t4, onclause=t3.c.ps_suppkey == t4.c.s_suppkey).join(
t2, onclause=t2.c.n_nationkey == t4.c.s_nationkey
)
)
.where(sa.column("n_name") == NATION)
.where(t2.c.n_name == NATION)
.alias("anon_1")
)

Expand Down

0 comments on commit e4011aa

Please sign in to comment.