Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add allow_module_overrides to AsyncDriver #1217

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion hamilton/async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
*modules,
result_builder: Optional[base.ResultMixin] = None,
adapters: typing.List[lifecycle.LifecycleAdapter] = None,
allow_module_overrides: bool = False,
):
"""Instantiates an asynchronous driver.

Expand All @@ -210,8 +211,12 @@ def __init__(

:param config: Config to build the graph
:param modules: Modules to crawl for fns/graph nodes
:param adapters: Adapters to use for lifecycle methods.
:param result_builder: Results mixin to compile the graph's final results. TBD whether this should be included in the long run.
:param adapters: Adapters to use for lifecycle methods.
:param allow_module_overrides: Optional. Same named functions get overridden by later modules.
The order of listing the modules is important, since later ones will overwrite the previous ones.
This is a global call affecting all imported modules.
See https://github.com/DAGWorks-Inc/hamilton/tree/main/examples/module_overrides for more info.
"""
if adapters is None:
adapters = []
Expand Down Expand Up @@ -243,6 +248,7 @@ def __init__(
*sync_adapters,
*async_adapters, # note async adapters will not be called during synchronous execution -- this is for access later
],
allow_module_overrides=allow_module_overrides,
)
self.initialized = False

Expand Down Expand Up @@ -492,6 +498,7 @@ def build_without_init(self) -> AsyncDriver:
*self.modules,
adapters=self.adapters,
result_builder=specified_result_builder,
allow_module_overrides=self._allow_module_overrides,
)

async def build(self):
Expand Down
38 changes: 38 additions & 0 deletions tests/test_async_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,41 @@ async def post_graph_construct(self, **kwargs):
assert hook.ran

# builder.result_builder = result_builder


@pytest.mark.asyncio
async def test_async_builder_allow_module_overrides():
def foo() -> int:
return 1

mod1 = ad_hoc_utils.create_temporary_module(foo)

def foo() -> int:
return 2

mod2 = ad_hoc_utils.create_temporary_module(foo)

# Should raise without .allow_module_overrides()
with pytest.raises(ValueError) as e:
await async_driver.Builder().with_modules(mod1, mod2).build()

assert "Cannot define function foo more than once." in str(e.value)

# build_without_init should also raise
with pytest.raises(ValueError) as e:
async_driver.Builder().with_modules(mod1, mod2).build_without_init()

assert "Cannot define function foo more than once." in str(e.value)

# Should not raise with .allow_module_overrides()
dr = await async_driver.Builder().with_modules(mod1, mod2).allow_module_overrides().build()
assert (await dr.execute(final_vars=["foo"])) == {"foo": 2}

# Same with build_without_init
dr = (
async_driver.Builder()
.with_modules(mod1, mod2)
.allow_module_overrides()
.build_without_init()
)
assert (await dr.execute(final_vars=["foo"])) == {"foo": 2}