Skip to content

Commit

Permalink
Add async support for pipe_family
Browse files Browse the repository at this point in the history
Enables running pipe_input, pipe_output and mutate with asyncio.
  • Loading branch information
jernejfrank committed Nov 13, 2024
1 parent 18a12c0 commit ae8841b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
11 changes: 9 additions & 2 deletions hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def resolve_namespace(self, default_namespace: str) -> Tuple[str, ...]:
def bind_function_args(
self, current_param: Optional[str]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Binds function arguments, given current, chained parameeter
"""Binds function arguments, given current, chained parameter
:param current_param: Current, chained parameter. None, if we're not chaining.
:return: A tuple of (upstream_inputs, literal_inputs)
Expand Down Expand Up @@ -1302,10 +1302,17 @@ def transform_node(
# We pick a reserved prefix that ovoids clashes with user defined functions / nodes
original_node = node_.copy_with(name=f"{node_.name}.raw")

is_async = inspect.iscoroutinefunction(fn) # determine if its async

def __identity(foo: Any) -> Any:
return foo

transforms = transforms + (step(__identity).named(fn.__name__),)
async def async_function(**kwargs):
return await __identity(**kwargs)

fn_to_use = async_function if is_async else __identity

transforms = transforms + (step(fn_to_use).named(fn.__name__),)
nodes, _ = chain_transforms(
target_arg=original_node.name,
transforms=transforms,
Expand Down
10 changes: 9 additions & 1 deletion hamilton/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,23 @@ def reassign_inputs(
if input_values is None:
input_values = {}

is_async = inspect.iscoroutinefunction(self.callable) # determine if its async

def new_callable(**kwargs) -> Any:
reverse_input_names = {v: k for k, v in input_names.items()}
kwargs = {**kwargs, **input_values}
return self.callable(**{reverse_input_names.get(k, k): v for k, v in kwargs.items()})

async def async_function(**kwargs):
return await new_callable(**kwargs)

fn_to_use = async_function if is_async else new_callable

new_input_types = {
input_names.get(k, k): v for k, v in self.input_types.items() if k not in input_values
}
out = self.copy_with(callabl=new_callable, input_types=new_input_types)
# out = self.copy_with(callabl=new_callable, input_types=new_input_types)
out = self.copy_with(callabl=fn_to_use, input_types=new_input_types)
return out

def transform_output(
Expand Down
36 changes: 36 additions & 0 deletions tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,39 @@ def test_node_from_future_annotation_standard():
def test_node_from_future_annotation_collected():
collected = nodes_with_future_annotation.collected
assert node.Node.from_fn(collected).node_role == node.NodeType.COLLECT


def test_reassign_inputs():
def foo(a: int, b: str) -> int:
return a + len(b)

node_ = Node.from_fn(foo)

first_arg_node = node_.reassign_inputs(input_names={"a": "c"})
new_first_arg = list(first_arg_node.input_types.keys())
second_arg_node = node_.reassign_inputs(input_names={"b": "d"})
new_second_arg = list(second_arg_node.input_types.keys())
both_arg_node = node_.reassign_inputs(input_names={"a": "c", "b": "d"})
new_both_arg = list(both_arg_node.input_types.keys())
assert new_first_arg[0] == "c"
assert new_first_arg[1] == "b"
assert new_second_arg[0] == "a"
assert new_second_arg[1] == "d"
assert new_both_arg[0] == "c"
assert new_both_arg[1] == "d"
assert both_arg_node(**{"c": 2, "d": "abc"}) == 5


@pytest.mark.asyncio
async def test_subdag_async():
async def foo(a: int, b: str) -> int:
return a + len(b)

node_ = Node.from_fn(foo)

new_node = node_.reassign_inputs(input_names={"a": "c", "b": "d"})
new_args = list(new_node.input_types.keys())
assert new_args[0] == "c"
assert new_args[1] == "d"
assert inspect.iscoroutinefunction(new_node.callable)
assert await new_node(**{"c": 2, "d": "abc"}) == 5

0 comments on commit ae8841b

Please sign in to comment.