Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add motherduck support for duckdb plugin #2680

Merged
merged 17 commits into from
Sep 3, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
DuckDBQuery
"""

from .task import DuckDBQuery
from .task import DuckDBProvider, DuckDBQuery
60 changes: 54 additions & 6 deletions plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from enum import Enum
from functools import partial
from typing import Dict, List, NamedTuple, Optional, Union

from flytekit import PythonInstanceTask, lazy_module
from flytekit import PythonInstanceTask, Secret, current_context, lazy_module
from flytekit.extend import Interface
from flytekit.types.structured.structured_dataset import StructuredDataset

Expand All @@ -10,6 +12,26 @@
pa = lazy_module("pyarrow")


def connect_local(hosted_secret: Optional[Secret]):
"""Connect to local DuckDB."""
return duckdb.connect(":memory:")


def connect_motherduck(hosted_secret: Secret):
"""Connect to MotherDuck."""
motherduck_token = current_context().secrets.get(
group=hosted_secret.group,
key=hosted_secret.key,
group_version=hosted_secret.group_version,
)
return duckdb.connect("md:", config={"motherduck_token": motherduck_token})


class DuckDBProvider(Enum):
LOCAL = partial(connect_local)
MOTHERDUCK = partial(connect_motherduck)


class QueryOutput(NamedTuple):
counter: int = -1
output: Optional[str] = None
Expand All @@ -21,8 +43,9 @@ class DuckDBQuery(PythonInstanceTask):
def __init__(
self,
name: str,
query: Union[str, List[str]],
query: Optional[Union[str, List[str]]] = None,
thomasjpfan marked this conversation as resolved.
Show resolved Hide resolved
inputs: Optional[Dict[str, Union[StructuredDataset, list]]] = None,
provider: DuckDBProvider = DuckDBProvider.LOCAL,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using an enum restricts someone from passing a callable. Lets say someone wants to add a new keyword to duckdb.connect for motherduck:

def custom_connect_motherduck(token: str):
    """Connect to MotherDuck."""
    return duckdb.connect("md:", config={"motherduck_token": token, "another_config": "hello"})

They should be a to pass it into provider:

DuckDBQuery(..., provider=custom_connect_motherduck)

If you prefer an enum, I think we also open it up to callables:

    provider: Union[DuckDBProvider, Callable]

In the above callable API, I propose making the input just be a string. To pass in the secret the user API becomes:

secret = Secret(key="my-key", group="my-group")

DuckDBQuery(..., connect_secret=secret)

then in `DuckDBQuery._connect_to_duckdb:

def _connect_to_duckdb(self):
    connect_token = None
    if self._connect_secret:
        connect_token = current_context().secrets.get(
            group=self._connect_secret.group,
            key=self._connect_secret.key,
            group_version=self._connect_secret.group_version,
        )
    if isinstance(self.provider, DuckDBProvider):
        return self.provider.value(connect_token)
    else:  # callable
        return self.provider(connect_token)

This way the callable does not need to be responsible for handling current_context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do!! Thank you.

I will also use secret_requests from PythonAutoContainerTask and check for it in the constructor rather than adding a new secret argument.

**kwargs,
):
"""
Expand All @@ -32,8 +55,13 @@ def __init__(
name: Name of the task
query: DuckDB query to execute
inputs: The query parameters to be used while executing the query
provider: DuckDB provider (e.g., LOCAL, MOTHERDUCK, ANOTHERPRODUCT)
"""
self._query = query
self._provider = provider
secret_requests: Optional[list[Secret]] = kwargs.get("secret_requests", None)
self._hosted_secret = secret_requests[0] if secret_requests else None

outputs = {"result": StructuredDataset}

super(DuckDBQuery, self).__init__(
Expand All @@ -44,6 +72,22 @@ def __init__(
**kwargs,
)

def _connect_to_duckdb(self):
"""
Handles the connection to DuckDB based on the provider.

Returns:
A DuckDB connection object.
"""

if self._provider not in DuckDBProvider:
raise ValueError(f"Unknown DuckDB provider: {self._provider}")

if self._provider != DuckDBProvider.LOCAL and self._hosted_secret is None:
raise ValueError(f"A secret is required for the {self._provider} provider.")

return self._provider.value(self._hosted_secret)

def _execute_query(
self, con: duckdb.DuckDBPyConnection, params: list, query: str, counter: int, multiple_params: bool
):
Expand Down Expand Up @@ -76,14 +120,15 @@ def _execute_query(

def execute(self, **kwargs) -> StructuredDataset:
# TODO: Enable iterative download after adding the functionality to structured dataset code.

# create an in-memory database that's non-persistent
con = duckdb.connect(":memory:")
con = self._connect_to_duckdb()

params = None
for key in self.python_interface.inputs.keys():
val = kwargs.get(key)
if isinstance(val, StructuredDataset):
if key == "query" and val is not None:
# Execution query takes priority
self._query = val
elif isinstance(val, StructuredDataset):
# register structured dataset
con.register(key, val.open(pa.Table).all())
elif isinstance(val, (pd.DataFrame, pa.Table)):
Expand All @@ -98,6 +143,9 @@ def execute(self, **kwargs) -> StructuredDataset:
else:
raise ValueError(f"Expected inputs of type StructuredDataset, str or list, received {type(val)}")

if self._query is None:
raise ValueError("A query must be specified when defining or executing a DuckDBQuery.")

final_query = self._query
query_output = QueryOutput()
# set flag to indicate the presence of params for multiple queries
Expand Down
Loading