Skip to content

Commit

Permalink
Add support for dataclasses and pydantic in UDF/TVF parameters/returns
Browse files Browse the repository at this point in the history
  • Loading branch information
kesmit13 committed Jan 23, 2025
1 parent 31e8124 commit 52db5af
Show file tree
Hide file tree
Showing 5 changed files with 300 additions and 120 deletions.
192 changes: 118 additions & 74 deletions singlestoredb/functions/decorator.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
import dataclasses
import datetime
import functools
import string
import inspect
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

from . import dtypes
from .dtypes import DataType
from .signature import simplify_dtype

try:
import pydantic
has_pydantic = True
except ImportError:
has_pydantic = False

python_type_map: Dict[Any, Callable[..., str]] = {
str: dtypes.TEXT,
Expand All @@ -33,88 +42,123 @@ def listify(x: Any) -> List[Any]:
return [x]


def process_annotation(annotation: Any) -> Tuple[Any, bool]:
types = simplify_dtype(annotation)
if isinstance(types, list):
nullable = False
if type(None) in types:
nullable = True
types = [x for x in types if x is not type(None)]
if len(types) > 1:
raise ValueError(f'multiple types not supported: {annotation}')
return types[0], nullable
return types, True


def process_types(params: Any) -> Any:
if params is None:
return params, []

elif isinstance(params, (list, tuple)):
params = list(params)
for i, item in enumerate(params):
if params[i] in python_type_map:
params[i] = python_type_map[params[i]]()
elif callable(item):
params[i] = item()
for item in params:
if not isinstance(item, str):
raise TypeError(f'unrecognized type for parameter: {item}')
return params, []

elif isinstance(params, dict):
names = []
params = dict(params)
for k, v in list(params.items()):
names.append(k)
if params[k] in python_type_map:
params[k] = python_type_map[params[k]]()
elif callable(v):
params[k] = v()
for item in params.values():
if not isinstance(item, str):
raise TypeError(f'unrecognized type for parameter: {item}')
return params, names

elif dataclasses.is_dataclass(params):
names = []
out = []
for item in dataclasses.fields(params):
typ, nullable = process_annotation(item.type)
sql_type = process_types(typ)[0]
if not nullable:
sql_type = sql_type.replace('NULL', 'NOT NULL')
out.append(sql_type)
names.append(item.name)
return out, names

elif has_pydantic and inspect.isclass(params) \
and issubclass(params, pydantic.BaseModel):
names = []
out = []
for name, item in params.model_fields.items():
typ, nullable = process_annotation(item.annotation)
sql_type = process_types(typ)[0]
if not nullable:
sql_type = sql_type.replace('NULL', 'NOT NULL')
out.append(sql_type)
names.append(name)
return out, names

elif params in python_type_map:
return python_type_map[params](), []

elif callable(params):
return params(), []

elif isinstance(params, str):
return params, []

raise TypeError(f'unrecognized data type for args: {params}')


def _func(
func: Optional[Callable[..., Any]] = None,
*,
name: Optional[str] = None,
args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None,
returns: Optional[Union[str, List[DataType], List[type]]] = None,
args: Optional[
Union[
DataType,
List[DataType],
Dict[str, DataType],
'pydantic.BaseModel',
type,
]
] = None,
returns: Optional[
Union[
str,
List[DataType],
List[type],
'pydantic.BaseModel',
type,
]
] = None,
data_format: Optional[str] = None,
include_masks: bool = False,
function_type: str = 'udf',
output_fields: Optional[List[str]] = None,
) -> Callable[..., Any]:
"""Generic wrapper for UDF and TVF decorators."""
if args is None:
pass
elif isinstance(args, (list, tuple)):
args = list(args)
for i, item in enumerate(args):
if args[i] in python_type_map:
args[i] = python_type_map[args[i]]()
elif callable(item):
args[i] = item()
for item in args:
if not isinstance(item, str):
raise TypeError(f'unrecognized type for parameter: {item}')
elif isinstance(args, dict):
args = dict(args)
for k, v in list(args.items()):
if args[k] in python_type_map:
args[k] = python_type_map[args[k]]()
elif callable(v):
args[k] = v()
for item in args.values():
if not isinstance(item, str):
raise TypeError(f'unrecognized type for parameter: {item}')
elif args in python_type_map:
args = python_type_map[args]()
elif callable(args):
args = args()
elif isinstance(args, str):
args = args
else:
raise TypeError(f'unrecognized data type for args: {args}')

if returns is None:
pass
elif isinstance(returns, (list, tuple)):
returns = list(returns)
for i, item in enumerate(returns):
if item in python_type_map:
returns[i] = python_type_map[item]()
elif callable(item):
returns[i] = item()
for item in returns:
if not isinstance(item, str):
raise TypeError(f'unrecognized return type: {item}')
elif returns in python_type_map:
returns = python_type_map[returns]()
elif callable(returns):
returns = returns()
elif isinstance(returns, str):
returns = returns
else:
raise TypeError(f'unrecognized return type: {returns}')

if returns is None:
pass
elif isinstance(returns, list):
for item in returns:
if not isinstance(item, str):
raise TypeError(f'unrecognized return type: {item}')
elif not isinstance(returns, str):
raise TypeError(f'unrecognized return type: {returns}')

if not output_fields:
if isinstance(returns, list):
output_fields = []
for i, _ in enumerate(returns):
output_fields.append(string.ascii_letters[i])
else:
output_fields = [string.ascii_letters[0]]

if isinstance(returns, list) and len(output_fields) != len(returns):
args, _ = process_types(args)
returns, fields = process_types(returns)

if not output_fields and fields:
output_fields = fields

if isinstance(returns, list) \
and isinstance(output_fields, list) \
and len(output_fields) != len(returns):
raise ValueError(
'The number of output fields must match the number of return types',
)
Expand All @@ -133,7 +177,7 @@ def _func(
data_format=data_format,
include_masks=include_masks,
function_type=function_type,
output_fields=output_fields,
output_fields=output_fields or None,
).items() if v is not None
}

Expand Down
13 changes: 11 additions & 2 deletions singlestoredb/functions/ext/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"""
import argparse
import asyncio
import dataclasses
import importlib.util
import io
import itertools
Expand Down Expand Up @@ -136,6 +137,14 @@ def get_func_names(funcs: str) -> List[Tuple[str, str]]:
return out


def as_tuple(x: Any) -> Any:
if hasattr(x, 'model_fields'):
return tuple(x.model_fields.values())
if dataclasses.is_dataclass(x):
return dataclasses.astuple(x)
return x


def make_func(
name: str,
func: Callable[..., Any],
Expand Down Expand Up @@ -174,7 +183,7 @@ async def do_func(
out_ids: List[int] = []
out = []
for i, res in zip(row_ids, func_map(func, rows)):
out.extend(res)
out.extend(as_tuple(res))
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
return out_ids, out

Expand Down Expand Up @@ -234,7 +243,7 @@ async def do_func(
List[Tuple[Any]],
]:
'''Call function on given rows of data.'''
return row_ids, list(zip(func_map(func, rows)))
return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))]

else:
# Vector formats use the same function wrapper
Expand Down
Loading

0 comments on commit 52db5af

Please sign in to comment.