diff --git a/ci/schema/snowflake.sql b/ci/schema/snowflake.sql index 291e36a29afc..93638382d33f 100644 --- a/ci/schema/snowflake.sql +++ b/ci/schema/snowflake.sql @@ -1,10 +1,10 @@ -CREATE OR REPLACE FILE FORMAT ibis_testing +CREATE OR REPLACE TEMP FILE FORMAT ibis_testing type = 'CSV' field_delimiter = ',' skip_header = 1 field_optionally_enclosed_by = '"'; -CREATE OR REPLACE STAGE ibis_testing file_format = ibis_testing; +CREATE OR REPLACE TEMP STAGE ibis_testing file_format = ibis_testing; CREATE OR REPLACE TABLE diamonds ( "carat" FLOAT, diff --git a/ibis/backends/base/sql/alchemy/__init__.py b/ibis/backends/base/sql/alchemy/__init__.py index 30f06a288516..ca21cadfc3d5 100644 --- a/ibis/backends/base/sql/alchemy/__init__.py +++ b/ibis/backends/base/sql/alchemy/__init__.py @@ -478,16 +478,8 @@ def table( Table expression """ if database is not None and database != self.current_database: - return self.database(database=database).table( - name=name, - database=database, - schema=schema, - ) - sqla_table = self._get_sqla_table( - name, - database=database, - schema=schema, - ) + return self.database(name=database).table(name=name, schema=schema) + sqla_table = self._get_sqla_table(name, database=database, schema=schema) return self._sqla_table_to_expr(sqla_table) def insert( diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 2c3ce3894dce..19d0a7a4099b 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -177,9 +177,22 @@ def connect(dbapi_connection, connection_record): def _get_sqla_table( self, name: str, schema: str | None = None, **_: Any ) -> sa.Table: - inspected = self.inspector.get_columns(name, schema) + with self.begin() as con: + cur_db, cur_schema = con.exec_driver_sql( + "SELECT CURRENT_DATABASE(), CURRENT_SCHEMA()" + ).fetchone() + if schema is not None: + con.exec_driver_sql(f"USE {schema}") + try: + inspected = self.inspector.get_columns( + name, + schema=schema.split(".", 2)[1] if schema is not None else schema, + ) + finally: + with self.begin() as con: + con.exec_driver_sql(f"USE {cur_db}.{cur_schema}") cols = [] - identifier = name if not schema else schema + "." + name + identifier = name if schema is None else schema + "." + name with self.begin() as con: cur = con.exec_driver_sql(f"DESCRIBE TABLE {identifier}").mappings() for colname, colinfo in zip(toolz.pluck("name", cur), inspected): diff --git a/ibis/backends/snowflake/tests/conftest.py b/ibis/backends/snowflake/tests/conftest.py index 14cf682f1a2f..bfd388b7b250 100644 --- a/ibis/backends/snowflake/tests/conftest.py +++ b/ibis/backends/snowflake/tests/conftest.py @@ -88,3 +88,8 @@ def connect(data_directory: Path) -> BaseBackend: if snowflake_url := os.environ.get("SNOWFLAKE_URL"): return ibis.connect(snowflake_url) # type: ignore pytest.skip("SNOWFLAKE_URL environment variable is not defined") + + +@pytest.fixture(scope="session") +def con(data_directory): + return TestConf.connect(data_directory) diff --git a/ibis/backends/snowflake/tests/test_client.py b/ibis/backends/snowflake/tests/test_client.py new file mode 100644 index 000000000000..56ddb79b7e76 --- /dev/null +++ b/ibis/backends/snowflake/tests/test_client.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import pytest + +import ibis +from ibis.util import guid + + +@pytest.mark.parametrize( + ("db", "schema", "cleanup"), + [ + (f"tmp_db_{guid()}", f"tmp_schema_{guid()}", True), + ("ibis_testing", f"tmp_schema_{guid()}", False), + ], + ids=["temp", "perm"], +) +def test_cross_db_access(con, db, schema, cleanup): + schema = f"{db}.{schema}" + table = f"tmp_table_{guid()}" + con.raw_sql(f"CREATE DATABASE IF NOT EXISTS {db}") + try: + con.raw_sql(f"CREATE SCHEMA IF NOT EXISTS {schema}") + try: + con.raw_sql(f'CREATE TEMP TABLE {schema}.{table} ("x" INT)') + t = con.table(table, schema=f"{schema}") + assert t.schema() == ibis.schema(dict(x="int")) + assert t.execute().empty + finally: + if cleanup: + con.raw_sql(f"DROP SCHEMA {schema}") + finally: + if cleanup: + con.raw_sql(f"DROP DATABASE {db}")