Skip to content

Commit

Permalink
fix: support self joins on memtables
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Oct 12, 2023
1 parent f2ae7cc commit f24e355
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 54 deletions.
74 changes: 37 additions & 37 deletions ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,79 +82,79 @@ def _get_join_type(self, op):

def _format_table(self, op):
ctx = self.context
ref_op = op

orig_op = op
if isinstance(op, ops.SelfReference):
ref_op = op.table
op = op.table

alias = ctx.get_ref(op)
alias = ctx.get_ref(orig_op)

translator = ctx.compiler.translator_class(ref_op, ctx)
translator = ctx.compiler.translator_class(op, ctx)

if isinstance(ref_op, ops.DatabaseTable):
result = ref_op.source._get_sqla_table(ref_op.name, schema=ref_op.namespace)
elif isinstance(ref_op, ops.UnboundTable):
if isinstance(op, ops.DatabaseTable):
result = op.source._get_sqla_table(op.name, schema=op.namespace)
elif isinstance(op, ops.UnboundTable):
# use SQLAlchemy's TableClause for unbound tables
result = sa.Table(
ref_op.name,
op.name,
sa.MetaData(),
*translator._schema_to_sqlalchemy_columns(ref_op.schema),
*translator._schema_to_sqlalchemy_columns(op.schema),
quote=translator._quote_table_names,
)
elif isinstance(ref_op, ops.SQLQueryResult):
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
result = sa.text(ref_op.query).columns(*columns)
elif isinstance(ref_op, ops.SQLStringView):
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
result = sa.text(ref_op.query).columns(*columns).cte(ref_op.name)
elif isinstance(ref_op, ops.View):
elif isinstance(op, ops.SQLQueryResult):
columns = translator._schema_to_sqlalchemy_columns(op.schema)
result = sa.text(op.query).columns(*columns)
elif isinstance(op, ops.SQLStringView):
columns = translator._schema_to_sqlalchemy_columns(op.schema)
result = sa.text(op.query).columns(*columns).cte(op.name)
elif isinstance(op, ops.View):
# TODO(kszucs): avoid converting to expression
child_expr = ref_op.child.to_expr()
child_expr = op.child.to_expr()
definition = child_expr.compile()
result = sa.Table(
ref_op.name,
op.name,
sa.MetaData(),
*translator._schema_to_sqlalchemy_columns(ref_op.schema),
*translator._schema_to_sqlalchemy_columns(op.schema),
quote=translator._quote_table_names,
)
backend = child_expr._find_backend()
backend._create_temp_view(view=result, definition=definition)
elif isinstance(ref_op, ops.InMemoryTable):
result = self._format_in_memory_table(op, ref_op, translator)
elif isinstance(ref_op, ops.DummyTable):
elif isinstance(op, ops.InMemoryTable):
result = self._format_in_memory_table(op, translator)
elif isinstance(op, ops.DummyTable):
result = sa.select(
*(
translator.translate(value).label(name)
for name, value in zip(ref_op.schema.names, ref_op.values)
for name, value in zip(op.schema.names, op.values)
)
)
else:
# A subquery
if ctx.is_extracted(ref_op):
if ctx.is_extracted(op):
# Was put elsewhere, e.g. WITH block, we just need to grab
# its alias
alias = ctx.get_ref(op)
alias = ctx.get_ref(orig_op)

# hack
if isinstance(op, ops.SelfReference):
table = ctx.get_ref(ref_op)
if isinstance(orig_op, ops.SelfReference):
table = ctx.get_ref(op)
self_ref = alias if hasattr(alias, "name") else table.alias(alias)
ctx.set_ref(op, self_ref)
ctx.set_ref(orig_op, self_ref)
return self_ref
return alias

alias = ctx.get_ref(op)
result = ctx.get_compiled_expr(op)
alias = ctx.get_ref(orig_op)
result = ctx.get_compiled_expr(orig_op)

result = alias if hasattr(alias, "name") else result.alias(alias)
ctx.set_ref(op, result)
ctx.set_ref(orig_op, result)
return result

