Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add motherduck support for duckdb plugin #2680

Merged
merged 17 commits into from
Sep 3, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
DuckDBQuery
"""

from .task import DuckDBQuery
from .task import DuckDBProvider, DuckDBQuery
92 changes: 85 additions & 7 deletions plugins/flytekit-duckdb/flytekitplugins/duckdb/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from typing import Dict, List, NamedTuple, Optional, Union
from enum import Enum
from functools import partial
from typing import Callable, Dict, List, NamedTuple, Optional, Union

from flytekit import PythonInstanceTask, lazy_module
from flytekit import PythonInstanceTask, Secret, current_context, lazy_module
from flytekit.extend import Interface
from flytekit.types.structured.structured_dataset import StructuredDataset

Expand All @@ -10,6 +12,25 @@
pa = lazy_module("pyarrow")


class MissingSecretError(ValueError):
pass


def connect_local(token: Optional[str]):
"""Connect to local DuckDB."""
return duckdb.connect(":memory:")


def connect_motherduck(token: str):
"""Connect to MotherDuck."""
return duckdb.connect("md:", config={"motherduck_token": token})


class DuckDBProvider(Enum):
LOCAL = partial(connect_local)
MOTHERDUCK = partial(connect_motherduck)


class QueryOutput(NamedTuple):
counter: int = -1
output: Optional[str] = None
Expand All @@ -21,19 +42,53 @@ class DuckDBQuery(PythonInstanceTask):
def __init__(
self,
name: str,
query: Union[str, List[str]],
query: Optional[Union[str, List[str]]] = None,
thomasjpfan marked this conversation as resolved.
Show resolved Hide resolved
inputs: Optional[Dict[str, Union[StructuredDataset, list]]] = None,
provider: Union[DuckDBProvider, Callable] = DuckDBProvider.LOCAL,
**kwargs,
):
"""
This method initializes the DuckDBQuery.

Note that the provider can be one of the default providers listed in DuckDBProvider or a custom callable like the following:

def custom_connect_motherduck(token: str):
return duckdb.connect("md:", config={"motherduck_token": token, "another_config": "hello"})

DuckDBQuery(..., provider=custom_connect_motherduck)

Also note that a query can be provided at runtime if query=None is provided.

duckdb_query = DuckDBQuery(
samhita-alla marked this conversation as resolved.
Show resolved Hide resolved
name="my_duckdb_query",
inputs=kwtypes(query=str)
)

@workflow
def wf(user_query: str) -> pd.DataFrame:
return duckdb_query(query=user_query)

Args:
name: Name of the task
query: DuckDB query to execute
inputs: The query parameters to be used while executing the query
provider: DuckDB provider
"""
self._query = query
self._provider = provider
secret_requests: Optional[list[Secret]] = kwargs.get("secret_requests", None)
self._connect_secret = None
if secret_requests:
assert len(secret_requests) == 1, "Only one secret can be used for a DuckDBQuery task."
self._connect_secret = secret_requests[0]

if (
self._connect_secret is None
and isinstance(self._provider, DuckDBProvider)
and self._provider != DuckDBProvider.LOCAL
):
raise MissingSecretError(f"A secret_requests must be provided for the {self._provider.name} provider.")

outputs = {"result": StructuredDataset}

super(DuckDBQuery, self).__init__(
Expand All @@ -44,6 +99,25 @@ def __init__(
**kwargs,
)

def _connect_to_duckdb(self):
"""
Handles the connection to DuckDB based on the provider.

Returns:
A DuckDB connection object.
"""
connect_token = None
if self._connect_secret:
connect_token = current_context().secrets.get(
group=self._connect_secret.group,
key=self._connect_secret.key,
group_version=self._connect_secret.group_version,
)
if isinstance(self._provider, DuckDBProvider):
thomasjpfan marked this conversation as resolved.
Show resolved Hide resolved
return self._provider.value(connect_token)
else: # callable
return self._provider(connect_token)
dansola marked this conversation as resolved.
Show resolved Hide resolved

def _execute_query(
self, con: duckdb.DuckDBPyConnection, params: list, query: str, counter: int, multiple_params: bool
):
Expand Down Expand Up @@ -76,14 +150,15 @@ def _execute_query(

def execute(self, **kwargs) -> StructuredDataset:
# TODO: Enable iterative download after adding the functionality to structured dataset code.

# create an in-memory database that's non-persistent
con = duckdb.connect(":memory:")
con = self._connect_to_duckdb()

params = None
for key in self.python_interface.inputs.keys():
val = kwargs.get(key)
if isinstance(val, StructuredDataset):
if key == "query" and val is not None:
# Execution query takes priority
self._query = val
elif isinstance(val, StructuredDataset):
# register structured dataset
con.register(key, val.open(pa.Table).all())
elif isinstance(val, (pd.DataFrame, pa.Table)):
Expand All @@ -98,6 +173,9 @@ def execute(self, **kwargs) -> StructuredDataset:
else:
raise ValueError(f"Expected inputs of type StructuredDataset, str or list, received {type(val)}")

if self._query is None:
raise ValueError("A query must be specified when defining or executing a DuckDBQuery.")

final_query = self._query
query_output = QueryOutput()
# set flag to indicate the presence of params for multiple queries
Expand Down
31 changes: 28 additions & 3 deletions plugins/flytekit-duckdb/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
from typing import List

import pytest
import pandas as pd
import pyarrow as pa
from flytekitplugins.duckdb import DuckDBQuery
from flytekitplugins.duckdb import DuckDBQuery, DuckDBProvider
from flytekitplugins.duckdb.task import MissingSecretError
from typing_extensions import Annotated

from flytekit import kwtypes, task, workflow
from flytekit import kwtypes, task, workflow, Secret
from flytekit.types.structured.structured_dataset import StructuredDataset


Expand Down Expand Up @@ -146,3 +147,27 @@ def params_wf(params: str) -> pa.Table:
return duckdb_params_query(params=params)

assert isinstance(params_wf(params=json.dumps([[[500], [300], [2]]])), pa.Table)


def test_motherduck_no_token():
with pytest.raises(MissingSecretError, match="A secret_requests must be provided for the MOTHERDUCK provider."):
duckdb_params_query = DuckDBQuery(
name="motherduck_query",
query="SELECT SUM(a) FROM sometable",
provider=DuckDBProvider.MOTHERDUCK,
)


def test_runtime_query():
runtime_duckdb_query = DuckDBQuery(
name="runtime_query", inputs=kwtypes(mydf=pd.DataFrame, query=str)
)

@workflow
def pandas_wf(mydf: pd.DataFrame, query: str) -> pd.DataFrame:
return runtime_duckdb_query(mydf=df, query=query)

df = pd.DataFrame({"a": [1, 2, 3]})
query = "SELECT SUM(a) FROM mydf"
assert isinstance(pandas_wf(mydf=df, query=query), pd.DataFrame)
assert pandas_wf(mydf=df, query=query).iloc[0, 0] == 6
Loading