Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More helpful error when UDF arg type check failed #5175

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 43 additions & 34 deletions py/server/deephaven/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@


@dataclass
class _ParsedParamAnnotation:
class _ParsedParam:
name: Union[str, int] = field(init=True)
orig_types: set[type] = field(default_factory=set)
encoded_types: set[str] = field(default_factory=set)
none_allowed: bool = False
Expand All @@ -54,7 +55,7 @@ class _ParsedReturnAnnotation:
@dataclass
class _ParsedSignature:
fn: Callable = None
params: List[_ParsedParamAnnotation] = field(default_factory=list)
params: List[_ParsedParam] = field(default_factory=list)
ret_annotation: _ParsedReturnAnnotation = None

@property
Expand Down Expand Up @@ -93,22 +94,22 @@ def _encode_param_type(t: type) -> str:
return tc


def _parse_param_annotation(annotation: Any) -> _ParsedParamAnnotation:
def _parse_param(name: str, annotation: Any) -> _ParsedParam:
""" Parse a parameter annotation in a function's signature """
p_annotation = _ParsedParamAnnotation()
p_param = _ParsedParam(name)

if annotation is inspect._empty:
p_annotation.encoded_types.add("O")
p_annotation.none_allowed = True
p_param.encoded_types.add("O")
p_param.none_allowed = True
elif isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union:
for t in annotation.__args__:
_parse_type_no_nested(annotation, p_annotation, t)
_parse_type_no_nested(annotation, p_param, t)
else:
_parse_type_no_nested(annotation, p_annotation, annotation)
return p_annotation
_parse_type_no_nested(annotation, p_param, annotation)
return p_param


def _parse_type_no_nested(annotation: Any, p_annotation: _ParsedParamAnnotation, t: Union[type, str]) -> None:
def _parse_type_no_nested(annotation: Any, p_param: _ParsedParam, t: Union[type, str]) -> None:
""" Parse a specific type (top level or nested in a top-level Union annotation) without handling nested types
(e.g. a nested Union). The result is stored in the given _ParsedAnnotation object.
"""
Expand All @@ -117,25 +118,25 @@ def _parse_type_no_nested(annotation: Any, p_annotation: _ParsedParamAnnotation,
# annotation is already a type, and we can remove this line.
t = eval(t) if isinstance(t, str) else t

p_annotation.orig_types.add(t)
p_param.orig_types.add(t)
tc = _encode_param_type(t)
if "[" in tc:
p_annotation.has_array = True
p_param.has_array = True
if tc in {"N", "O"}:
p_annotation.none_allowed = True
p_param.none_allowed = True
if tc in _NUMPY_INT_TYPE_CODES:
if p_annotation.int_char and p_annotation.int_char != tc:
if p_param.int_char and p_param.int_char != tc:
raise DHError(message=f"multiple integer types in annotation: {annotation}, "
f"types: {p_annotation.int_char}, {tc}. this is not supported because it is not "
f"types: {p_param.int_char}, {tc}. this is not supported because it is not "
f"clear which Deephaven null value to use when checking for nulls in the argument")
p_annotation.int_char = tc
p_param.int_char = tc
if tc in _NUMPY_FLOATING_TYPE_CODES:
if p_annotation.floating_char and p_annotation.floating_char != tc:
if p_param.floating_char and p_param.floating_char != tc:
raise DHError(message=f"multiple floating types in annotation: {annotation}, "
f"types: {p_annotation.floating_char}, {tc}. this is not supported because it is not "
f"types: {p_param.floating_char}, {tc}. this is not supported because it is not "
f"clear which Deephaven null value to use when checking for nulls in the argument")
p_annotation.floating_char = tc
p_annotation.encoded_types.add(tc)
p_param.floating_char = tc
p_param.encoded_types.add(tc)


def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation:
Expand Down Expand Up @@ -182,8 +183,8 @@ def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufun
p_sig.ret_annotation.encoded_type = rt_char

if isinstance(fn, numba.np.ufunc.dufunc.DUFunc):
for p in params:
pa = _ParsedParamAnnotation()
for i, p in enumerate(params):
pa = _ParsedParam(i + 1)
pa.encoded_types.add(p)
if p in _NUMPY_INT_TYPE_CODES:
pa.int_char = p
Expand All @@ -198,8 +199,8 @@ def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufun
input_decl = re.sub("[()]", "", input_decl).split(",")
output_decl = re.sub("[()]", "", output_decl)

for p, d in zip(params, input_decl):
pa = _ParsedParamAnnotation()
for i, (p, d) in enumerate(zip(params, input_decl)):
pa = _ParsedParam(i + 1)
if d:
pa.encoded_types.add("[" + p)
pa.has_array = True
Expand All @@ -225,9 +226,10 @@ def _parse_np_ufunc_signature(fn: numpy.ufunc) -> _ParsedSignature:
# them in the future (https://github.com/deephaven/deephaven-core/issues/4762)
p_sig = _ParsedSignature(fn)
if fn.nin > 0:
pa = _ParsedParamAnnotation()
pa.encoded_types.add("O")
p_sig.params = [pa] * fn.nin
for i in range(fn.nin):
pa = _ParsedParam(i + 1)
pa.encoded_types.add("O")
p_sig.params.append(pa)
p_sig.ret_annotation = _ParsedReturnAnnotation()
p_sig.ret_annotation.encoded_type = "O"
return p_sig
Expand All @@ -249,7 +251,7 @@ def _parse_signature(fn: Callable) -> _ParsedSignature:
else:
sig = inspect.signature(fn)
for n, p in sig.parameters.items():
p_sig.params.append(_parse_param_annotation(p.annotation))
p_sig.params.append(_parse_param(n, p.annotation))

p_sig.ret_annotation = _parse_return_annotation(sig.return_annotation)
return p_sig
Expand All @@ -263,11 +265,11 @@ def _is_from_np_type(param_types: set[type], np_type_char: str) -> bool:
return False


def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
def _convert_arg(param: _ParsedParam, arg: Any) -> Any:
""" Convert a single argument to the type specified by the annotation """
if arg is None:
if not param.none_allowed:
raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}")
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation {param.orig_types}")
else:
return None

