diff --git a/ibis/backends/base/sql/alchemy/query_builder.py b/ibis/backends/base/sql/alchemy/query_builder.py index 3155aa4d2a21..f8cac84ca81d 100644 --- a/ibis/backends/base/sql/alchemy/query_builder.py +++ b/ibis/backends/base/sql/alchemy/query_builder.py @@ -82,79 +82,79 @@ def _get_join_type(self, op): def _format_table(self, op): ctx = self.context - ref_op = op + orig_op = op if isinstance(op, ops.SelfReference): - ref_op = op.table + op = op.table - alias = ctx.get_ref(op) + alias = ctx.get_ref(orig_op) - translator = ctx.compiler.translator_class(ref_op, ctx) + translator = ctx.compiler.translator_class(op, ctx) - if isinstance(ref_op, ops.DatabaseTable): - result = ref_op.source._get_sqla_table(ref_op.name, schema=ref_op.namespace) - elif isinstance(ref_op, ops.UnboundTable): + if isinstance(op, ops.DatabaseTable): + result = op.source._get_sqla_table(op.name, schema=op.namespace) + elif isinstance(op, ops.UnboundTable): # use SQLAlchemy's TableClause for unbound tables result = sa.Table( - ref_op.name, + op.name, sa.MetaData(), - *translator._schema_to_sqlalchemy_columns(ref_op.schema), + *translator._schema_to_sqlalchemy_columns(op.schema), quote=translator._quote_table_names, ) - elif isinstance(ref_op, ops.SQLQueryResult): - columns = translator._schema_to_sqlalchemy_columns(ref_op.schema) - result = sa.text(ref_op.query).columns(*columns) - elif isinstance(ref_op, ops.SQLStringView): - columns = translator._schema_to_sqlalchemy_columns(ref_op.schema) - result = sa.text(ref_op.query).columns(*columns).cte(ref_op.name) - elif isinstance(ref_op, ops.View): + elif isinstance(op, ops.SQLQueryResult): + columns = translator._schema_to_sqlalchemy_columns(op.schema) + result = sa.text(op.query).columns(*columns) + elif isinstance(op, ops.SQLStringView): + columns = translator._schema_to_sqlalchemy_columns(op.schema) + result = sa.text(op.query).columns(*columns).cte(op.name) + elif isinstance(op, ops.View): # TODO(kszucs): avoid converting to expression - child_expr = ref_op.child.to_expr() + child_expr = op.child.to_expr() definition = child_expr.compile() result = sa.Table( - ref_op.name, + op.name, sa.MetaData(), - *translator._schema_to_sqlalchemy_columns(ref_op.schema), + *translator._schema_to_sqlalchemy_columns(op.schema), quote=translator._quote_table_names, ) backend = child_expr._find_backend() backend._create_temp_view(view=result, definition=definition) - elif isinstance(ref_op, ops.InMemoryTable): - result = self._format_in_memory_table(op, ref_op, translator) - elif isinstance(ref_op, ops.DummyTable): + elif isinstance(op, ops.InMemoryTable): + result = self._format_in_memory_table(op, translator) + elif isinstance(op, ops.DummyTable): result = sa.select( *( translator.translate(value).label(name) - for name, value in zip(ref_op.schema.names, ref_op.values) + for name, value in zip(op.schema.names, op.values) ) ) else: # A subquery - if ctx.is_extracted(ref_op): + if ctx.is_extracted(op): # Was put elsewhere, e.g. WITH block, we just need to grab # its alias - alias = ctx.get_ref(op) + alias = ctx.get_ref(orig_op) # hack - if isinstance(op, ops.SelfReference): - table = ctx.get_ref(ref_op) + if isinstance(orig_op, ops.SelfReference): + table = ctx.get_ref(op) self_ref = alias if hasattr(alias, "name") else table.alias(alias) - ctx.set_ref(op, self_ref) + ctx.set_ref(orig_op, self_ref) return self_ref return alias - alias = ctx.get_ref(op) - result = ctx.get_compiled_expr(op) + alias = ctx.get_ref(orig_op) + result = ctx.get_compiled_expr(orig_op) result = alias if hasattr(alias, "name") else result.alias(alias) - ctx.set_ref(op, result) + ctx.set_ref(orig_op, result) return result - def _format_in_memory_table(self, op, ref_op, translator): - columns = translator._schema_to_sqlalchemy_columns(ref_op.schema) + def _format_in_memory_table(self, op, translator): + columns = translator._schema_to_sqlalchemy_columns(op.schema) if self.context.compiler.cheap_in_memory_tables: result = sa.Table( - ref_op.name, + op.name, sa.MetaData(), *columns, quote=translator._quote_table_names, @@ -167,8 +167,8 @@ def _format_in_memory_table(self, op, ref_op, translator): ) ).limit(0) elif self.context.compiler.support_values_syntax_in_select: - rows = list(ref_op.data.to_frame().itertuples(index=False)) - result = sa.values(*columns, name=ref_op.name).data(rows) + rows = list(op.data.to_frame().itertuples(index=False)) + result = sa.values(*columns, name=op.name).data(rows) else: raw_rows = ( sa.select( @@ -179,7 +179,7 @@ def _format_in_memory_table(self, op, ref_op, translator): ) for row in op.data.to_frame().itertuples(index=False) ) - result = sa.union_all(*raw_rows).alias(ref_op.name) + result = sa.union_all(*raw_rows).alias(op.name) return result diff --git a/ibis/backends/base/sql/compiler/query_builder.py b/ibis/backends/base/sql/compiler/query_builder.py index 7fcc9738117d..b989e22e4ca3 100644 --- a/ibis/backends/base/sql/compiler/query_builder.py +++ b/ibis/backends/base/sql/compiler/query_builder.py @@ -98,41 +98,41 @@ def _format_table(self, op): # TODO: This could probably go in a class and be significantly nicer ctx = self.context - ref_op = op + orig_op = op if isinstance(op, ops.SelfReference): - ref_op = op.table + op = op.table - if isinstance(ref_op, ops.InMemoryTable): - result = self._format_in_memory_table(ref_op) - elif isinstance(ref_op, ops.PhysicalTable): + if isinstance(op, ops.InMemoryTable): + result = self._format_in_memory_table(op) + elif isinstance(op, ops.PhysicalTable): # TODO(kszucs): add a mandatory `name` field to the base # PhyisicalTable instead of the child classes, this should prevent # this error scenario - if (name := ref_op.name) is None: + if (name := op.name) is None: raise com.RelationError(f"Table did not have a name: {op!r}") result = sg.table( name, - db=getattr(ref_op, "namespace", None), + db=getattr(op, "namespace", None), quoted=self.parent.translator_class._quote_identifiers, ).sql(dialect=self.parent.translator_class._dialect_name) else: # A subquery - if ctx.is_extracted(ref_op): + if ctx.is_extracted(op): # Was put elsewhere, e.g. WITH block, we just need to grab its # alias - alias = ctx.get_ref(op) + alias = ctx.get_ref(orig_op) # HACK: self-references have to be treated more carefully here - if isinstance(op, ops.SelfReference): - return f"{ctx.get_ref(ref_op)} {alias}" + if isinstance(orig_op, ops.SelfReference): + return f"{ctx.get_ref(op)} {alias}" else: return alias - subquery = ctx.get_compiled_expr(op) + subquery = ctx.get_compiled_expr(orig_op) result = f"(\n{util.indent(subquery, self.indent)}\n)" - result += f" {ctx.get_ref(op)}" + result += f" {ctx.get_ref(orig_op)}" return result diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index df7031ea0c0f..652f0ea3537e 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -814,6 +814,21 @@ def test_agg_memory_table(con): assert result == 3 +@pytest.mark.broken( + ["polars"], reason="join column renaming is currently incorrect on polars" +) +@pytest.mark.notimpl(["datafusion"]) +def test_self_join_memory_table(backend, con): + t = ibis.memtable({"x": [1, 2], "y": [2, 1], "z": ["a", "b"]}) + t_view = t.view() + expr = t.join(t_view, t.x == t_view.y).select("x", "y", "z", "z_right") + result = con.execute(expr).sort_values("x").reset_index(drop=True) + expected = pd.DataFrame( + {"x": [1, 2], "y": [2, 1], "z": ["a", "b"], "z_right": ["b", "a"]} + ) + backend.assert_frame_equal(result, expected) + + @pytest.mark.parametrize( "t", [ diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index b484419273c4..a850a22bd463 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -49,7 +49,7 @@ def _rewrite_string_contains(op): class TrinoTableSetFormatter(_AlchemyTableSetFormatter): - def _format_in_memory_table(self, op, ref_op, translator): + def _format_in_memory_table(self, op, translator): if not op.data: return sa.select( *( @@ -64,10 +64,10 @@ def _format_in_memory_table(self, op, ref_op, translator): translator.translate(ops.Literal(col, dtype=type_)).label(name) for col, (name, type_) in zip(row, op_schema) ) - for row in ref_op.data.to_frame().itertuples(index=False) + for row in op.data.to_frame().itertuples(index=False) ] - columns = translator._schema_to_sqlalchemy_columns(ref_op.schema) - return sa.values(*columns, name=ref_op.name).data(rows) + columns = translator._schema_to_sqlalchemy_columns(op.schema) + return sa.values(*columns, name=op.name).data(rows) class TrinoSQLCompiler(AlchemyCompiler):