Skip to content

Commit

Permalink
feat(sql): use temp views where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Mar 31, 2023
1 parent eec3706 commit 5b9d8c0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 36 deletions.
58 changes: 24 additions & 34 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def _register_failure(self):
f"please call one of {msg} directly"
)

def _compile_temp_view(self, table_name, source):
raw_source = source.compile(
dialect=self.con.dialect, compile_kwargs=dict(literal_binds=True)
)
return f'CREATE OR REPLACE TEMPORARY VIEW "{table_name}" AS {raw_source}'

@util.experimental
def read_json(
self,
Expand All @@ -263,7 +269,6 @@ def read_json(
Table
An ibis table expression
"""
import sqlalchemy_views as sav
from packaging.version import parse as vparse

if (version := vparse(self.version)) < vparse("0.7.0"):
Expand All @@ -273,18 +278,15 @@ def read_json(
if not table_name:
table_name = f"ibis_read_json_{next(json_n)}"

view = sav.CreateView(
sa.table(table_name),
sa.select(sa.literal_column("*")).select_from(
sa.func.read_json_auto(
sa.func.list_value(*normalize_filenames(source_list)),
_format_kwargs(kwargs),
)
),
or_replace=True,
source = sa.select(sa.literal_column("*")).select_from(
sa.func.read_json_auto(
sa.func.list_value(*normalize_filenames(source_list)),
_format_kwargs(kwargs),
)
)
view = self._compile_temp_view(table_name, source)
with self.begin() as con:
con.execute(view)
con.exec_driver_sql(view)

return self.table(table_name)

Expand Down Expand Up @@ -313,8 +315,6 @@ def read_csv(
ir.Table
The just-registered table
"""
import sqlalchemy_views as sav

source_list = normalize_filenames(source_list)

if not table_name:
Expand All @@ -329,9 +329,10 @@ def read_csv(
source = sa.select(sa.literal_column("*")).select_from(
sa.func.read_csv(sa.func.list_value(*source_list), _format_kwargs(kwargs))
)
view = sav.CreateView(sa.table(table_name), source, or_replace=True)

view = self._compile_temp_view(table_name, source)
with self.begin() as con:
con.execute(view)
con.exec_driver_sql(view)
return self.table(table_name)

def read_parquet(
Expand Down Expand Up @@ -377,13 +378,8 @@ def read_parquet(
return self.table(table_name)

def _read_parquet_duckdb_native(
self,
source_list: str | Iterable[str],
table_name: str,
**kwargs: Any,
self, source_list: str | Iterable[str], table_name: str, **kwargs: Any
) -> None:
import sqlalchemy_views as sav

if any(
source.startswith(("http://", "https://", "s3://"))
for source in source_list
Expand All @@ -395,15 +391,12 @@ def _read_parquet_duckdb_native(
sa.func.list_value(*source_list), _format_kwargs(kwargs)
)
)
view = sav.CreateView(sa.table(table_name), source, or_replace=True)
view = self._compile_temp_view(table_name, source)
with self.begin() as con:
con.execute(view)
con.exec_driver_sql(view)

def _read_parquet_pyarrow_dataset(
self,
source_list: str | Iterable[str],
table_name: str,
**kwargs: Any,
self, source_list: str | Iterable[str], table_name: str, **kwargs: Any
) -> None:
import pyarrow.dataset as ds

Expand Down Expand Up @@ -496,8 +489,6 @@ def read_postgres(self, uri, table_name: str | None = None, schema: str = "publi
ir.Table
The just-registered table.
"""
import sqlalchemy_views as sav

if table_name is None:
raise ValueError(
"`table_name` is required when registering a postgres table"
Expand All @@ -506,9 +497,9 @@ def read_postgres(self, uri, table_name: str | None = None, schema: str = "publi
source = sa.select(sa.literal_column("*")).select_from(
sa.func.postgres_scan_pushdown(uri, schema, table_name)
)
view = sav.CreateView(sa.table(table_name), source, or_replace=True)
view = self._compile_temp_view(table_name, source)
with self.begin() as con:
con.execute(view)
con.exec_driver_sql(view)

return self.table(table_name)

Expand Down Expand Up @@ -540,17 +531,16 @@ def read_sqlite(self, path: str | Path, table_name: str | None = None) -> ir.Tab
3 0.29 Premium I VS2 62.4 58.0 334 4.20 4.23 2.63
4 0.31 Good J SI2 63.3 58.0 335 4.34 4.35 2.75
"""
import sqlalchemy_views as sav

if table_name is None:
raise ValueError("`table_name` is required when registering a sqlite table")
self._load_extensions(["sqlite"])
source = sa.select(sa.literal_column("*")).select_from(
sa.func.sqlite_scan(str(path), table_name)
)
view = sav.CreateView(sa.table(table_name), source, or_replace=True)
view = self._compile_temp_view(table_name, source)
with self.begin() as con:
con.execute(view)
con.exec_driver_sql(view)

return self.table(table_name)

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,4 +259,4 @@ def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
) -> str:
yield f"CREATE OR REPLACE VIEW {name} AS {definition}"
yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}"
2 changes: 1 addition & 1 deletion ibis/backends/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
) -> str:
yield f"DROP VIEW IF EXISTS {name}"
yield f"CREATE VIEW {name} AS {definition}"
yield f"CREATE TEMPORARY VIEW {name} AS {definition}"

def _get_compiled_statement(self, view: sa.Table, definition: sa.sql.Selectable):
return super()._get_compiled_statement(
Expand Down

0 comments on commit 5b9d8c0

Please sign in to comment.