def _format_in_memory_table(self, op, ref_op, translator):
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
def _format_in_memory_table(self, op, translator):
columns = translator._schema_to_sqlalchemy_columns(op.schema)
if self.context.compiler.cheap_in_memory_tables:
result = sa.Table(
ref_op.name,
op.name,
sa.MetaData(),
*columns,
quote=translator._quote_table_names,
Expand All @@ -167,8 +167,8 @@ def _format_in_memory_table(self, op, ref_op, translator):
)
).limit(0)
elif self.context.compiler.support_values_syntax_in_select:
rows = list(ref_op.data.to_frame().itertuples(index=False))
result = sa.values(*columns, name=ref_op.name).data(rows)
rows = list(op.data.to_frame().itertuples(index=False))
result = sa.values(*columns, name=op.name).data(rows)
else:
raw_rows = (
sa.select(
Expand All @@ -179,7 +179,7 @@ def _format_in_memory_table(self, op, ref_op, translator):
)
for row in op.data.to_frame().itertuples(index=False)
)
result = sa.union_all(*raw_rows).alias(ref_op.name)
result = sa.union_all(*raw_rows).alias(op.name)
return result


Expand Down
26 changes: 13 additions & 13 deletions ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,41 +98,41 @@ def _format_table(self, op):
# TODO: This could probably go in a class and be significantly nicer
ctx = self.context

ref_op = op
orig_op = op
if isinstance(op, ops.SelfReference):
ref_op = op.table
op = op.table

if isinstance(ref_op, ops.InMemoryTable):
result = self._format_in_memory_table(ref_op)
elif isinstance(ref_op, ops.PhysicalTable):
if isinstance(op, ops.InMemoryTable):
result = self._format_in_memory_table(op)
elif isinstance(op, ops.PhysicalTable):
# TODO(kszucs): add a mandatory `name` field to the base
# PhyisicalTable instead of the child classes, this should prevent
# this error scenario
if (name := ref_op.name) is None:
if (name := op.name) is None:
raise com.RelationError(f"Table did not have a name: {op!r}")

result = sg.table(
name,
db=getattr(ref_op, "namespace", None),
db=getattr(op, "namespace", None),
quoted=self.parent.translator_class._quote_identifiers,
).sql(dialect=self.parent.translator_class._dialect_name)
else:
# A subquery
if ctx.is_extracted(ref_op):
if ctx.is_extracted(op):
# Was put elsewhere, e.g. WITH block, we just need to grab its
# alias
alias = ctx.get_ref(op)
alias = ctx.get_ref(orig_op)

# HACK: self-references have to be treated more carefully here
if isinstance(op, ops.SelfReference):
return f"{ctx.get_ref(ref_op)} {alias}"
if isinstance(orig_op, ops.SelfReference):
return f"{ctx.get_ref(op)} {alias}"
else:
return alias

subquery = ctx.get_compiled_expr(op)
subquery = ctx.get_compiled_expr(orig_op)
result = f"(\n{util.indent(subquery, self.indent)}\n)"

result += f" {ctx.get_ref(op)}"
result += f" {ctx.get_ref(orig_op)}"

return result

Expand Down
15 changes: 15 additions & 0 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,21 @@ def test_agg_memory_table(con):
assert result == 3


@pytest.mark.broken(
["polars"], reason="join column renaming is currently incorrect on polars"
)
@pytest.mark.notimpl(["datafusion"])
def test_self_join_memory_table(backend, con):
t = ibis.memtable({"x": [1, 2], "y": [2, 1], "z": ["a", "b"]})
t_view = t.view()
expr = t.join(t_view, t.x == t_view.y).select("x", "y", "z", "z_right")
result = con.execute(expr).sort_values("x").reset_index(drop=True)
expected = pd.DataFrame(
{"x": [1, 2], "y": [2, 1], "z": ["a", "b"], "z_right": ["b", "a"]}
)
backend.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"t",
[
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _rewrite_string_contains(op):


class TrinoTableSetFormatter(_AlchemyTableSetFormatter):
def _format_in_memory_table(self, op, ref_op, translator):
def _format_in_memory_table(self, op, translator):
if not op.data:
return sa.select(
*(
Expand All @@ -64,10 +64,10 @@ def _format_in_memory_table(self, op, ref_op, translator):
translator.translate(ops.Literal(col, dtype=type_)).label(name)
for col, (name, type_) in zip(row, op_schema)
)
for row in ref_op.data.to_frame().itertuples(index=False)
for row in op.data.to_frame().itertuples(index=False)
]
columns = translator._schema_to_sqlalchemy_columns(ref_op.schema)
return sa.values(*columns, name=ref_op.name).data(rows)
columns = translator._schema_to_sqlalchemy_columns(op.schema)
return sa.values(*columns, name=op.name).data(rows)


class TrinoSQLCompiler(AlchemyCompiler):
Expand Down

0 comments on commit f24e355

Please sign in to comment.