Skip to content

Commit

Permalink
feat(trino): cross-schema table support
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 17, 2023
1 parent ae3e76e commit 9c7c65f
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 58 deletions.
7 changes: 4 additions & 3 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,10 +786,11 @@ def list_tables(
Parameters
----------
like : str, optional
like
A pattern in Python's regex format.
database : str, optional
The database to list tables of, if not the current one.
database
The database from which to list tables. If not provided, the
current database is used.
Returns
-------
Expand Down
80 changes: 80 additions & 0 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,3 +877,83 @@ def drop_view(

with self.begin() as con:
con.execute(view)


class AlchemyCrossSchemaBackend(BaseAlchemyBackend):
"""A SQLAlchemy backend that supports cross-schema queries.
This backend differs from the default SQLAlchemy backend in that it
overrides `_get_sqla_table` to potentially switch schemas during table
reflection, if the table requested lives in a different schema than the
currently active one.
"""

@property
@abc.abstractmethod
def use_stmt_prefix(self) -> str:
"""The prefix to use for switching schemas.
Common examples are `USE` and `USE SCHEMA`.
"""

@contextlib.contextmanager
def _use_schema(self, ident: str, current_db: str, current_schema: str) -> None:
use_prefix = self.use_stmt_prefix

try:
with self.begin() as c:
c.exec_driver_sql(f"{use_prefix} {ident}")
yield
finally:
with self.begin() as c:
c.exec_driver_sql(
f"{use_prefix} {self._quote(current_db)}.{self._quote(current_schema)}"
)

def _get_sqla_table(
self,
name: str,
schema: str | None = None,
database: str | None = None,
**_: Any,
) -> sa.Table:
current_db = self.current_database
current_schema = self.current_schema
if schema is None:
schema = current_schema
*db, schema = schema.split(".")
db = "".join(db) or database or current_db
ident = ".".join(map(self._quote, filter(None, (db, schema))))

pairs = self._metadata(f"SELECT * FROM {ident}.{self._quote(name)}")
ibis_schema = ibis.schema(pairs)

with self._use_schema(ident, current_db, current_schema):
result = self._table_from_schema(name, schema=ibis_schema)
result.schema = schema
return result

def drop_table(
self, name: str, database: str | None = None, force: bool = False
) -> None:
name = self._quote(name)
# TODO: handle database quoting
if database is not None:
name = f"{database}.{name}"
drop_stmt = "DROP TABLE" + (" IF EXISTS" * force) + f" {name}"
with self.begin() as con:
con.exec_driver_sql(drop_stmt)


@compiles(sa.Table, "snowflake")
def compile_table(element, compiler, **kw):
"""Override compilation of leaf tables.
The override is necessary because the dialect does not handle database
hierarchies and/or quoting properly.
"""
schema = element.schema
name = compiler.preparer.quote_identifier(element.name)
if schema is not None:
return f"{schema}.{name}"
return name
56 changes: 3 additions & 53 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from ibis.backends.base.sql.alchemy import (
AlchemyCanCreateSchema,
AlchemyCompiler,
AlchemyCrossSchemaBackend,
AlchemyExprTranslator,
BaseAlchemyBackend,
)
from ibis.backends.snowflake.converter import SnowflakePandasData
from ibis.backends.snowflake.datatypes import SnowflakeType, parse
Expand Down Expand Up @@ -108,11 +108,12 @@ class SnowflakeCompiler(AlchemyCompiler):
}


class Backend(BaseAlchemyBackend, CanCreateDatabase, AlchemyCanCreateSchema):
class Backend(AlchemyCrossSchemaBackend, CanCreateDatabase, AlchemyCanCreateSchema):
name = "snowflake"
compiler = SnowflakeCompiler
supports_create_or_replace = True
supports_python_udfs = True
use_stmt_prefix = "USE SCHEMA"

_latest_udf_python_version = (3, 10)

Expand Down Expand Up @@ -400,43 +401,6 @@ def _make_batch_iter(
for t in cur.cursor.fetch_arrow_batches()
)

@contextlib.contextmanager
def _use_schema(self, ident: str, fallback: str):
if ident == fallback:
yield
else:
try:
with self.begin() as c:
c.exec_driver_sql(f"USE SCHEMA {ident}")
yield
finally:
with self.begin() as c:
c.exec_driver_sql(f"USE SCHEMA {fallback}")

def _get_sqla_table(
self,
name: str,
schema: str | None = None,
database: str | None = None,
**_: Any,
) -> sa.Table:
current_db = self.current_database
current_schema = self.current_schema
if schema is None:
schema = current_schema
*db, schema = schema.split(".")
db = "".join(db) or database or current_db
ident = ".".join(map(self._quote, filter(None, (db, schema))))

pairs = self._metadata(f"SELECT * FROM {ident}.{self._quote(name)}")
ibis_schema = ibis.schema(pairs)

fallback = f"{self._quote(current_db)}.{self._quote(current_schema)}"
with self._use_schema(ident, fallback=fallback):
result = self._table_from_schema(name, schema=ibis_schema)
result.schema = ident
return result

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
with self.begin() as con, con.connection.cursor() as cur:
result = cur.describe(query)
Expand Down Expand Up @@ -855,17 +819,3 @@ def compile_join(element, compiler, **kw):
if element.right._is_lateral:
return re.sub(r"^(.+) ON true$", r"\1", result, flags=re.IGNORECASE | re.DOTALL)
return result


@compiles(sa.Table, "snowflake")
def compile_table(element, compiler, **kw):
"""Override compilation of leaf tables.
The override is necessary because snowflake-sqlalchemy does not handle
quoting databases and schemas correctly.
"""
schema = element.schema
name = compiler.preparer.quote_identifier(element.name)
if schema is not None:
return f"{schema}.{name}"
return name
37 changes: 35 additions & 2 deletions ibis/backends/trino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends.base import CanListDatabases
from ibis.backends.base.sql.alchemy import AlchemyCanCreateSchema, BaseAlchemyBackend
from ibis.backends.base.sql.alchemy import (
AlchemyCanCreateSchema,
AlchemyCrossSchemaBackend,
)
from ibis.backends.base.sql.alchemy.datatypes import ArrayType
from ibis.backends.trino.compiler import TrinoSQLCompiler
from ibis.backends.trino.datatypes import ROW, TrinoType, parse
Expand All @@ -32,11 +35,12 @@
import ibis.expr.schema as sch


class Backend(BaseAlchemyBackend, AlchemyCanCreateSchema, CanListDatabases):
class Backend(AlchemyCrossSchemaBackend, AlchemyCanCreateSchema, CanListDatabases):
name = "trino"
compiler = TrinoSQLCompiler
supports_create_or_replace = False
supports_temporary_tables = False
use_stmt_prefix = "USE"

@cached_property
def version(self) -> str:
Expand Down Expand Up @@ -68,6 +72,19 @@ def list_schemas(
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.literal_column("current_schema")))

def list_tables(
self, like: str | None = None, database: str | None = None
) -> list[str]:
query = "SHOW TABLES"

if database is not None:
query += f" IN {database}"

with self.begin() as con:
tables = list(con.exec_driver_sql(query).scalars())

return self._filter_with_like(tables, like=like)

def do_connect(
self,
user: str = "user",
Expand Down Expand Up @@ -330,3 +347,19 @@ def literal_compile(v):
)

return self.table(orig_table_ref)

def _table_from_schema(
self,
name: str,
schema: sch.Schema,
temp: bool = False,
database: str | None = None,
**kwargs: Any,
) -> sa.Table:
return super()._table_from_schema(
name,
schema,
temp=temp,
trino_catalog=database or self.current_database,
**kwargs,
)

0 comments on commit 9c7c65f

Please sign in to comment.