Skip to content

Commit

Permalink
refactor(sql): remove temporary table creation when using inline sql (#…
Browse files Browse the repository at this point in the history
…8149)

This PR fixes a long-standing annoyance with our `.sql` methods.
Previously we pooped a bunch of temporary tables or views (depending on
the backend), but after this PR the various `.sql` methods replace this
hack with much more vanilla CTEs.

---------

Co-authored-by: Krisztián Szűcs <szucs.krisztian@gmail.com>
  • Loading branch information
cpcloud and kszucs committed Feb 12, 2024
1 parent 4d24502 commit ea428ba
Show file tree
Hide file tree
Showing 35 changed files with 354 additions and 253 deletions.
35 changes: 14 additions & 21 deletions ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,20 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
"""Return an ibis Schema from a backend-specific SQL string."""
return sch.Schema.from_tuples(self._metadata(query))

def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema:
compiler = self.compiler
dialect = compiler.dialect

cte = self._to_sqlglot(table)
parsed = sg.parse_one(query, read=dialect)
parsed.args["with"] = cte.args.pop("with", [])
parsed = parsed.with_(
sg.to_identifier(name, quoted=compiler.quoted), as_=cte, dialect=dialect
)

sql = parsed.sql(dialect)
return self._get_schema_using_query(sql)

