Skip to content

Commit

Permalink
feat(duckdb): expose loading extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Sep 7, 2023
1 parent 95f2f38 commit 2feecf7
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 25 deletions.
59 changes: 34 additions & 25 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import contextlib
import os
import warnings
from functools import partial
from pathlib import Path
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -34,7 +33,7 @@
from ibis.formats.pandas import PandasData

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, MutableMapping
from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence

import pandas as pd
import torch
Expand Down Expand Up @@ -173,6 +172,7 @@ def do_connect(
database: str | Path = ":memory:",
read_only: bool = False,
temp_directory: str | Path | None = None,
extensions: Sequence[str] | None = None,
**config: Any,
) -> None:
"""Create an Ibis client connected to a DuckDB database.
Expand All @@ -186,6 +186,8 @@ def do_connect(
temp_directory
Directory to use for spilling to disk. Only set by default for
in-memory connections.
extensions
A list of duckdb extensions to install/load upon connection.
config
DuckDB configuration parameters. See the [DuckDB configuration
documentation](https://duckdb.org/docs/sql/configuration) for
Expand Down Expand Up @@ -222,6 +224,8 @@ def do_connect(

@sa.event.listens_for(engine, "connect")
def configure_connection(dbapi_connection, connection_record):
if extensions is not None:
self._sa_load_extensions(dbapi_connection, extensions)
dbapi_connection.execute("SET TimeZone = 'UTC'")
# the progress bar in duckdb <0.8.0 causes kernel crashes in
# jupyterlab, fixed in https://github.com/duckdb/duckdb/pull/6831
Expand All @@ -237,31 +241,36 @@ def configure_connection(dbapi_connection, connection_record):

super().do_connect(engine)

def _load_extensions(self, extensions):
extension_name = sa.column("extension_name")
loaded = sa.column("loaded")
installed = sa.column("installed")
aliases = sa.column("aliases")
query = (
sa.select(extension_name)
.select_from(sa.func.duckdb_extensions())
.where(
sa.and_(
# extension isn't loaded or isn't installed
sa.not_(loaded & installed),
# extension is one that we're requesting, or an alias of it
sa.or_(
extension_name.in_(extensions),
*map(partial(sa.func.array_has, aliases), extensions),
),
)
)
@staticmethod
def _sa_load_extensions(dbapi_con, extensions):
query = """
WITH exts AS (
SELECT extension_name AS name, aliases FROM duckdb_extensions()
WHERE installed AND loaded
)
SELECT name FROM exts
UNION (SELECT UNNEST(aliases) AS name FROM exts)
"""
installed = (name for (name,) in dbapi_con.sql(query).fetchall())
# Install and load all other extensions
todo = set(extensions).difference(installed)
for extension in todo:
dbapi_con.install_extension(extension)
dbapi_con.load_extension(extension)

def _load_extensions(self, extensions):
with self.begin() as con:
c = con.connection
for extension in con.execute(query).scalars():
c.install_extension(extension)
c.load_extension(extension)
self._sa_load_extensions(con.connection, extensions)

def load_extension(self, extension: str) -> None:
"""Install and load a duckdb extension by name or path.
Parameters
----------
extension
The extension name or path.
"""
self._load_extensions([extension])

def create_schema(
self, name: str, database: str | None = None, force: bool = False
Expand Down
42 changes: 42 additions & 0 deletions ibis/backends/duckdb/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

import duckdb
import pytest
import sqlalchemy as sa

import ibis
from ibis.conftest import LINUX, SANDBOXED


@pytest.mark.xfail(
LINUX and SANDBOXED,
reason="nix on linux cannot download duckdb extensions or data due to sandboxing",
raises=sa.exc.OperationalError,
)
def test_connect_extensions():
con = ibis.duckdb.connect(extensions=["s3", "sqlite"])
results = con.raw_sql(
"""
SELECT loaded FROM duckdb_extensions()
WHERE extension_name = 'httpfs' OR extension_name = 'sqlite'
"""
).fetchall()
assert all(loaded for (loaded,) in results)


@pytest.mark.xfail(
LINUX and SANDBOXED,
reason="nix on linux cannot download duckdb extensions or data due to sandboxing",
raises=duckdb.IOException,
)
def test_load_extension():
con = ibis.duckdb.connect()
con.load_extension("s3")
con.load_extension("sqlite")
results = con.raw_sql(
"""
SELECT loaded FROM duckdb_extensions()
WHERE extension_name = 'httpfs' OR extension_name = 'sqlite'
"""
).fetchall()
assert all(loaded for (loaded,) in results)

0 comments on commit 2feecf7

Please sign in to comment.