Expand All @@ -277,12 +279,17 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
# if it matches one of the encoded types, convert it
if encoded_type in param.encoded_types:
dtype = dtypes.from_np_dtype(np_dtype)
return _j_array_to_numpy_array(dtype, arg, conv_null=True, type_promotion=False)
try:
return _j_array_to_numpy_array(dtype, arg, conv_null=True, type_promotion=False)
except Exception as e:
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation"
f" {param.encoded_types}"
f"\n{str(e)}") from e
# if the annotation is missing, or it is a generic object type, return the arg
elif "O" in param.encoded_types:
return arg
else:
raise TypeError(f"Argument {arg} is not compatible with annotation {param.encoded_types}")
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation {param.encoded_types}")
else: # if the arg is not a Java array
specific_types = param.encoded_types - {"N", "O"} # remove NoneType and object type
if specific_types:
Expand All @@ -300,7 +307,8 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
if param.none_allowed:
return None
else:
raise DHError(f"Argument {arg} is not compatible with annotation {param.orig_types}")
raise DHError(f"Argument {param.name!r}: {arg} is not compatible with annotation"
f" {param.orig_types}")
else:
# return a numpy integer instance only if the annotation is a numpy type
if _is_from_np_type(param.orig_types, param.int_char):
Expand Down Expand Up @@ -332,7 +340,8 @@ def _convert_arg(param: _ParsedParamAnnotation, arg: Any) -> Any:
if "O" in param.encoded_types:
return arg
else:
raise TypeError(f"Argument {arg} is not compatible with annotation {param.orig_types}")
raise TypeError(f"Argument {param.name!r}: {arg} is not compatible with annotation"
f" {param.orig_types}")
else: # if no annotation or generic object, return arg
return arg

Expand Down
14 changes: 13 additions & 1 deletion py/server/tests/test_numba_guvectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from numba import guvectorize, int64, int32

from deephaven import empty_table, dtypes
from deephaven import empty_table, dtypes, DHError
from tests.testbase import BaseTestCase

a = np.arange(5, dtype=np.int64)
Expand Down Expand Up @@ -89,6 +89,18 @@ def g(x, res):
t = empty_table(10).update(["X=i%3", "Y=ii"]).group_by("X").update("Z=g(Y)")
self.assertEqual(t.columns[2].data_type, dtypes.long_array)

def test_type_mismatch_error(self):
# vector input to scalar output function (m)->()
@guvectorize([(int64[:], int64[:])], "(m)->()", nopython=True)
def g(x, res):
res[0] = 0
for xi in x:
res[0] += xi

with self.assertRaises(DHError) as cm:
t = empty_table(10).update(["X=i%3", "Y=(double)ii"]).group_by("X").update("Z=g(Y)")
self.assertIn("Argument 1", str(cm.exception))


if __name__ == '__main__':
unittest.main()
8 changes: 4 additions & 4 deletions py/server/tests/test_udf_numpy_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def f3(p1: np.ndarray[np.bool_], p2=None) -> bool:
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
with self.assertRaises(DHError) as cm:
t2 = t.update(["X1 = f3(null, Y )"])
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")

def f31(p1: Optional[np.ndarray[bool]], p2=None) -> bool:
return bool(len(p1)) if p1 is not None else False
Expand All @@ -352,7 +352,7 @@ def f1(p1: str, p2=None) -> bool:
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? `deephaven`: null"])
with self.assertRaises(DHError) as cm:
t1 = t.update(["X1 = f1(Y)"])
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")

def f11(p1: Union[str, None], p2=None) -> bool:
return p1 is None
Expand All @@ -366,7 +366,7 @@ def f2(p1: np.datetime64, p2=None) -> bool:
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? now() : null"])
with self.assertRaises(DHError) as cm:
t1 = t.update(["X1 = f2(Y)"])
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")

def f21(p1: Union[np.datetime64, None], p2=None) -> bool:
return p1 is None
Expand All @@ -380,7 +380,7 @@ def f3(p1: np.bool_, p2=None) -> bool:
t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : null"])
with self.assertRaises(DHError) as cm:
t1 = t.update(["X1 = f3(Y)"])
self.assertRegex(str(cm.exception), "Argument None is not compatible with annotation")
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")

t = empty_table(10).update(["X = i % 3", "Y = i % 2 == 0? true : false"])
t1 = t.update(["X1 = f3(Y)"])
Expand Down
Loading