diff --git a/ibis/backends/base/__init__.py b/ibis/backends/base/__init__.py index 7e2dbd46b1e3..f5419eaea469 100644 --- a/ibis/backends/base/__init__.py +++ b/ibis/backends/base/__init__.py @@ -864,6 +864,10 @@ def connect(self, *args, **kwargs) -> BaseBackend: new_backend.reconnect() return new_backend + @abc.abstractmethod + def disconnect(self) -> None: + """Close the connection to the backend.""" + @staticmethod def _convert_kwargs(kwargs: MutableMapping) -> None: """Manipulate keyword arguments to `.connect` method.""" diff --git a/ibis/backends/base/sqlglot/__init__.py b/ibis/backends/base/sqlglot/__init__.py index 9380ba8c8916..b4c1ef71b869 100644 --- a/ibis/backends/base/sqlglot/__init__.py +++ b/ibis/backends/base/sqlglot/__init__.py @@ -385,3 +385,8 @@ def truncate_table( ).sql(self.dialect) with self._safe_raw_sql(f"TRUNCATE TABLE {ident}"): pass + + def disconnect(self): + # This is part of the Python DB-API specification so should work for + # _most_ sqlglot backends + self.con.close() diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index 43fe70af5c8a..a7d8e20d520e 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -447,6 +447,9 @@ def do_connect( self.partition_column = partition_column + def disconnect(self) -> None: + self.client.close() + def _parse_project_and_dataset(self, dataset) -> tuple[str, str]: if not dataset and not self.dataset: raise ValueError("Unable to determine BigQuery dataset.") diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index 9e6cde739f1b..43f490a07f4d 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -446,7 +446,7 @@ def raw_sql( self._log(query) return self.con.query(query, external_data=external_data, **kwargs) - def close(self) -> None: + def disconnect(self) -> None: """Close ClickHouse connection.""" self.con.close() diff --git a/ibis/backends/dask/__init__.py b/ibis/backends/dask/__init__.py index df1de12c2359..67ee86b89dea 100644 --- a/ibis/backends/dask/__init__.py +++ b/ibis/backends/dask/__init__.py @@ -61,6 +61,9 @@ def do_connect( ) super().do_connect(dictionary) + def disconnect(self) -> None: + pass + @property def version(self): return dask.__version__ diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index 44946ac3abf5..ee251cee9f39 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -94,6 +94,9 @@ def do_connect( for name, path in config.items(): self.register(path, table_name=name) + def disconnect(self) -> None: + pass + @contextlib.contextmanager def _safe_raw_sql(self, sql: sge.Statement) -> Any: yield self.raw_sql(sql).collect() diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index 1d631ddca52d..7e1af1bbeb1c 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -67,6 +67,9 @@ def do_connect(self, table_env: TableEnvironment) -> None: """ self._table_env = table_env + def disconnect(self) -> None: + pass + def raw_sql(self, query: str) -> TableResult: return self._table_env.execute_sql(query) diff --git a/ibis/backends/pandas/__init__.py b/ibis/backends/pandas/__init__.py index 4d3cddb665c0..fa89b26af606 100644 --- a/ibis/backends/pandas/__init__.py +++ b/ibis/backends/pandas/__init__.py @@ -53,6 +53,9 @@ def do_connect( self.dictionary = dictionary or {} self.schemas: MutableMapping[str, sch.Schema] = {} + def disconnect(self) -> None: + pass + def from_dataframe( self, df: pd.DataFrame, diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index b03331fe80d5..694494275964 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -59,6 +59,9 @@ def do_connect( for name, table in (tables or {}).items(): self._add_table(name, table) + def disconnect(self) -> None: + pass + @property def version(self) -> str: return pl.__version__ diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index e77839fa939f..3507636e1cc2 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -158,6 +158,9 @@ def do_connect(self, session: SparkSession) -> None: self._session.conf.set("spark.sql.session.timeZone", "UTC") self._session.conf.set("spark.sql.mapKeyDedupPolicy", "LAST_WIN") + def disconnect(self) -> None: + self._session.stop() + def _metadata(self, query: str): cursor = self.raw_sql(query) struct_dtype = PySparkType.to_ibis(cursor.query.schema) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 62ffb72c0ccf..1ebc504be679 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -1460,3 +1460,28 @@ def test_list_databases_schemas(con_create_database_schema): con_create_database_schema.drop_schema(schema, database=database) finally: con_create_database_schema.drop_database(database) + + +@pytest.mark.notyet( + ["pandas", "dask", "polars", "datafusion"], + reason="this is a no-op for in-memory backends", +) +@pytest.mark.notyet( + ["trino", "clickhouse", "impala", "bigquery", "flink"], + reason="Backend client does not conform to DB-API, subsequent op does not raise", +) +@pytest.mark.skip() +def test_close_connection(con): + if con.name == "pyspark": + # It would be great if there were a simple way to say "give me a new + # spark context" but I haven't found it. + pytest.skip("Closing spark context breaks subsequent tests") + new_con = getattr(ibis, con.name).connect(*con._con_args, **con._con_kwargs) + + # Run any command that hits the backend + _ = new_con.list_tables() + new_con.disconnect() + + # DB-API states that subsequent execution attempt should raise + with pytest.raises(Exception): # noqa:B017 + new_con.list_tables() diff --git a/ibis/tests/expr/mocks.py b/ibis/tests/expr/mocks.py index 7a83adfe7060..8c6347428b21 100644 --- a/ibis/tests/expr/mocks.py +++ b/ibis/tests/expr/mocks.py @@ -39,6 +39,9 @@ def __init__(self): def do_connect(self): pass + def disconnect(self): + pass + def table(self, name, **kwargs): schema = self.get_schema(name) node = ops.DatabaseTable(source=self, name=name, schema=schema)