Skip to content

Commit

Permalink
Allows subdag to be async
Browse files Browse the repository at this point in the history
This just adds another path for the two functions we redefine. We need
to add tests but this works on a manual test for now.

See #903
  • Loading branch information
elijahbenizzy committed May 14, 2024
1 parent 4364d26 commit 86a8350
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
25 changes: 23 additions & 2 deletions hamilton/function_modifiers/recursive.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import sys
from types import ModuleType
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union
Expand Down Expand Up @@ -329,6 +330,9 @@ def add_namespace(

# Reassign sources
for node_ in nodes:
# This is not perfect -- we might get strangeness if its dynamically generated
# that said, it should work
is_async = inspect.iscoroutinefunction(node_.callable)
new_name = new_name_map[node_.name]
kwarg_mapping = {
(new_name_map[key] if key in new_name_map else key): key
Expand All @@ -347,12 +351,23 @@ def fn(
new_kwargs = {_kwarg_mapping[kwarg]: value for kwarg, value in kwargs.items()}
return _callabl(**new_kwargs)

async def async_fn(
_callabl=node_.callable,
_kwarg_mapping=dict(kwarg_mapping),
_new_name=new_name,
_new_name_map=dict(new_name_map),
**kwargs,
):
new_kwargs = {_kwarg_mapping[kwarg]: value for kwarg, value in kwargs.items()}
return await _callabl(**new_kwargs)

new_input_types = {
dep: node_.input_types[original_dep] for dep, original_dep in kwarg_mapping.items()
}
fn_to_use = async_fn if is_async else fn

new_nodes.append(
node_.copy_with(input_types=new_input_types, name=new_name, callabl=fn)
node_.copy_with(input_types=new_input_types, name=new_name, callabl=fn_to_use)
)
return new_nodes

Expand All @@ -362,6 +377,7 @@ def add_final_node(self, fn: Callable, node_name: str, namespace: str):
:param fn:
:return:
"""
is_async = inspect.iscoroutinefunction(fn) # determine if its async
node_ = node.Node.from_fn(fn)
namespaced_input_map = {
(assign_namespace(key, namespace) if key not in self.external_inputs else key): key
Expand All @@ -380,7 +396,12 @@ def new_function(**kwargs):
# Have to translate it back to use the kwargs the fn is expecting
return fn(**kwargs_without_namespace)

return node_.copy_with(name=node_name, input_types=new_input_types, callabl=new_function)
async def async_function(**kwargs):
return await new_function(**kwargs)

fn_to_use = async_function if is_async else new_function

return node_.copy_with(name=node_name, input_types=new_input_types, callabl=fn_to_use)

def _derive_namespace(self, fn: Callable) -> str:
"""Utility function to derive a namespace from a function.
Expand Down
2 changes: 1 addition & 1 deletion hamilton/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = (1, 61, 0)
VERSION = (1, 62, 0, "rc0")

0 comments on commit 86a8350

Please sign in to comment.