Skip to content

Commit

Permalink
Add async support for pipe_family
Browse files Browse the repository at this point in the history
Enables running pipe_input, pipe_output and mutate with asyncio.
  • Loading branch information
jernejfrank committed Nov 12, 2024
1 parent 239d142 commit f7c7e04
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
11 changes: 9 additions & 2 deletions hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def resolve_namespace(self, default_namespace: str) -> Tuple[str, ...]:
def bind_function_args(
self, current_param: Optional[str]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Binds function arguments, given current, chained parameeter
"""Binds function arguments, given current, chained parameter
:param current_param: Current, chained parameter. None, if we're not chaining.
:return: A tuple of (upstream_inputs, literal_inputs)
Expand Down Expand Up @@ -1302,10 +1302,17 @@ def transform_node(
# We pick a reserved prefix that ovoids clashes with user defined functions / nodes
original_node = node_.copy_with(name=f"{node_.name}.raw")

is_async = inspect.iscoroutinefunction(fn) # determine if its async

def __identity(foo: Any) -> Any:
return foo

transforms = transforms + (step(__identity).named(fn.__name__),)
async def async_function(**kwargs):
return await __identity(**kwargs)

fn_to_use = async_function if is_async else __identity

transforms = transforms + (step(fn_to_use).named(fn.__name__),)
nodes, _ = chain_transforms(
target_arg=original_node.name,
transforms=transforms,
Expand Down
10 changes: 9 additions & 1 deletion hamilton/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,23 @@ def reassign_inputs(
if input_values is None:
input_values = {}

is_async = inspect.iscoroutinefunction(self.callable) # determine if its async

def new_callable(**kwargs) -> Any:
reverse_input_names = {v: k for k, v in input_names.items()}
kwargs = {**kwargs, **input_values}
return self.callable(**{reverse_input_names.get(k, k): v for k, v in kwargs.items()})

async def async_function(**kwargs):
return await new_callable(**kwargs)

fn_to_use = async_function if is_async else new_callable

new_input_types = {
input_names.get(k, k): v for k, v in self.input_types.items() if k not in input_values
}
out = self.copy_with(callabl=new_callable, input_types=new_input_types)
# out = self.copy_with(callabl=new_callable, input_types=new_input_types)
out = self.copy_with(callabl=fn_to_use, input_types=new_input_types)
return out

def transform_output(
Expand Down

0 comments on commit f7c7e04

Please sign in to comment.