From 4fdb7079a9fcdc1848829e0235bebe7b333cb6f1 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 26 Aug 2024 13:13:12 -0400 Subject: [PATCH] fix(oracle): implement current_catalog and current_database correctly (#9918) Fixes some hacks around `current_catalog` and `current_database` for the Oracle backend --- ibis/backends/__init__.py | 10 ---------- ibis/backends/datafusion/__init__.py | 8 -------- ibis/backends/oracle/__init__.py | 10 +++++++++- ibis/backends/tests/test_api.py | 6 +----- ibis/backends/tests/test_client.py | 23 ++++++----------------- 5 files changed, 16 insertions(+), 41 deletions(-) diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index 5d6ff016ad20..6c8b4d83866e 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -614,11 +614,6 @@ def list_catalogs(self, like: str | None = None) -> list[str]: """ - @property - @abc.abstractmethod - def current_catalog(self) -> str: - """The current catalog in use.""" - class CanCreateCatalog(CanListCatalog): @abc.abstractmethod @@ -705,11 +700,6 @@ def list_databases( """ - @property - @abc.abstractmethod - def current_database(self) -> str: - """The current database in use.""" - class CanCreateDatabase(CanListDatabase): @abc.abstractmethod diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 05ec7d3bafbf..4b88e44865a9 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -235,14 +235,6 @@ def raw_sql(self, query: str | sge.Expression) -> Any: self._log(query) return self.con.sql(query) - @property - def current_catalog(self) -> str: - raise NotImplementedError() - - @property - def current_database(self) -> str: - raise NotImplementedError() - def list_catalogs(self, like: str | None = None) -> list[str]: code = ( sg.select(C.table_catalog) diff --git a/ibis/backends/oracle/__init__.py b/ibis/backends/oracle/__init__.py index 8e4c46fd51d6..31b636c718d0 100644 --- a/ibis/backends/oracle/__init__.py +++ b/ibis/backends/oracle/__init__.py @@ -192,8 +192,16 @@ def _from_url(self, url: ParseResult, **kwargs): return self @property - def current_database(self) -> str: + def current_catalog(self) -> str: with self._safe_raw_sql(sg.select(STAR).from_("global_name")) as cur: + [(catalog,)] = cur.fetchall() + return catalog + + @property + def current_database(self) -> str: + # databases correspond to users, other than that there's + # no notion of a database inside a catalog for oracle + with self._safe_raw_sql(sg.select("user").from_("dual")) as cur: [(database,)] = cur.fetchall() return database diff --git a/ibis/backends/tests/test_api.py b/ibis/backends/tests/test_api.py index 4b71bdb4ffee..ba1c65201e8e 100644 --- a/ibis/backends/tests/test_api.py +++ b/ibis/backends/tests/test_api.py @@ -25,6 +25,7 @@ def test_version(backend): "clickhouse", "sqlite", "dask", + "datafusion", "exasol", "pandas", "druid", @@ -37,11 +38,6 @@ def test_version(backend): reason="backend does not support catalogs", raises=AttributeError, ) -@pytest.mark.notimpl( - ["datafusion"], - raises=NotImplementedError, - reason="current_catalog isn't implemented", -) @pytest.mark.xfail_version(pyspark=["pyspark<3.4"]) def test_catalog_consistency(backend, con): catalogs = con.list_catalogs() diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 432f0be2836b..9dba682e11ed 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1753,24 +1753,13 @@ def test_cross_database_join(con_create_database, monkeypatch): ["impala", "pyspark", "trino"], reason="Default constraints are not supported" ) def test_insert_into_table_missing_columns(con, temp_table): - try: - db = getattr(con, "current_database", None) - except NotImplementedError: - db = None - - # UGH - if con.name == "oracle": - db = None - - try: - catalog = getattr(con, "current_catalog", None) - except NotImplementedError: - catalog = None + db = getattr(con, "current_database", None) - raw_ident = ".".join( - sg.to_identifier(i, quoted=True).sql("duckdb") - for i in filter(None, (catalog, db, temp_table)) - ) + raw_ident = sg.table( + temp_table, + db=db if db is None else sg.to_identifier(db, quoted=True), + quoted=True, + ).sql("duckdb") ct_sql = f'CREATE TABLE {raw_ident} ("a" INT DEFAULT 1, "b" INT)' sg_expr = sg.parse_one(ct_sql, read="duckdb")