From ae8841b789ef2dc1f9ba1d286a03be3baefffce5 Mon Sep 17 00:00:00 2001 From: jernejfrank Date: Tue, 12 Nov 2024 20:33:44 +0800 Subject: [PATCH] Add async support for pipe_family Enables running pipe_input, pipe_output and mutate with asyncio. --- hamilton/function_modifiers/macros.py | 11 ++++++-- hamilton/node.py | 10 +++++++- tests/test_node.py | 36 +++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 3 deletions(-) diff --git a/hamilton/function_modifiers/macros.py b/hamilton/function_modifiers/macros.py index 7cc074985..dfdb28b53 100644 --- a/hamilton/function_modifiers/macros.py +++ b/hamilton/function_modifiers/macros.py @@ -608,7 +608,7 @@ def resolve_namespace(self, default_namespace: str) -> Tuple[str, ...]: def bind_function_args( self, current_param: Optional[str] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """Binds function arguments, given current, chained parameeter + """Binds function arguments, given current, chained parameter :param current_param: Current, chained parameter. None, if we're not chaining. :return: A tuple of (upstream_inputs, literal_inputs) @@ -1302,10 +1302,17 @@ def transform_node( # We pick a reserved prefix that ovoids clashes with user defined functions / nodes original_node = node_.copy_with(name=f"{node_.name}.raw") + is_async = inspect.iscoroutinefunction(fn) # determine if its async + def __identity(foo: Any) -> Any: return foo - transforms = transforms + (step(__identity).named(fn.__name__),) + async def async_function(**kwargs): + return await __identity(**kwargs) + + fn_to_use = async_function if is_async else __identity + + transforms = transforms + (step(fn_to_use).named(fn.__name__),) nodes, _ = chain_transforms( target_arg=original_node.name, transforms=transforms, diff --git a/hamilton/node.py b/hamilton/node.py index 41062a7ae..5755b0ed7 100644 --- a/hamilton/node.py +++ b/hamilton/node.py @@ -360,15 +360,23 @@ def reassign_inputs( if input_values is None: input_values = {} + is_async = inspect.iscoroutinefunction(self.callable) # determine if its async + def new_callable(**kwargs) -> Any: reverse_input_names = {v: k for k, v in input_names.items()} kwargs = {**kwargs, **input_values} return self.callable(**{reverse_input_names.get(k, k): v for k, v in kwargs.items()}) + async def async_function(**kwargs): + return await new_callable(**kwargs) + + fn_to_use = async_function if is_async else new_callable + new_input_types = { input_names.get(k, k): v for k, v in self.input_types.items() if k not in input_values } - out = self.copy_with(callabl=new_callable, input_types=new_input_types) + # out = self.copy_with(callabl=new_callable, input_types=new_input_types) + out = self.copy_with(callabl=fn_to_use, input_types=new_input_types) return out def transform_output( diff --git a/tests/test_node.py b/tests/test_node.py index 7f2909943..085415fbb 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -137,3 +137,39 @@ def test_node_from_future_annotation_standard(): def test_node_from_future_annotation_collected(): collected = nodes_with_future_annotation.collected assert node.Node.from_fn(collected).node_role == node.NodeType.COLLECT + + +def test_reassign_inputs(): + def foo(a: int, b: str) -> int: + return a + len(b) + + node_ = Node.from_fn(foo) + + first_arg_node = node_.reassign_inputs(input_names={"a": "c"}) + new_first_arg = list(first_arg_node.input_types.keys()) + second_arg_node = node_.reassign_inputs(input_names={"b": "d"}) + new_second_arg = list(second_arg_node.input_types.keys()) + both_arg_node = node_.reassign_inputs(input_names={"a": "c", "b": "d"}) + new_both_arg = list(both_arg_node.input_types.keys()) + assert new_first_arg[0] == "c" + assert new_first_arg[1] == "b" + assert new_second_arg[0] == "a" + assert new_second_arg[1] == "d" + assert new_both_arg[0] == "c" + assert new_both_arg[1] == "d" + assert both_arg_node(**{"c": 2, "d": "abc"}) == 5 + + +@pytest.mark.asyncio +async def test_subdag_async(): + async def foo(a: int, b: str) -> int: + return a + len(b) + + node_ = Node.from_fn(foo) + + new_node = node_.reassign_inputs(input_names={"a": "c", "b": "d"}) + new_args = list(new_node.input_types.keys()) + assert new_args[0] == "c" + assert new_args[1] == "d" + assert inspect.iscoroutinefunction(new_node.callable) + assert await new_node(**{"c": 2, "d": "abc"}) == 5