-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #360 from golemfactory/az/async-task-generator
Asynchronous task iterator as argument to `Executor.submit()`
- Loading branch information
Showing
7 changed files
with
250 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from typing import AsyncIterable | ||
|
||
import pytest | ||
|
||
from yapapi.executor._smartq import PeekableAsyncIterator | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize( | ||
"input", | ||
[ | ||
[], | ||
[1], | ||
[1, 2], | ||
[1, 2, 3], | ||
[1, 2, 3, 4], | ||
], | ||
) | ||
async def test_iterator(input): | ||
async def iterator(): | ||
for item in input: | ||
yield item | ||
|
||
it = PeekableAsyncIterator(iterator()) | ||
assert (await it.has_next()) == bool(input) | ||
|
||
output = [] | ||
|
||
async for item in it: | ||
output.append(item) | ||
assert (await it.has_next()) == (len(output) < len(input)) | ||
|
||
assert not await it.has_next() | ||
assert input == output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#!/usr/bin/env python3 | ||
"""A requestor script for testing asynchronous generation of input tasks.""" | ||
import asyncio | ||
from datetime import timedelta | ||
import pathlib | ||
import sys | ||
from typing import AsyncGenerator | ||
|
||
from yapapi import Executor, Task | ||
from yapapi.log import enable_default_logger, log_event_repr | ||
from yapapi.package import vm | ||
|
||
|
||
async def main(): | ||
|
||
vm_package = await vm.repo( | ||
image_hash="9a3b5d67b0b27746283cb5f287c13eab1beaa12d92a9f536b747c7ae", | ||
min_mem_gib=0.5, | ||
min_storage_gib=2.0, | ||
) | ||
|
||
async def worker(work_ctx, tasks): | ||
async for task in tasks: | ||
print("task data:", task.data, file=sys.stderr) | ||
work_ctx.run("/bin/sleep", "1") | ||
yield work_ctx.commit() | ||
task.accept_result(result=task.data) | ||
|
||
async with Executor( | ||
budget=10.0, | ||
package=vm_package, | ||
max_workers=1, | ||
subnet_tag="goth", | ||
timeout=timedelta(minutes=6), | ||
event_consumer=log_event_repr, | ||
) as executor: | ||
|
||
# We use an async task generator that yields tasks removed from | ||
# an async queue. Each computed task will potentially spawn | ||
# new tasks -- this is made possible thanks to using async task | ||
# generator as an input to `executor.submit()`. | ||
|
||
task_queue = asyncio.Queue() | ||
|
||
# Seed the queue with the first task: | ||
await task_queue.put(Task(data=3)) | ||
|
||
async def input_generator(): | ||
"""Task generator yields tasks removed from `queue`.""" | ||
while True: | ||
task = await task_queue.get() | ||
if task.data == 0: | ||
break | ||
yield task | ||
|
||
async for task in executor.submit(worker, input_generator()): | ||
print("task result:", task.result, file=sys.stderr) | ||
for n in range(task.result): | ||
await task_queue.put(Task(data=task.result - 1)) | ||
|
||
print("all done!", file=sys.stderr) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_dir = pathlib.Path(__file__).parent.name | ||
enable_default_logger(log_file=f"{test_dir}.log") | ||
asyncio.run(main()) |
50 changes: 50 additions & 0 deletions
50
tests/goth/test_async_task_generation/test_async_task_generation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
import goth.configuration | ||
from goth.runner import Runner | ||
from goth.runner.log import configure_logging | ||
from goth.runner.probe import RequestorProbe | ||
|
||
|
||
logger = logging.getLogger("goth.test.async_task_generation") | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_async_task_generation(project_dir: Path, log_dir: Path, config_overrides) -> None: | ||
"""Run the `requestor.py` and make sure that it's standard output is as expected.""" | ||
|
||
configure_logging(log_dir) | ||
|
||
# Override the default test configuration to create only one provider node | ||
nodes = [ | ||
{"name": "requestor", "type": "Requestor"}, | ||
{"name": "provider-1", "type": "VM-Wasm-Provider", "use-proxy": True}, | ||
] | ||
config_overrides.append(("nodes", nodes)) | ||
goth_config = goth.configuration.load_yaml( | ||
project_dir / "tests" / "goth" / "assets" / "goth-config.yml", | ||
config_overrides, | ||
) | ||
|
||
runner = Runner(base_log_dir=log_dir, compose_config=goth_config.compose_config) | ||
|
||
async with runner(goth_config.containers): | ||
|
||
requestor = runner.get_probes(probe_type=RequestorProbe)[0] | ||
|
||
async with requestor.run_command_on_host( | ||
str(Path(__file__).parent / "requestor.py"), env=os.environ | ||
) as (_cmd_task, cmd_monitor): | ||
# The requestor should print "task result: 3" once ... | ||
await cmd_monitor.wait_for_pattern("task result: 3", timeout=60) | ||
# ... then "task result: 2" twice ... | ||
for _ in range(3): | ||
await cmd_monitor.wait_for_pattern("task result: 2", timeout=10) | ||
# ... and "task result: 1" six times. | ||
for _ in range(6): | ||
await cmd_monitor.wait_for_pattern("task result: 1", timeout=10) | ||
await cmd_monitor.wait_for_pattern("all done!", timeout=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.