def create_view(
self,
name: str,
Expand Down Expand Up @@ -195,27 +209,6 @@ def drop_view(
with self._safe_raw_sql(src):
pass

def _get_temp_view_definition(self, name: str, definition: str) -> str:
return sge.Create(
this=sg.to_identifier(name, quoted=self.compiler.quoted),
kind="VIEW",
expression=definition,
replace=True,
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
)

def _create_temp_view(self, table_name, source):
if table_name not in self._temp_views and table_name in self.list_tables():
raise ValueError(
f"{table_name} already exists as a non-temporary table or view"
)

with self._safe_raw_sql(self._get_temp_view_definition(table_name, source)):
pass

self._temp_views.add(table_name)
self._register_temp_view_cleanup(table_name)

def _register_temp_view_cleanup(self, name: str) -> None:
"""Register a clean up function for a temporary view.
Expand Down
58 changes: 33 additions & 25 deletions ibis/backends/base/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,28 +261,36 @@ def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression:
op, ctes = sqlize(op)

aliases = {}
alias_counter = itertools.count()
counter = itertools.count()

def fn(node, _, **kwargs):
result = self.visit_node(node, **kwargs)

if node is op:
return result
elif isinstance(node, ops.Relation):
aliases[node] = alias = f"t{next(alias_counter)}"
alias = sg.to_identifier(alias, quoted=self.quoted)
try:
return result.subquery(alias)
except AttributeError:
return result.as_(alias, quoted=self.quoted)
else:
# if it's not a relation then we don't need to do anything special
if node is op or not isinstance(node, ops.Relation):
return result

# alias ops.Views to their explicitly assigned name otherwise generate
alias = node.name if isinstance(node, ops.View) else f"t{next(counter)}"
aliases[node] = alias

alias = sg.to_identifier(alias, quoted=self.quoted)
try:
return result.subquery(alias)
except AttributeError:
return result.as_(alias, quoted=self.quoted)

# apply translate rules in topological order
results = op.map(fn)

# get the root node as a sqlglot select statement
out = results[op]
out = out.this if isinstance(out, sge.Subquery) else out
if isinstance(out, sge.Table):
out = sg.select(STAR).from_(out)
elif isinstance(out, sge.Subquery):
out = out.this

# add cte definitions to the select statement
for cte in ctes:
alias = sg.to_identifier(aliases[cte], quoted=self.quoted)
out = out.with_(alias, as_=results[cte].this, dialect=self.dialect)
Expand Down Expand Up @@ -1222,27 +1230,27 @@ def visit_FillNa(self, op, *, parent, replacements):
}
return sg.select(*self._cleanup_names(exprs)).from_(parent)

@visit_node.register(ops.View)
def visit_View(self, op, *, child, name: str):
# TODO: find a way to do this without creating a temporary view
backend = op.child.to_expr()._find_backend()
backend._create_temp_view(table_name=name, source=sg.select(STAR).from_(child))
return sg.table(name, quoted=self.quoted)

@visit_node.register(CTE)
def visit_CTE(self, op, *, parent):
return sg.table(parent.alias_or_name, quoted=self.quoted)

@visit_node.register(ops.View)
def visit_View(self, op, *, child, name: str):
if isinstance(child, sge.Table):
child = sg.select(STAR).from_(child)

try:
return child.subquery(name)
except AttributeError:
return child.as_(name)

@visit_node.register(ops.SQLStringView)
def visit_SQLStringView(self, op, *, query: str, name: str, child):
table = sg.table(name, quoted=self.quoted)
return (
sg.select(STAR).from_(table).with_(table, as_=query, dialect=self.dialect)
)
def visit_SQLStringView(self, op, *, query: str, child, schema):
return sg.parse_one(query, read=self.dialect)

@visit_node.register(ops.SQLQueryResult)
def visit_SQLQueryResult(self, op, *, query, schema, source):
return sg.parse_one(query, read=self.dialect).subquery()
return sg.parse_one(query, dialect=self.dialect).subquery()

@visit_node.register(ops.JoinTable)
def visit_JoinTable(self, op, *, parent, index):
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/base/sqlglot/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def extract_ctes(node):

g = Graph.from_bfs(node, filter=(ops.Relation, ops.Subquery, ops.JoinLink))
for node, dependents in g.invert().items():
if len(dependents) > 1 and isinstance(node, cte_types):
if isinstance(node, ops.View) or (
len(dependents) > 1 and isinstance(node, cte_types)
):
result.append(node)

return result
Expand Down
18 changes: 0 additions & 18 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import ast
import atexit
import contextlib
import glob
from contextlib import closing
Expand Down Expand Up @@ -166,7 +165,6 @@ def do_connect(
compress=compression,
**kwargs,
)
self._temp_views = set()

@property
def version(self) -> str:
Expand Down Expand Up @@ -726,19 +724,3 @@ def create_view(
with self._safe_raw_sql(src, external_tables=external_tables):
pass
return self.table(name, database=database)

def _get_temp_view_definition(self, name: str, definition: str) -> str:
return sge.Create(
this=sg.to_identifier(name, quoted=self.compiler.quoted),
kind="VIEW",
expression=definition,
replace=True,
)

def _register_temp_view_cleanup(self, name: str) -> None:
def drop(self, name: str, query: str):
self.raw_sql(query)
self._temp_views.discard(name)

query = sge.Drop(this=sg.table(name), kind="VIEW", exists=True)
atexit.register(drop, self, name=name, query=query)
2 changes: 0 additions & 2 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ def do_connect(
for name, path in config.items():
self.register(path, table_name=name)

self._temp_views = set()

@contextlib.contextmanager
def _safe_raw_sql(self, sql: sge.Statement) -> Any:
yield self.raw_sql(sql)
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/druid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def do_connect(self, **kwargs: Any) -> None:
"""Create an Ibis client using the passed connection parameters."""
header = kwargs.pop("header", True)
self.con = pydruid.db.connect(**kwargs, header=header)
self._temp_views = set()

@contextlib.contextmanager
def _safe_raw_sql(self, query, *args, **kwargs):
Expand Down
13 changes: 13 additions & 0 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,3 +1534,16 @@ def insert(
table_name,
obj if isinstance(obj, pd.DataFrame) else pd.DataFrame(obj),
)

def _get_temp_view_definition(self, name: str, definition: str) -> str:
return sge.Create(
this=sg.to_identifier(name, quoted=self.compiler.quoted),
kind="VIEW",
expression=definition,
replace=True,
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
)

def _create_temp_view(self, table_name, source):
with self._safe_raw_sql(self._get_temp_view_definition(table_name, source)):
pass
3 changes: 1 addition & 2 deletions ibis/backends/exasol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from ibis.backends.base import BaseBackend

# strip trailing encodings e.g., UTF8
_VARCHAR_REGEX = re.compile(r"^(VARCHAR(?:\(\d+\)))?(?:\s+.+)?$")
_VARCHAR_REGEX = re.compile(r"^((VAR)?CHAR(?:\(\d+\)))?(?:\s+.+)?$")


class Backend(SQLGlotBackend):
Expand Down Expand Up @@ -90,7 +90,6 @@ def do_connect(
quote_ident=True,
**kwargs,
)
self._temp_views = set()

def _from_url(self, url: str, **kwargs) -> BaseBackend:
"""Construct an ibis backend from a SQLAlchemy-conforming URL."""
Expand Down
17 changes: 0 additions & 17 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import atexit
import contextlib
import datetime
import struct
Expand Down Expand Up @@ -92,7 +91,6 @@ def do_connect(
cur.execute("SET DATEFIRST 1")

self.con = con
self._temp_views = set()

def get_schema(
self, name: str, schema: str | None = None, database: str | None = None
Expand Down Expand Up @@ -244,13 +242,6 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
con.commit()
return cursor

def _get_temp_view_definition(self, name: str, definition) -> str:
return sge.Create(
kind="OR ALTER VIEW",
this=sg.to_identifier(name, quoted=self.compiler.quoted),
expression=definition,
)

def create_database(self, name: str, force: bool = False) -> None:
name = self._quote(name)
create_stmt = (
Expand Down Expand Up @@ -462,14 +453,6 @@ def create_table(
name, schema=schema, source=self, namespace=ops.Namespace(database=database)
).to_expr()

def _register_temp_view_cleanup(self, name: str) -> None:
def drop(self, name: str, query: str):
self.raw_sql(query)
self._temp_views.discard(name)

query = sge.Drop(this=sg.table(name), kind="VIEW", exists=True)
atexit.register(drop, self, name=name, query=query)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,6 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit):
return self.f.dateadd(self.v.s, arg / 1_000, "1970-01-01 00:00:00")
raise com.UnsupportedOperationError(f"{unit!r} unit is not supported!")

@visit_node.register(ops.SQLStringView)
def visit_SQLStringView(self, op, *, query: str, name: str, child):
return sg.parse_one(query, read=self.dialect).subquery(name)

def visit_NonNullLiteral(self, op, *, value, dtype):
if dtype.is_decimal():
return self.cast(str(value.normalize()), dtype)
Expand Down
18 changes: 0 additions & 18 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import atexit
import contextlib
import re
import warnings
Expand Down Expand Up @@ -174,7 +173,6 @@ def do_connect(
warnings.warn(f"Unable to set session timezone to UTC: {e}")

self.con = con
self._temp_views = set()

@property
def current_database(self) -> str:
Expand Down Expand Up @@ -222,14 +220,6 @@ def get_schema(

return sch.Schema(fields)

def _get_temp_view_definition(self, name: str, definition: str) -> str:
return sge.Create(
kind="VIEW",
replace=True,
this=sg.to_identifier(name, quoted=self.compiler.quoted),
expression=definition,
)

def create_database(self, name: str, force: bool = False) -> None:
sql = sge.Create(kind="DATABASE", exist=force, this=sg.to_identifier(name)).sql(
self.name
Expand Down Expand Up @@ -509,11 +499,3 @@ def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
raise
df = MySQLPandasData.convert_table(df, schema)
return df

def _register_temp_view_cleanup(self, name: str) -> None:
def drop(self, name: str, query: str):
self.raw_sql(query)
self._temp_views.discard(name)

query = sge.Drop(this=sg.table(name), kind="VIEW", exists=True)
atexit.register(drop, self, name=name, query=query)
9 changes: 0 additions & 9 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,6 @@ def do_connect(
with self.begin() as cur:
cur.execute("SET TIMEZONE = UTC")

self._temp_views = set()

def list_tables(
self,
like: str | None = None,
Expand Down Expand Up @@ -552,13 +550,6 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
with self._safe_raw_sql(drop_stmt):
pass

def _get_temp_view_definition(self, name: str, definition):
drop = sge.Drop(
kind="VIEW", exists=True, this=sg.table(name), cascade=True
).sql(self.name)
create = super()._get_temp_view_definition(name, definition)
return f"{drop}; {create}"

def create_schema(
self, name: str, database: str | None = None, force: bool = False
) -> None:
Expand Down
9 changes: 0 additions & 9 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import atexit
import contextlib
import os
from pathlib import Path
Expand Down Expand Up @@ -156,7 +155,6 @@ def do_connect(self, session: SparkSession) -> None:
# https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html#timestamp-with-time-zone-semantics
self._session.conf.set("spark.sql.session.timeZone", "UTC")
self._session.conf.set("spark.sql.mapKeyDedupPolicy", "LAST_WIN")
self._temp_views = set()

def _metadata(self, query: str):
cursor = self.raw_sql(query)
Expand Down Expand Up @@ -235,13 +233,6 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
df = self._session.createDataFrame(data=op.data.to_frame(), schema=schema)
df.createOrReplaceTempView(op.name)

def _register_temp_view_cleanup(self, name: str) -> None:
def drop(self, name: str):
self._session.catalog.dropTempView(name)
self._temp_views.discard(name)

atexit.register(drop, self, name=name)

def _fetch_from_cursor(self, cursor, schema):
df = cursor.query.toPandas() # blocks until finished
return PySparkPandasData.convert_table(df, schema)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
WITH foo AS (
SELECT
*
FROM `ibis-gbq`.ibis_gbq_testing.test_bigquery_temp_mem_t_for_cte AS t0
)
SELECT
COUNT(*) AS `x`
FROM `foo`
Loading

0 comments on commit ea428ba

Please sign in to comment.