diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index e17e19e4..9e6ef7ff 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -1,15 +1,24 @@ +import dataclasses import datetime import functools -import string +import inspect from typing import Any from typing import Callable from typing import Dict from typing import List from typing import Optional +from typing import Tuple from typing import Union from . import dtypes from .dtypes import DataType +from .signature import simplify_dtype + +try: + import pydantic + has_pydantic = True +except ImportError: + has_pydantic = False python_type_map: Dict[Any, Callable[..., str]] = { str: dtypes.TEXT, @@ -33,88 +42,123 @@ def listify(x: Any) -> List[Any]: return [x] +def process_annotation(annotation: Any) -> Tuple[Any, bool]: + types = simplify_dtype(annotation) + if isinstance(types, list): + nullable = False + if type(None) in types: + nullable = True + types = [x for x in types if x is not type(None)] + if len(types) > 1: + raise ValueError(f'multiple types not supported: {annotation}') + return types[0], nullable + return types, True + + +def process_types(params: Any) -> Any: + if params is None: + return params, [] + + elif isinstance(params, (list, tuple)): + params = list(params) + for i, item in enumerate(params): + if params[i] in python_type_map: + params[i] = python_type_map[params[i]]() + elif callable(item): + params[i] = item() + for item in params: + if not isinstance(item, str): + raise TypeError(f'unrecognized type for parameter: {item}') + return params, [] + + elif isinstance(params, dict): + names = [] + params = dict(params) + for k, v in list(params.items()): + names.append(k) + if params[k] in python_type_map: + params[k] = python_type_map[params[k]]() + elif callable(v): + params[k] = v() + for item in params.values(): + if not isinstance(item, str): + raise TypeError(f'unrecognized type for parameter: {item}') + return params, names + + elif dataclasses.is_dataclass(params): + names = [] + out = [] + for item in dataclasses.fields(params): + typ, nullable = process_annotation(item.type) + sql_type = process_types(typ)[0] + if not nullable: + sql_type = sql_type.replace('NULL', 'NOT NULL') + out.append(sql_type) + names.append(item.name) + return out, names + + elif has_pydantic and inspect.isclass(params) \ + and issubclass(params, pydantic.BaseModel): + names = [] + out = [] + for name, item in params.model_fields.items(): + typ, nullable = process_annotation(item.annotation) + sql_type = process_types(typ)[0] + if not nullable: + sql_type = sql_type.replace('NULL', 'NOT NULL') + out.append(sql_type) + names.append(name) + return out, names + + elif params in python_type_map: + return python_type_map[params](), [] + + elif callable(params): + return params(), [] + + elif isinstance(params, str): + return params, [] + + raise TypeError(f'unrecognized data type for args: {params}') + + def _func( func: Optional[Callable[..., Any]] = None, *, name: Optional[str] = None, - args: Optional[Union[DataType, List[DataType], Dict[str, DataType]]] = None, - returns: Optional[Union[str, List[DataType], List[type]]] = None, + args: Optional[ + Union[ + DataType, + List[DataType], + Dict[str, DataType], + 'pydantic.BaseModel', + type, + ] + ] = None, + returns: Optional[ + Union[ + str, + List[DataType], + List[type], + 'pydantic.BaseModel', + type, + ] + ] = None, data_format: Optional[str] = None, include_masks: bool = False, function_type: str = 'udf', output_fields: Optional[List[str]] = None, ) -> Callable[..., Any]: """Generic wrapper for UDF and TVF decorators.""" - if args is None: - pass - elif isinstance(args, (list, tuple)): - args = list(args) - for i, item in enumerate(args): - if args[i] in python_type_map: - args[i] = python_type_map[args[i]]() - elif callable(item): - args[i] = item() - for item in args: - if not isinstance(item, str): - raise TypeError(f'unrecognized type for parameter: {item}') - elif isinstance(args, dict): - args = dict(args) - for k, v in list(args.items()): - if args[k] in python_type_map: - args[k] = python_type_map[args[k]]() - elif callable(v): - args[k] = v() - for item in args.values(): - if not isinstance(item, str): - raise TypeError(f'unrecognized type for parameter: {item}') - elif args in python_type_map: - args = python_type_map[args]() - elif callable(args): - args = args() - elif isinstance(args, str): - args = args - else: - raise TypeError(f'unrecognized data type for args: {args}') - - if returns is None: - pass - elif isinstance(returns, (list, tuple)): - returns = list(returns) - for i, item in enumerate(returns): - if item in python_type_map: - returns[i] = python_type_map[item]() - elif callable(item): - returns[i] = item() - for item in returns: - if not isinstance(item, str): - raise TypeError(f'unrecognized return type: {item}') - elif returns in python_type_map: - returns = python_type_map[returns]() - elif callable(returns): - returns = returns() - elif isinstance(returns, str): - returns = returns - else: - raise TypeError(f'unrecognized return type: {returns}') - - if returns is None: - pass - elif isinstance(returns, list): - for item in returns: - if not isinstance(item, str): - raise TypeError(f'unrecognized return type: {item}') - elif not isinstance(returns, str): - raise TypeError(f'unrecognized return type: {returns}') - - if not output_fields: - if isinstance(returns, list): - output_fields = [] - for i, _ in enumerate(returns): - output_fields.append(string.ascii_letters[i]) - else: - output_fields = [string.ascii_letters[0]] - - if isinstance(returns, list) and len(output_fields) != len(returns): + args, _ = process_types(args) + returns, fields = process_types(returns) + + if not output_fields and fields: + output_fields = fields + + if isinstance(returns, list) \ + and isinstance(output_fields, list) \ + and len(output_fields) != len(returns): raise ValueError( 'The number of output fields must match the number of return types', ) @@ -133,7 +177,7 @@ def _func( data_format=data_format, include_masks=include_masks, function_type=function_type, - output_fields=output_fields, + output_fields=output_fields or None, ).items() if v is not None } diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 323160eb..f02dda43 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -24,6 +24,7 @@ """ import argparse import asyncio +import dataclasses import importlib.util import io import itertools @@ -136,6 +137,14 @@ def get_func_names(funcs: str) -> List[Tuple[str, str]]: return out +def as_tuple(x: Any) -> Any: + if hasattr(x, 'model_fields'): + return tuple(x.model_fields.values()) + if dataclasses.is_dataclass(x): + return dataclasses.astuple(x) + return x + + def make_func( name: str, func: Callable[..., Any], @@ -174,7 +183,7 @@ async def do_func( out_ids: List[int] = [] out = [] for i, res in zip(row_ids, func_map(func, rows)): - out.extend(res) + out.extend(as_tuple(res)) out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) return out_ids, out @@ -234,7 +243,7 @@ async def do_func( List[Tuple[Any]], ]: '''Call function on given rows of data.''' - return row_ids, list(zip(func_map(func, rows))) + return row_ids, [as_tuple(x) for x in zip(func_map(func, rows))] else: # Vector formats use the same function wrapper diff --git a/singlestoredb/functions/signature.py b/singlestoredb/functions/signature.py index 38df66d0..66e070d4 100644 --- a/singlestoredb/functions/signature.py +++ b/singlestoredb/functions/signature.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +import dataclasses import datetime import inspect import numbers @@ -22,6 +23,12 @@ except ImportError: has_numpy = False +try: + import pydantic + has_pydantic = True +except ImportError: + has_pydantic = False + from . import dtypes as dt from ..mysql.converters import escape_item # type: ignore @@ -243,6 +250,9 @@ def classify_dtype(dtype: Any) -> str: if isinstance(dtype, list): return '|'.join(classify_dtype(x) for x in dtype) + if isinstance(dtype, str): + return sql_to_dtype(dtype) + # Specific types if dtype is None or dtype is type(None): # noqa: E721 return 'null' @@ -253,6 +263,21 @@ def classify_dtype(dtype: Any) -> str: if dtype is bool: return 'bool' + if dataclasses.is_dataclass(dtype): + fields = dataclasses.fields(dtype) + item_dtypes = ','.join( + f'{classify_dtype(simplify_dtype(x.type))}' for x in fields + ) + return f'tuple[{item_dtypes}]' + + if has_pydantic and inspect.isclass(dtype) and issubclass(dtype, pydantic.BaseModel): + fields = dtype.model_fields.values() + item_dtypes = ','.join( + f'{classify_dtype(simplify_dtype(x.annotation))}' # type: ignore + for x in fields + ) + return f'tuple[{item_dtypes}]' + if not inspect.isclass(dtype): # Check for compound types origin = typing.get_origin(dtype) @@ -261,7 +286,7 @@ def classify_dtype(dtype: Any) -> str: if origin is Tuple: args = typing.get_args(dtype) item_dtypes = ','.join(classify_dtype(x) for x in args) - return f'tuple:{item_dtypes}' + return f'tuple[{item_dtypes}]' # Array types elif issubclass(origin, array_types): @@ -504,28 +529,55 @@ def get_signature(func: Callable[..., Any], name: Optional[str] = None) -> Dict[ sql = returns_overrides out_type = sql_to_dtype(sql) elif isinstance(returns_overrides, list): - sqls = [] - out_types = [] - for i, item in enumerate(returns_overrides): - if not isinstance(item, str): - raise TypeError(f'unrecognized type for return value: {item}') - if output_fields: - sqls.append(f'`{output_fields[i]}` {item}') - else: - sqls.append(f'{string.ascii_letters[i]} {item}') - out_types.append(sql_to_dtype(item)) - if function_type == 'tvf': - sql = 'TABLE({})'.format(', '.join(sqls)) - else: - sql = 'RECORD({})'.format(', '.join(sqls)) - out_type = 'tuple[{}]'.format(','.join(out_types)) + if not output_fields: + output_fields = [ + string.ascii_letters[i] for i in range(len(returns_overrides)) + ] + out_type = 'tuple[' + collapse_dtypes([ + classify_dtype(x) + for x in simplify_dtype(returns_overrides) + ]).replace('|', ',') + ']' + sql = dtype_to_sql( + out_type, function_type=function_type, field_names=output_fields, + ) + elif dataclasses.is_dataclass(returns_overrides): + out_type = collapse_dtypes([ + classify_dtype(x) + for x in simplify_dtype([x.type for x in returns_overrides.fields]) + ]) + sql = dtype_to_sql( + out_type, + function_type=function_type, + field_names=[x.name for x in returns_overrides.fields], + ) + elif has_pydantic and inspect.isclass(returns_overrides) \ + and issubclass(returns_overrides, pydantic.BaseModel): + out_type = collapse_dtypes([ + classify_dtype(x) + for x in simplify_dtype([x for x in returns_overrides.model_fields.values()]) + ]) + sql = dtype_to_sql( + out_type, + function_type=function_type, + field_names=[x for x in returns_overrides.model_fields.keys()], + ) elif returns_overrides is not None and not isinstance(returns_overrides, str): raise TypeError(f'unrecognized type for return value: {returns_overrides}') else: + if not output_fields: + if dataclasses.is_dataclass(signature.return_annotation): + output_fields = [ + x.name for x in dataclasses.fields(signature.return_annotation) + ] + elif has_pydantic and inspect.isclass(signature.return_annotation) \ + and issubclass(signature.return_annotation, pydantic.BaseModel): + output_fields = list(signature.return_annotation.model_fields.keys()) out_type = collapse_dtypes([ classify_dtype(x) for x in simplify_dtype(signature.return_annotation) ]) - sql = dtype_to_sql(out_type, function_type=function_type) + sql = dtype_to_sql( + out_type, function_type=function_type, field_names=output_fields, + ) out['returns'] = dict(dtype=out_type, sql=sql, default=None) copied_keys = ['database', 'environment', 'packages', 'resources', 'replace'] @@ -580,7 +632,12 @@ def sql_to_dtype(sql: str) -> str: return dtype -def dtype_to_sql(dtype: str, default: Any = None, function_type: str = 'udf') -> str: +def dtype_to_sql( + dtype: str, + default: Any = None, + field_names: Optional[List[str]] = None, + function_type: str = 'udf', +) -> str: """ Convert a collapsed dtype string to a SQL type. @@ -590,6 +647,8 @@ def dtype_to_sql(dtype: str, default: Any = None, function_type: str = 'udf') -> Simplified data type string default : Any, optional Default value + field_names : List[str], optional + Field names for tuple types Returns ------- @@ -621,17 +680,22 @@ def dtype_to_sql(dtype: str, default: Any = None, function_type: str = 'udf') -> dtypes = dtypes[:-1] item_dtypes = [] for i, item in enumerate(dtypes.split(',')): - name = string.ascii_letters[i] + if field_names: + name = field_names[i] + else: + name = string.ascii_letters[i] if '=' in item: name, item = item.split('=', 1) item_dtypes.append( - name + ' ' + dtype_to_sql(item, function_type=function_type), + f'`{name}` ' + dtype_to_sql(item, function_type=function_type), ) if function_type == 'udf': return f'RECORD({", ".join(item_dtypes)}){nullable}{default_clause}' else: - return f'TABLE({", ".join(item_dtypes)}){nullable}{default_clause}'\ - .replace(' NOT NULL', '') + return re.sub( + r' NOT NULL\s*$', r'', + f'TABLE({", ".join(item_dtypes)}){nullable}{default_clause}', + ) return f'{sql_type_map[dtype]}{nullable}{default_clause}' diff --git a/singlestoredb/tests/test_udf.py b/singlestoredb/tests/test_udf.py index 7e0baec2..aa21c478 100755 --- a/singlestoredb/tests/test_udf.py +++ b/singlestoredb/tests/test_udf.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # type: ignore """SingleStoreDB UDF testing.""" +import dataclasses import datetime import re import unittest @@ -11,9 +12,11 @@ from typing import Union import numpy as np +import pydantic from ..functions import dtypes as dt from ..functions import signature as sig +from ..functions import tvf from ..functions import udf @@ -28,7 +31,7 @@ def to_sql(x): out = sig.signature_to_sql(sig.get_signature(x)) out = re.sub(r'^CREATE EXTERNAL FUNCTION ', r'', out) out = re.sub(r' AS REMOTE SERVICE.+$', r'', out) - return out + return out.strip() class TestUDF(unittest.TestCase): @@ -99,27 +102,27 @@ def foo() -> Union[int, str]: ... # Tuple def foo() -> Tuple[int, float, str]: ... - assert to_sql(foo) == '`foo`() RETURNS RECORD(a BIGINT NOT NULL, ' \ - 'b DOUBLE NOT NULL, ' \ - 'c TEXT NOT NULL) NOT NULL' + assert to_sql(foo) == '`foo`() RETURNS RECORD(`a` BIGINT NOT NULL, ' \ + '`b` DOUBLE NOT NULL, ' \ + '`c` TEXT NOT NULL) NOT NULL' # Optional tuple def foo() -> Optional[Tuple[int, float, str]]: ... - assert to_sql(foo) == '`foo`() RETURNS RECORD(a BIGINT NOT NULL, ' \ - 'b DOUBLE NOT NULL, ' \ - 'c TEXT NOT NULL) NULL' + assert to_sql(foo) == '`foo`() RETURNS RECORD(`a` BIGINT NOT NULL, ' \ + '`b` DOUBLE NOT NULL, ' \ + '`c` TEXT NOT NULL) NULL' # Optional tuple with optional element def foo() -> Optional[Tuple[int, float, Optional[str]]]: ... - assert to_sql(foo) == '`foo`() RETURNS RECORD(a BIGINT NOT NULL, ' \ - 'b DOUBLE NOT NULL, ' \ - 'c TEXT NULL) NULL' + assert to_sql(foo) == '`foo`() RETURNS RECORD(`a` BIGINT NOT NULL, ' \ + '`b` DOUBLE NOT NULL, ' \ + '`c` TEXT NULL) NULL' # Optional tuple with optional union element def foo() -> Optional[Tuple[int, Optional[Union[float, int]], str]]: ... - assert to_sql(foo) == '`foo`() RETURNS RECORD(a BIGINT NOT NULL, ' \ - 'b DOUBLE NULL, ' \ - 'c TEXT NOT NULL) NULL' + assert to_sql(foo) == '`foo`() RETURNS RECORD(`a` BIGINT NOT NULL, ' \ + '`b` DOUBLE NULL, ' \ + '`c` TEXT NOT NULL) NULL' # Unknown type def foo() -> set: ... @@ -182,21 +185,21 @@ def foo(x: Union[int, str]) -> None: ... # Tuple def foo(x: Tuple[int, float, str]) -> None: ... - assert to_sql(foo) == '`foo`(`x` RECORD(a BIGINT NOT NULL, ' \ - 'b DOUBLE NOT NULL, ' \ - 'c TEXT NOT NULL) NOT NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` RECORD(`a` BIGINT NOT NULL, ' \ + '`b` DOUBLE NOT NULL, ' \ + '`c` TEXT NOT NULL) NOT NULL) RETURNS NULL' # Optional tuple with optional element def foo(x: Optional[Tuple[int, float, Optional[str]]]) -> None: ... - assert to_sql(foo) == '`foo`(`x` RECORD(a BIGINT NOT NULL, ' \ - 'b DOUBLE NOT NULL, ' \ - 'c TEXT NULL) NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` RECORD(`a` BIGINT NOT NULL, ' \ + '`b` DOUBLE NOT NULL, ' \ + '`c` TEXT NULL) NULL) RETURNS NULL' # Optional tuple with optional union element def foo(x: Optional[Tuple[int, Optional[Union[float, int]], str]]) -> None: ... - assert to_sql(foo) == '`foo`(`x` RECORD(a BIGINT NOT NULL, ' \ - 'b DOUBLE NULL, ' \ - 'c TEXT NOT NULL) NULL) RETURNS NULL' + assert to_sql(foo) == '`foo`(`x` RECORD(`a` BIGINT NOT NULL, ' \ + '`b` DOUBLE NULL, ' \ + '`c` TEXT NOT NULL) NULL) RETURNS NULL' # Unknown type def foo(x: set) -> None: ... @@ -402,6 +405,65 @@ def foo(x: int) -> int: ... assert to_sql(foo) == '`hello``_``world`(`x` BIGINT NOT NULL) ' \ 'RETURNS BIGINT NOT NULL' + @dataclasses.dataclass + class MyData: + one: Optional[int] + two: str + three: float + + @udf + def foo(x: int) -> MyData: ... + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ + 'RETURNS RECORD(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ + '`three` DOUBLE NOT NULL) NOT NULL' + + @udf(returns=MyData) + def foo(x: int) -> Tuple[int, int, int]: ... + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ + 'RETURNS RECORD(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ + '`three` DOUBLE NOT NULL) NOT NULL' + + @tvf + def foo(x: int) -> MyData: ... + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ + 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ + '`three` DOUBLE NOT NULL)' + + @tvf(returns=MyData) + def foo(x: int) -> Tuple[int, int, int]: ... + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ + 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ + '`three` DOUBLE NOT NULL)' + + class MyData(pydantic.BaseModel): + one: Optional[int] + two: str + three: float + + @udf + def foo(x: int) -> MyData: ... + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ + 'RETURNS RECORD(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ + '`three` DOUBLE NOT NULL) NOT NULL' + + @udf(returns=MyData) + def foo(x: int) -> Tuple[int, int, int]: ... + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ + 'RETURNS RECORD(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ + '`three` DOUBLE NOT NULL) NOT NULL' + + @tvf + def foo(x: int) -> MyData: ... + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ + 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ + '`three` DOUBLE NOT NULL)' + + @tvf(returns=MyData) + def foo(x: int) -> Tuple[int, int, int]: ... + assert to_sql(foo) == '`foo`(`x` BIGINT NOT NULL) ' \ + 'RETURNS TABLE(`one` BIGINT NULL, `two` TEXT NOT NULL, ' \ + '`three` DOUBLE NOT NULL)' + def test_dtypes(self): assert dt.BOOL() == 'BOOL NULL' assert dt.BOOL(nullable=False) == 'BOOL NOT NULL' diff --git a/test-requirements.txt b/test-requirements.txt index e0ce25a4..e2ab7aaf 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -7,6 +7,7 @@ pandas parameterized polars pyarrow +pydantic pygeos ; python_version < '3.12' pytest pytest-cov