This repository has been archived by the owner on Jul 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds tests for async graphadapter + Driver
- Loading branch information
1 parent
0f8ddeb
commit ca4b36f
Showing
6 changed files
with
108 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pytest-asyncio |
Empty file.
28 changes: 28 additions & 0 deletions
28
graph_adapter_tests/h_async/resources/simple_async_module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import asyncio | ||
from typing import Dict | ||
|
||
import hamilton.function_modifiers | ||
|
||
|
||
async def simple_async_func(external_input: int) -> int: | ||
await asyncio.sleep(.01) | ||
return external_input + 1 | ||
|
||
|
||
async def async_func_with_param(simple_async_func: int, external_input: int) -> int: | ||
await asyncio.sleep(.01) | ||
return simple_async_func + external_input + 1 | ||
|
||
|
||
def simple_non_async_func(simple_async_func: int, async_func_with_param: int) -> int: | ||
return simple_async_func + async_func_with_param + 1 | ||
|
||
|
||
async def another_async_func(simple_non_async_func: int) -> int: | ||
await asyncio.sleep(.01) | ||
return simple_non_async_func + 1 | ||
|
||
|
||
@hamilton.function_modifiers.extract_fields(dict(result_1=int, result_2=int)) | ||
def non_async_func_with_decorator(async_func_with_param: int, another_async_func: int) -> Dict[str, int]: | ||
return {'result_1': another_async_func + 1, 'result_2': async_func_with_param + 1} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import asyncio | ||
import pdb | ||
|
||
import pytest | ||
|
||
from hamilton.experimental import h_async | ||
from .resources import simple_async_module | ||
|
||
|
||
async def async_identity(n: int) -> int: | ||
await asyncio.sleep(.01) | ||
return n | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_await_dict_of_coroutines(): | ||
tasks = {n: async_identity(n) for n in range(0, 10)} | ||
results = await h_async.await_dict_of_tasks(tasks) | ||
assert results == {n: await async_identity(n) for n in range(0, 10)} | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_await_dict_of_tasks(): | ||
tasks = {n: asyncio.create_task(async_identity(n)) for n in range(0, 10)} | ||
results = await h_async.await_dict_of_tasks(tasks) | ||
assert results == {n: await async_identity(n) for n in range(0, 10)} | ||
|
||
|
||
# The following are not parameterized as we need to use the event loop -- fixtures will complicate this | ||
@pytest.mark.asyncio | ||
async def test_process_value_raw(): | ||
assert await h_async.process_value(1) == 1 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_process_value_coroutine(): | ||
assert await h_async.process_value(async_identity(1)) == 1 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_process_value_task(): | ||
assert await h_async.process_value(asyncio.create_task(async_identity(1))) == 1 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_driver_end_to_end(): | ||
dr = h_async.AsyncDriver({}, simple_async_module) | ||
all_vars = [var.name for var in dr.list_available_variables()] | ||
result = await dr.raw_execute(final_vars=all_vars, inputs={'external_input': 1}) | ||
assert result == {'another_async_func': 8, | ||
'async_func_with_param': 4, | ||
'external_input': 1, | ||
'non_async_func_with_decorator': {'result_1': 9, 'result_2': 5}, | ||
'result_1': 9, | ||
'result_2': 5, | ||
'simple_async_func': 2, | ||
'simple_non_async_func': 7} |