Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove branching in UDF wrappers via. code gen #5487

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 93 additions & 75 deletions py/server/deephaven/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def _parse_signature(fn: Callable) -> _ParsedSignature:
p_sig.ret_annotation = _parse_return_annotation(t)
return p_sig


def _udf_parser(fn: Callable):
"""A decorator that acts as a transparent translator for Python UDFs used in Deephaven query formulas between
Python and Java. This decorator is intended for internal use by the Deephaven query engine and should not be used by
Expand All @@ -528,90 +527,31 @@ def _udf_parser(fn: Callable):
ret_dtype = dtypes.from_np_dtype(np.dtype(ret_np_char if ret_np_char != "X" else "O"))

@wraps(fn)
def _udf_decorator(encoded_arg_types: str, for_vectorization: bool):
def _udf_decorator(encoded_arg_types: str, for_vectorization: bool) -> Callable:
"""The actual decorator that wraps the Python UDF and converts the arguments and return values.
It is called by the query engine with the runtime argument types to create a wrapper that can efficiently
convert the arguments and return values based on the provided argument types and the parsed parameters of the
UDF.
"""
arg_conv_needed = p_sig.prepare_auto_arg_conv(encoded_arg_types)
p_sig.ret_annotation.setup_return_converter()
ret_converter = p_sig.ret_annotation.ret_converter
nonlocal ret_dtype # used in converting array-type return values, bring it into the local scope before exec()

if not for_vectorization:
if not arg_conv_needed and p_sig.ret_annotation.encoded_type == "O":
return fn

def _wrapper(*args, **kwargs):
if arg_conv_needed:
converted_args = [param.arg_converter(arg) if param.arg_converter else arg
for param, arg in zip(p_sig.params, args)]

# if the number of arguments is more than the number of parameters, treat the last parameter as a
# vararg and use its arg_converter to convert the rest of the arguments
if len(args) > len(p_sig.params):
arg_converter = p_sig.params[-1].arg_converter
converted_args.extend([arg_converter(arg) if arg_converter else arg
for arg in args[len(converted_args):]])
else:
converted_args = args
# kwargs are not converted because they are not used in the UDFs
ret = fn(*converted_args, **kwargs)
if return_array:
return dtypes.array(ret_dtype, ret)
else:
return p_sig.ret_annotation.ret_converter(ret) if p_sig.ret_annotation.ret_converter else ret

return _wrapper
else: # for vectorization
def _vectorization_wrapper(*args):
if len(args) != len(p_sig.params) + 2:
raise ValueError(
f"The number of arguments doesn't match the function ({p_sig.fn.__name__}) signature. "
f"{len(args) - 2}, {p_sig.encoded}")
if args[0] <= 0:
raise ValueError(
f"The chunk size argument must be a positive integer for vectorized function ("
f"{p_sig.fn.__name__}). {args[0]}")

chunk_size = args[0]
chunk_result = args[1]
if args[2:]:
vectorized_args = zip(*args[2:])
for i in range(chunk_size):
scalar_args = next(vectorized_args)
if arg_conv_needed:
converted_args = [param.arg_converter(arg) if param.arg_converter else arg
for param, arg in zip(p_sig.params, scalar_args)]

# if the number of arguments is more than the number of parameters, treat the last parameter
# as a vararg and use its arg_converter to convert the rest of the arguments
if len(args) > len(p_sig.params):
arg_converter = p_sig.params[-1].arg_converter
converted_args.extend([arg_converter(arg) if arg_converter else arg
for arg in scalar_args[len(converted_args):]])
else:
converted_args = scalar_args
if not for_vectorization and not arg_conv_needed and not ret_converter and not return_array:
# no wrapper needed
return fn

ret = fn(*converted_args)
if return_array:
chunk_result[i] = dtypes.array(ret_dtype, ret)
else:
chunk_result[i] = p_sig.ret_annotation.ret_converter(
ret) if p_sig.ret_annotation.ret_converter else ret
else:
for i in range(chunk_size):
ret = fn()
if return_array:
chunk_result[i] = dtypes.array(ret_dtype, ret)
else:
chunk_result[i] = p_sig.ret_annotation.ret_converter(
ret) if p_sig.ret_annotation.ret_converter else ret
return chunk_result
_wrapper_str = _gen_wrapper_code(p_sig, for_vectorization, arg_conv_needed, return_array)
scope = {**globals(), **locals()}
exec(_wrapper_str, scope)
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved

if for_vectorization and test_vectorization:
global vectorized_count
vectorized_count += 1

return scope["_wrapper"]

if test_vectorization:
global vectorized_count
vectorized_count += 1
return _vectorization_wrapper

_udf_decorator.j_name = ret_dtype.j_name
real_ret_dtype = _BUILDABLE_ARRAY_DTYPE_MAP.get(ret_dtype, dtypes.PyObject) if return_array else ret_dtype
Expand All @@ -625,3 +565,81 @@ def _vectorization_wrapper(*args):
_udf_decorator.signature = p_sig.encoded

return _udf_decorator

# region Wrapper Code Generation
# for non-vectorize-able UDFs
INDENT_STR= " " * 4
WRAPPER_HEADER = """def _wrapper(*args):"""
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
ARG_CONV = """
converted_args = [param.arg_converter(arg) if param.arg_converter else arg
for param, arg in zip(p_sig.params, args)]

