diff --git a/ibis/backends/base/__init__.py b/ibis/backends/base/__init__.py index 17dd35418748..af788901ec2b 100644 --- a/ibis/backends/base/__init__.py +++ b/ibis/backends/base/__init__.py @@ -786,10 +786,11 @@ def list_tables( Parameters ---------- - like : str, optional + like A pattern in Python's regex format. - database : str, optional - The database to list tables of, if not the current one. + database + The database from which to list tables. If not provided, the + current database is used. Returns ------- diff --git a/ibis/backends/base/sql/alchemy/__init__.py b/ibis/backends/base/sql/alchemy/__init__.py index 65e43d9764de..76586b89f28c 100644 --- a/ibis/backends/base/sql/alchemy/__init__.py +++ b/ibis/backends/base/sql/alchemy/__init__.py @@ -877,3 +877,83 @@ def drop_view( with self.begin() as con: con.execute(view) + + +class AlchemyCrossSchemaBackend(BaseAlchemyBackend): + """A SQLAlchemy backend that supports cross-schema queries. + + This backend differs from the default SQLAlchemy backend in that it + overrides `_get_sqla_table` to potentially switch schemas during table + reflection, if the table requested lives in a different schema than the + currently active one. + """ + + @property + @abc.abstractmethod + def use_stmt_prefix(self) -> str: + """The prefix to use for switching schemas. + + Common examples are `USE` and `USE SCHEMA`. + """ + + @contextlib.contextmanager + def _use_schema(self, ident: str, current_db: str, current_schema: str) -> None: + use_prefix = self.use_stmt_prefix + + try: + with self.begin() as c: + c.exec_driver_sql(f"{use_prefix} {ident}") + yield + finally: + with self.begin() as c: + c.exec_driver_sql( + f"{use_prefix} {self._quote(current_db)}.{self._quote(current_schema)}" + ) + + def _get_sqla_table( + self, + name: str, + schema: str | None = None, + database: str | None = None, + **_: Any, + ) -> sa.Table: + current_db = self.current_database + current_schema = self.current_schema + if schema is None: + schema = current_schema + *db, schema = schema.split(".") + db = "".join(db) or database or current_db + ident = ".".join(map(self._quote, filter(None, (db, schema)))) + + pairs = self._metadata(f"SELECT * FROM {ident}.{self._quote(name)}") + ibis_schema = ibis.schema(pairs) + + with self._use_schema(ident, current_db, current_schema): + result = self._table_from_schema(name, schema=ibis_schema) + result.schema = schema + return result + + def drop_table( + self, name: str, database: str | None = None, force: bool = False + ) -> None: + name = self._quote(name) + # TODO: handle database quoting + if database is not None: + name = f"{database}.{name}" + drop_stmt = "DROP TABLE" + (" IF EXISTS" * force) + f" {name}" + with self.begin() as con: + con.exec_driver_sql(drop_stmt) + + +@compiles(sa.Table, "snowflake") +def compile_table(element, compiler, **kw): + """Override compilation of leaf tables. + + The override is necessary because the dialect does not handle database + hierarchies and/or quoting properly. + """ + schema = element.schema + name = compiler.preparer.quote_identifier(element.name) + if schema is not None: + return f"{schema}.{name}" + return name diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 0c97b13d10ef..136e0d512e80 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -30,8 +30,8 @@ from ibis.backends.base.sql.alchemy import ( AlchemyCanCreateSchema, AlchemyCompiler, + AlchemyCrossSchemaBackend, AlchemyExprTranslator, - BaseAlchemyBackend, ) from ibis.backends.snowflake.converter import SnowflakePandasData from ibis.backends.snowflake.datatypes import SnowflakeType, parse @@ -108,11 +108,12 @@ class SnowflakeCompiler(AlchemyCompiler): } -class Backend(BaseAlchemyBackend, CanCreateDatabase, AlchemyCanCreateSchema): +class Backend(AlchemyCrossSchemaBackend, CanCreateDatabase, AlchemyCanCreateSchema): name = "snowflake" compiler = SnowflakeCompiler supports_create_or_replace = True supports_python_udfs = True + use_stmt_prefix = "USE SCHEMA" _latest_udf_python_version = (3, 10) @@ -400,43 +401,6 @@ def _make_batch_iter( for t in cur.cursor.fetch_arrow_batches() ) - @contextlib.contextmanager - def _use_schema(self, ident: str, fallback: str): - if ident == fallback: - yield - else: - try: - with self.begin() as c: - c.exec_driver_sql(f"USE SCHEMA {ident}") - yield - finally: - with self.begin() as c: - c.exec_driver_sql(f"USE SCHEMA {fallback}") - - def _get_sqla_table( - self, - name: str, - schema: str | None = None, - database: str | None = None, - **_: Any, - ) -> sa.Table: - current_db = self.current_database - current_schema = self.current_schema - if schema is None: - schema = current_schema - *db, schema = schema.split(".") - db = "".join(db) or database or current_db - ident = ".".join(map(self._quote, filter(None, (db, schema)))) - - pairs = self._metadata(f"SELECT * FROM {ident}.{self._quote(name)}") - ibis_schema = ibis.schema(pairs) - - fallback = f"{self._quote(current_db)}.{self._quote(current_schema)}" - with self._use_schema(ident, fallback=fallback): - result = self._table_from_schema(name, schema=ibis_schema) - result.schema = ident - return result - def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]: with self.begin() as con, con.connection.cursor() as cur: result = cur.describe(query) @@ -855,17 +819,3 @@ def compile_join(element, compiler, **kw): if element.right._is_lateral: return re.sub(r"^(.+) ON true$", r"\1", result, flags=re.IGNORECASE | re.DOTALL) return result - - -@compiles(sa.Table, "snowflake") -def compile_table(element, compiler, **kw): - """Override compilation of leaf tables. - - The override is necessary because snowflake-sqlalchemy does not handle - quoting databases and schemas correctly. - """ - schema = element.schema - name = compiler.preparer.quote_identifier(element.name) - if schema is not None: - return f"{schema}.{name}" - return name diff --git a/ibis/backends/trino/__init__.py b/ibis/backends/trino/__init__.py index 08c1135fdf88..9abc2770a6b7 100644 --- a/ibis/backends/trino/__init__.py +++ b/ibis/backends/trino/__init__.py @@ -19,7 +19,10 @@ import ibis.expr.types as ir from ibis import util from ibis.backends.base import CanListDatabases -from ibis.backends.base.sql.alchemy import AlchemyCanCreateSchema, BaseAlchemyBackend +from ibis.backends.base.sql.alchemy import ( + AlchemyCanCreateSchema, + AlchemyCrossSchemaBackend, +) from ibis.backends.base.sql.alchemy.datatypes import ArrayType from ibis.backends.trino.compiler import TrinoSQLCompiler from ibis.backends.trino.datatypes import ROW, TrinoType, parse @@ -32,11 +35,12 @@ import ibis.expr.schema as sch -class Backend(BaseAlchemyBackend, AlchemyCanCreateSchema, CanListDatabases): +class Backend(AlchemyCrossSchemaBackend, AlchemyCanCreateSchema, CanListDatabases): name = "trino" compiler = TrinoSQLCompiler supports_create_or_replace = False supports_temporary_tables = False + use_stmt_prefix = "USE" @cached_property def version(self) -> str: @@ -68,6 +72,19 @@ def list_schemas( def current_schema(self) -> str: return self._scalar_query(sa.select(sa.literal_column("current_schema"))) + def list_tables( + self, like: str | None = None, database: str | None = None + ) -> list[str]: + query = "SHOW TABLES" + + if database is not None: + query += f" IN {database}" + + with self.begin() as con: + tables = list(con.exec_driver_sql(query).scalars()) + + return self._filter_with_like(tables, like=like) + def do_connect( self, user: str = "user", @@ -330,3 +347,19 @@ def literal_compile(v): ) return self.table(orig_table_ref) + + def _table_from_schema( + self, + name: str, + schema: sch.Schema, + temp: bool = False, + database: str | None = None, + **kwargs: Any, + ) -> sa.Table: + return super()._table_from_schema( + name, + schema, + temp=temp, + trino_catalog=database or self.current_database, + **kwargs, + )