diff --git a/hamilton/function_modifiers/recursive.py b/hamilton/function_modifiers/recursive.py index e05c13469..7f9504e31 100644 --- a/hamilton/function_modifiers/recursive.py +++ b/hamilton/function_modifiers/recursive.py @@ -1,3 +1,4 @@ +import inspect import sys from types import ModuleType from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union @@ -329,6 +330,9 @@ def add_namespace( # Reassign sources for node_ in nodes: + # This is not perfect -- we might get strangeness if its dynamically generated + # that said, it should work + is_async = inspect.iscoroutinefunction(node_.callable) new_name = new_name_map[node_.name] kwarg_mapping = { (new_name_map[key] if key in new_name_map else key): key @@ -347,12 +351,23 @@ def fn( new_kwargs = {_kwarg_mapping[kwarg]: value for kwarg, value in kwargs.items()} return _callabl(**new_kwargs) + async def async_fn( + _callabl=node_.callable, + _kwarg_mapping=dict(kwarg_mapping), + _new_name=new_name, + _new_name_map=dict(new_name_map), + **kwargs, + ): + new_kwargs = {_kwarg_mapping[kwarg]: value for kwarg, value in kwargs.items()} + return await _callabl(**new_kwargs) + new_input_types = { dep: node_.input_types[original_dep] for dep, original_dep in kwarg_mapping.items() } + fn_to_use = async_fn if is_async else fn new_nodes.append( - node_.copy_with(input_types=new_input_types, name=new_name, callabl=fn) + node_.copy_with(input_types=new_input_types, name=new_name, callabl=fn_to_use) ) return new_nodes @@ -362,6 +377,7 @@ def add_final_node(self, fn: Callable, node_name: str, namespace: str): :param fn: :return: """ + is_async = inspect.iscoroutinefunction(fn) # determine if its async node_ = node.Node.from_fn(fn) namespaced_input_map = { (assign_namespace(key, namespace) if key not in self.external_inputs else key): key @@ -380,7 +396,12 @@ def new_function(**kwargs): # Have to translate it back to use the kwargs the fn is expecting return fn(**kwargs_without_namespace) - return node_.copy_with(name=node_name, input_types=new_input_types, callabl=new_function) + async def async_function(**kwargs): + return await new_function(**kwargs) + + fn_to_use = async_function if is_async else new_function + + return node_.copy_with(name=node_name, input_types=new_input_types, callabl=fn_to_use) def _derive_namespace(self, fn: Callable) -> str: """Utility function to derive a namespace from a function. diff --git a/hamilton/version.py b/hamilton/version.py index a0b0d57e3..9bb8b401a 100644 --- a/hamilton/version.py +++ b/hamilton/version.py @@ -1 +1 @@ -VERSION = (1, 61, 0) +VERSION = (1, 62, 0, "rc0")