diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 0a023e4d433c..a749db306eed 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -21,6 +21,7 @@ from sqlalchemy.ext.compiler import compiles import ibis +import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir @@ -505,12 +506,23 @@ def _get_temp_view_definition( yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}" def create_database(self, name: str, force: bool = False) -> None: + current_database = self.current_database name = self._quote(name) if_not_exists = "IF NOT EXISTS " * force with self.begin() as con: con.exec_driver_sql(f"CREATE DATABASE {if_not_exists}{name}") + # Snowflake automatically switches to the new database after creating + # it per + # https://docs.snowflake.com/en/sql-reference/sql/create-database#general-usage-notes + # so we switch back to the original database + con.exec_driver_sql(f"USE DATABASE {self._quote(current_database)}") def drop_database(self, name: str, force: bool = False) -> None: + current_database = self.current_database + if name == current_database: + raise com.UnsupportedOperationError( + "Dropping the current database is not supported because its behavior is undefined" + ) name = self._quote(name) if_exists = "IF EXISTS " * force with self.begin() as con: @@ -521,12 +533,28 @@ def create_schema( ) -> None: name = ".".join(map(self._quote, filter(None, [database, name]))) if_not_exists = "IF NOT EXISTS " * force + current_database = self.current_database + current_schema = self.current_schema with self.begin() as con: con.exec_driver_sql(f"CREATE SCHEMA {if_not_exists}{name}") + # Snowflake automatically switches to the new schema after creating + # it per + # https://docs.snowflake.com/en/sql-reference/sql/create-schema#usage-notes + # so we switch back to the original schema + con.exec_driver_sql( + f"USE SCHEMA {self._quote(current_database)}.{self._quote(current_schema)}" + ) def drop_schema( self, name: str, database: str | None = None, force: bool = False ) -> None: + if self.current_schema == name and ( + database is None or self.current_database == database + ): + raise com.UnsupportedOperationError( + "Dropping the current schema is not supported because its behavior is undefined" + ) + name = ".".join(map(self._quote, filter(None, [database, name]))) if_exists = "IF EXISTS " * force with self.begin() as con: diff --git a/ibis/backends/snowflake/tests/test_client.py b/ibis/backends/snowflake/tests/test_client.py index 0d358a04e4b5..b6a466c228e8 100644 --- a/ibis/backends/snowflake/tests/test_client.py +++ b/ibis/backends/snowflake/tests/test_client.py @@ -6,6 +6,7 @@ import pytest import ibis +import ibis.common.exceptions as com from ibis.backends.snowflake.tests.conftest import _get_url from ibis.util import gen_name @@ -81,3 +82,81 @@ def test_timestamp_tz_column(simple_con): ).mutate(ts=lambda t: t.ts.to_timestamp("YYYY-MM-DD HH24-MI-SS")) expr = t.ts assert expr.execute().empty + + +def test_create_schema(simple_con): + schema = gen_name("test_create_schema") + + cur_schema = simple_con.current_schema + cur_db = simple_con.current_database + + simple_con.create_schema(schema) + + assert simple_con.current_schema == cur_schema + assert simple_con.current_database == cur_db + + simple_con.drop_schema(schema) + + assert simple_con.current_schema == cur_schema + assert simple_con.current_database == cur_db + + +def test_create_database(simple_con): + database = gen_name("test_create_database") + cur_db = simple_con.current_database + + simple_con.create_database(database) + assert simple_con.current_database == cur_db + + simple_con.drop_database(database) + assert simple_con.current_database == cur_db + + +@pytest.fixture(scope="session") +def db_con(): + return ibis.connect(_get_url()) + + +@pytest.fixture(scope="session") +def schema_con(): + return ibis.connect(_get_url()) + + +def test_drop_current_db_not_allowed(db_con): + database = gen_name("test_create_database") + cur_db = db_con.current_database + + db_con.create_database(database) + + assert db_con.current_database == cur_db + + with db_con.begin() as c: + c.exec_driver_sql(f'USE DATABASE "{database}"') + + with pytest.raises(com.UnsupportedOperationError, match="behavior is undefined"): + db_con.drop_database(database) + + with db_con.begin() as c: + c.exec_driver_sql(f"USE DATABASE {cur_db}") + + db_con.drop_database(database) + + +def test_drop_current_schema_not_allowed(schema_con): + schema = gen_name("test_create_schema") + cur_schema = schema_con.current_schema + + schema_con.create_schema(schema) + + assert schema_con.current_schema == cur_schema + + with schema_con.begin() as c: + c.exec_driver_sql(f'USE SCHEMA "{schema}"') + + with pytest.raises(com.UnsupportedOperationError, match="behavior is undefined"): + schema_con.drop_schema(schema) + + with schema_con.begin() as c: + c.exec_driver_sql(f"USE SCHEMA {cur_schema}") + + schema_con.drop_schema(schema)