# if the number of arguments is more than the number of parameters, treat the last parameter as a
# vararg and use its arg_converter to convert the rest of the arguments
if len(args) > len(p_sig.params):
arg_converter = p_sig.params[-1].arg_converter
converted_args.extend([arg_converter(arg) if arg_converter else arg for arg in args[len(converted_args):]])
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
ret = fn(*converted_args)"""
NO_ARG_CONV = """ret = fn(*args)"""
ARRAY_RET = """return dtypes.array(ret_dtype, ret)"""
SCALAR_RET = """return ret_converter(ret) if ret_converter else ret"""

# for vectorize-able UDFs
V_WRAPPER_HEADER = """
def _wrapper(*args):
chunk_size = args[0]
chunk_result = args[1]"""
V_ARRAY_RET = """chunk_result[i] = dtypes.array(ret_dtype, ret)"""
V_SCALAR_RET = """chunk_result[i] = ret_converter(ret) if ret_converter else ret"""
V_ZIP_CHUNK_ARGS = """vectorized_args = zip(*args[2:])"""
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
V_ARG_CONV = """
scalar_args = next(vectorized_args)
converted_args = [param.arg_converter(arg) if param.arg_converter else arg
for param, arg in zip(p_sig.params, scalar_args)]
# if the number of arguments is more than the number of parameters, treat the last
# parameter as a vararg and use its arg_converter to convert the rest of the arguments
if len(args) > len(p_sig.params):
arg_converter = p_sig.params[-1].arg_converter
converted_args.extend([arg_converter(arg) if arg_converter else arg
for arg in scalar_args[len(converted_args):]])
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
ret = fn(*converted_args)"""
V_NO_ARG_CONV = """ret = fn(*next(vectorized_args))"""
V_NO_ARG = """ret = fn()"""
V_LOOP_CHUNK = """for i in range(chunk_size):"""
V_RET_CHUNK = "return chunk_result"


def _gen_wrapper_code(p_sig: _ParsedSignature, for_vectorization: bool, arg_conv_needed: bool, return_array: bool) -> str:
""" Generate the wrapper code for the UDF based on the parsed signature and the context of the UDF usage."""
if not for_vectorization:
conv_str = ARG_CONV if arg_conv_needed else NO_ARG_CONV
ret_str = ARRAY_RET if return_array else SCALAR_RET
wrapper_str = (WRAPPER_HEADER + "\n"
+ INDENT_STR + conv_str.replace("\n", "\n" + INDENT_STR) + "\n"
+ INDENT_STR + ret_str.replace("\n", "\n" + INDENT_STR) + "\n")
return wrapper_str + "\n"
else:
ret_str = V_ARRAY_RET if return_array else V_SCALAR_RET

if len(p_sig.params) == 0:
wrapper_str = (V_WRAPPER_HEADER + "\n"
+ INDENT_STR + V_LOOP_CHUNK + "\n"
+ INDENT_STR * 2 + V_NO_ARG + "\n"
+ INDENT_STR * 2 + ret_str)
else:
wrapper_str = (V_WRAPPER_HEADER + "\n"
+ INDENT_STR + V_ZIP_CHUNK_ARGS + "\n"
+ INDENT_STR + V_LOOP_CHUNK + "\n")
if arg_conv_needed:
wrapper_str = (wrapper_str
+ INDENT_STR * 2 + V_ARG_CONV.replace("\n", "\n" + INDENT_STR * 2) + "\n"
+ INDENT_STR * 2 + ret_str + "\n")
else:
wrapper_str = (wrapper_str
+ INDENT_STR * 2 + V_NO_ARG_CONV.replace("\n", "\n" + INDENT_STR * 2) + "\n"
+ INDENT_STR * 2 + ret_str + "\n")

return (wrapper_str + "\n"
+ INDENT_STR + V_RET_CHUNK + "\n")
# endregion

Loading