diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 117565cf91be..94f8793307f9 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -6,7 +6,6 @@ import contextlib import os import warnings -from functools import partial from pathlib import Path from typing import ( TYPE_CHECKING, @@ -34,7 +33,7 @@ from ibis.formats.pandas import PandasData if TYPE_CHECKING: - from collections.abc import Iterable, Iterator, Mapping, MutableMapping + from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence import pandas as pd import torch @@ -173,6 +172,7 @@ def do_connect( database: str | Path = ":memory:", read_only: bool = False, temp_directory: str | Path | None = None, + extensions: Sequence[str] | None = None, **config: Any, ) -> None: """Create an Ibis client connected to a DuckDB database. @@ -186,6 +186,8 @@ def do_connect( temp_directory Directory to use for spilling to disk. Only set by default for in-memory connections. + extensions + A list of duckdb extensions to install/load upon connection. config DuckDB configuration parameters. See the [DuckDB configuration documentation](https://duckdb.org/docs/sql/configuration) for @@ -222,6 +224,8 @@ def do_connect( @sa.event.listens_for(engine, "connect") def configure_connection(dbapi_connection, connection_record): + if extensions is not None: + self._sa_load_extensions(dbapi_connection, extensions) dbapi_connection.execute("SET TimeZone = 'UTC'") # the progress bar in duckdb <0.8.0 causes kernel crashes in # jupyterlab, fixed in https://github.com/duckdb/duckdb/pull/6831 @@ -237,31 +241,36 @@ def configure_connection(dbapi_connection, connection_record): super().do_connect(engine) - def _load_extensions(self, extensions): - extension_name = sa.column("extension_name") - loaded = sa.column("loaded") - installed = sa.column("installed") - aliases = sa.column("aliases") - query = ( - sa.select(extension_name) - .select_from(sa.func.duckdb_extensions()) - .where( - sa.and_( - # extension isn't loaded or isn't installed - sa.not_(loaded & installed), - # extension is one that we're requesting, or an alias of it - sa.or_( - extension_name.in_(extensions), - *map(partial(sa.func.array_has, aliases), extensions), - ), - ) - ) + @staticmethod + def _sa_load_extensions(dbapi_con, extensions): + query = """ + WITH exts AS ( + SELECT extension_name AS name, aliases FROM duckdb_extensions() + WHERE installed AND loaded ) + SELECT name FROM exts + UNION (SELECT UNNEST(aliases) AS name FROM exts) + """ + installed = (name for (name,) in dbapi_con.sql(query).fetchall()) + # Install and load all other extensions + todo = set(extensions).difference(installed) + for extension in todo: + dbapi_con.install_extension(extension) + dbapi_con.load_extension(extension) + + def _load_extensions(self, extensions): with self.begin() as con: - c = con.connection - for extension in con.execute(query).scalars(): - c.install_extension(extension) - c.load_extension(extension) + self._sa_load_extensions(con.connection, extensions) + + def load_extension(self, extension: str) -> None: + """Install and load a duckdb extension by name or path. + + Parameters + ---------- + extension + The extension name or path. + """ + self._load_extensions([extension]) def create_schema( self, name: str, database: str | None = None, force: bool = False diff --git a/ibis/backends/duckdb/tests/test_client.py b/ibis/backends/duckdb/tests/test_client.py new file mode 100644 index 000000000000..0af8d655e48e --- /dev/null +++ b/ibis/backends/duckdb/tests/test_client.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import duckdb +import pytest +import sqlalchemy as sa + +import ibis +from ibis.conftest import LINUX, SANDBOXED + + +@pytest.mark.xfail( + LINUX and SANDBOXED, + reason="nix on linux cannot download duckdb extensions or data due to sandboxing", + raises=sa.exc.OperationalError, +) +def test_connect_extensions(): + con = ibis.duckdb.connect(extensions=["s3", "sqlite"]) + results = con.raw_sql( + """ + SELECT loaded FROM duckdb_extensions() + WHERE extension_name = 'httpfs' OR extension_name = 'sqlite' + """ + ).fetchall() + assert all(loaded for (loaded,) in results) + + +@pytest.mark.xfail( + LINUX and SANDBOXED, + reason="nix on linux cannot download duckdb extensions or data due to sandboxing", + raises=duckdb.IOException, +) +def test_load_extension(): + con = ibis.duckdb.connect() + con.load_extension("s3") + con.load_extension("sqlite") + results = con.raw_sql( + """ + SELECT loaded FROM duckdb_extensions() + WHERE extension_name = 'httpfs' OR extension_name = 'sqlite' + """ + ).fetchall() + assert all(loaded for (loaded,) in results)