Skip to content

Commit

Permalink
Support use of DType instance in UDF annotation (#5200)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver authored and stanbrub committed Feb 28, 2024
1 parent 4698b9f commit 6e48391
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 0 deletions.
10 changes: 10 additions & 0 deletions py/server/deephaven/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def _parse_type_no_nested(annotation: Any, p_param: _ParsedParam, t: Union[type,
t = eval(t) if isinstance(t, str) else t

p_param.orig_types.add(t)

# if the annotation is a DH DType instance, we'll use its numpy type
if isinstance(t, dtypes.DType):
t = t.np_type

tc = _encode_param_type(t)
if "[" in tc:
p_param.has_array = True
Expand Down Expand Up @@ -157,6 +162,11 @@ def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation:
t = annotation.__args__[0]
elif annotation.__args__[0] == type(None): # noqa: E721
t = annotation.__args__[1]

# if the annotation is a DH DType instance, we'll use its numpy type
if isinstance(t, dtypes.DType):
t = t.np_type

component_char = _component_np_dtype_char(t)
if component_char:
pra.encoded_type = "[" + component_char
Expand Down
3 changes: 3 additions & 0 deletions py/server/deephaven/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,9 @@ def _component_np_dtype_char(t: type) -> Optional[str]:
component_type = None
if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence):
component_type = t.__args__[0]
# if the component type is a DType, get its numpy type
if isinstance(component_type, DType):
component_type = component_type.np_type

if not component_type:
component_type = _np_ndarray_component_type(t)
Expand Down
10 changes: 10 additions & 0 deletions py/server/tests/test_udf_return_java_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ def nbsin(x):
t3 = empty_table(10).update(["X3 = nbsin(i)"])
self.assertEqual(t3.columns[0].data_type, dtypes.double)

def test_java_instant_return(self):
from deephaven.time import to_j_instant

t = empty_table(10).update(["X1 = to_j_instant(`2021-01-01T00:00:00Z`)"])
self.assertEqual(t.columns[0].data_type, dtypes.Instant)

def udf() -> List[dtypes.Instant]:
return [to_j_instant("2021-01-01T00:00:00Z")]
t = empty_table(10).update(["X1 = udf()"])
self.assertEqual(t.columns[0].data_type, dtypes.instant_array)

if __name__ == '__main__':
unittest.main()

0 comments on commit 6e48391

Please sign in to comment.