Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async support for pipe_family #1223

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def resolve_namespace(self, default_namespace: str) -> Tuple[str, ...]:
def bind_function_args(
self, current_param: Optional[str]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Binds function arguments, given current, chained parameeter
"""Binds function arguments, given current, chained parameter
:param current_param: Current, chained parameter. None, if we're not chaining.
:return: A tuple of (upstream_inputs, literal_inputs)
Expand Down Expand Up @@ -1302,10 +1302,17 @@ def transform_node(
# We pick a reserved prefix that ovoids clashes with user defined functions / nodes
original_node = node_.copy_with(name=f"{node_.name}.raw")

is_async = inspect.iscoroutinefunction(fn) # determine if its async
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic to check if a function is async using inspect.iscoroutinefunction is already present in multiple places in the codebase. Consider reusing existing implementations to avoid redundancy.

  • logic to determine if a function is async (expanders.py)
  • logic to determine if a function is async (recursive.py)
  • logic to determine if a function is async (base.py)
  • logic to determine if a function is async (node.py)


def __identity(foo: Any) -> Any:
return foo

transforms = transforms + (step(__identity).named(fn.__name__),)
async def async_function(**kwargs):
return await __identity(**kwargs)

fn_to_use = async_function if is_async else __identity

transforms = transforms + (step(fn_to_use).named(fn.__name__),)
nodes, _ = chain_transforms(
target_arg=original_node.name,
transforms=transforms,
Expand Down
10 changes: 9 additions & 1 deletion hamilton/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,23 @@ def reassign_inputs(
if input_values is None:
input_values = {}

is_async = inspect.iscoroutinefunction(self.callable) # determine if its async

def new_callable(**kwargs) -> Any:
reverse_input_names = {v: k for k, v in input_names.items()}
kwargs = {**kwargs, **input_values}
return self.callable(**{reverse_input_names.get(k, k): v for k, v in kwargs.items()})

async def async_function(**kwargs):
return await new_callable(**kwargs)

fn_to_use = async_function if is_async else new_callable

new_input_types = {
input_names.get(k, k): v for k, v in self.input_types.items() if k not in input_values
}
out = self.copy_with(callabl=new_callable, input_types=new_input_types)
# out = self.copy_with(callabl=new_callable, input_types=new_input_types)
out = self.copy_with(callabl=fn_to_use, input_types=new_input_types)
return out

def transform_output(
Expand Down
36 changes: 36 additions & 0 deletions tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,39 @@ def test_node_from_future_annotation_standard():
def test_node_from_future_annotation_collected():
collected = nodes_with_future_annotation.collected
assert node.Node.from_fn(collected).node_role == node.NodeType.COLLECT


def test_reassign_inputs():
def foo(a: int, b: str) -> int:
return a + len(b)

node_ = Node.from_fn(foo)

first_arg_node = node_.reassign_inputs(input_names={"a": "c"})
new_first_arg = list(first_arg_node.input_types.keys())
second_arg_node = node_.reassign_inputs(input_names={"b": "d"})
new_second_arg = list(second_arg_node.input_types.keys())
both_arg_node = node_.reassign_inputs(input_names={"a": "c", "b": "d"})
new_both_arg = list(both_arg_node.input_types.keys())
assert new_first_arg[0] == "c"
assert new_first_arg[1] == "b"
assert new_second_arg[0] == "a"
assert new_second_arg[1] == "d"
assert new_both_arg[0] == "c"
assert new_both_arg[1] == "d"
assert both_arg_node(**{"c": 2, "d": "abc"}) == 5


@pytest.mark.asyncio
async def test_subdag_async():
async def foo(a: int, b: str) -> int:
return a + len(b)

node_ = Node.from_fn(foo)

new_node = node_.reassign_inputs(input_names={"a": "c", "b": "d"})
new_args = list(new_node.input_types.keys())
assert new_args[0] == "c"
assert new_args[1] == "d"
assert inspect.iscoroutinefunction(new_node.callable)
assert await new_node(**{"c": 2, "d": "abc"}) == 5