Skip to content

Commit

Permalink
Add integration test for using AsyncGenerator for tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
azawlocki committed May 12, 2021
1 parent d52f8ee commit 3954819
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 114 deletions.
6 changes: 1 addition & 5 deletions examples/blender/blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,7 @@ async def worker(ctx: WorkContext, tasks):
num_tasks = 0
start_time = datetime.now()

async def task_iterator():
for frame in frames:
yield Task(data=frame)

async for task in executor.submit(worker, task_iterator()):
async for task in executor.submit(worker, [Task(data=frame) for frame in frames]):
num_tasks += 1
print(
f"{TEXT_COLOR_CYAN}"
Expand Down
69 changes: 0 additions & 69 deletions examples/interactive-requestor.py

This file was deleted.

37 changes: 0 additions & 37 deletions tests/executor/test_peekable_async_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,40 +32,3 @@ async def iterator():

assert not await it.has_next()
assert input == output


# @pytest.mark.asyncio
# async def test_double_iteration():
#
# class TestIterable(AsyncIterable[int]):
#
# def __aiter__(self):
# async def iterable():
# for item in [1, 2, 3]:
# yield item
# return iterable()
#
# test = TestIterable()
#
# peekable = PeekableAsyncIterator(test)
#
# assert await peekable.has_items()
#
# async for item in peekable:
# assert item == 1
# break
#
# assert await peekable.has_items()
#
# async for item in peekable:
# assert item == 2
# break
#
# assert await peekable.has_items()
#
# async for item in peekable:
# assert item == 3
# break
#
# assert not await peekable.has_items()
#
67 changes: 67 additions & 0 deletions tests/goth/test_async_task_generation/requestor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""A requestor script for testing asynchronous generation of input tasks."""
import asyncio
from datetime import timedelta
import pathlib
import sys
from typing import AsyncGenerator

from yapapi import Executor, Task
from yapapi.log import enable_default_logger, log_event_repr
from yapapi.package import vm


async def main():

vm_package = await vm.repo(
image_hash="9a3b5d67b0b27746283cb5f287c13eab1beaa12d92a9f536b747c7ae",
min_mem_gib=0.5,
min_storage_gib=2.0,
)

async def worker(work_ctx, tasks):
async for task in tasks:
print("task data:", task.data, file=sys.stderr)
work_ctx.run("/bin/sleep", "1")
yield work_ctx.commit()
task.accept_result(result=task.data)

async with Executor(
budget=10.0,
package=vm_package,
max_workers=1,
subnet_tag="goth",
timeout=timedelta(minutes=6),
event_consumer=log_event_repr,
) as executor:

# We use an async task generator that yields tasks removed from
# an async queue. Each computed task will potentially spawn
# new tasks -- this is made possible thanks to using async task
# generator as an input to `executor.submit()`.

task_queue = asyncio.Queue()

# Seed the queue with the first task:
await task_queue.put(Task(data=3))

async def input_generator() -> AsyncGenerator[Task]:
"""Task generator yields tasks removed from `queue`."""
while True:
task = await task_queue.get()
if task.data == 0:
break
yield task

async for task in executor.submit(worker, input_generator()):
print("task result:", task.result, file=sys.stderr)
for n in range(task.result):
await task.queue.put(Task(data=task.result - 1))

print("all done!", file=sys.stderr)


if __name__ == "__main__":
test_dir = pathlib.Path(__file__).parent.name
enable_default_logger(log_file=f"{test_dir}.log")
asyncio.run(main())
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging
import os
from pathlib import Path

import pytest

import goth.configuration
from goth.runner import Runner
from goth.runner.log import configure_logging
from goth.runner.probe import RequestorProbe


logger = logging.getLogger("goth.test.async_task_generation")


@pytest.mark.asyncio
async def test_async_task_generation(project_dir: Path, log_dir: Path, config_overrides) -> None:
"""Run the `requestor.py` and make sure that it's standard output is as expected."""

configure_logging(log_dir)

# Override the default test configuration to create only one provider node
nodes = [
{"name": "requestor", "type": "Requestor"},
{"name": "provider-1", "type": "VM-Wasm-Provider", "use-proxy": True},
]
config_overrides.append(("nodes", nodes))
goth_config = goth.configuration.load_yaml(
project_dir / "tests" / "goth" / "assets" / "goth-config.yml",
config_overrides,
)

runner = Runner(base_log_dir=log_dir, compose_config=goth_config.compose_config)

async with runner(goth_config.containers):

requestor = runner.get_probes(probe_type=RequestorProbe)[0]

async with requestor.run_command_on_host(
str(Path(__file__).parent / "requestor.py"), env=os.environ
) as (_cmd_task, cmd_monitor):
# The requestor should print "task result: 3" once ...
await cmd_monitor.wait_for_pattern("task result: 3", timeout=60)
# ... then "task result: 2" twice ...
for _ in range(3):
await cmd_monitor.wait_for_pattern("task result: 2", timeout=10)
# ... and "task result: 1" six times.
for _ in range(6):
await cmd_monitor.wait_for_pattern("task result: 1", timeout=10)
await cmd_monitor.wait_for_pattern("all done!", timeout=10)
1 change: 0 additions & 1 deletion yapapi/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ async def _submit(

state = Executor.SubmissionState(builder, agreements_pool)

# market_api = self._market_api
activity_api: rest.Activity = self._activity_api

done_queue: asyncio.Queue[Task[D, R]] = asyncio.Queue()
Expand Down
3 changes: 1 addition & 2 deletions yapapi/executor/_smartq.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ async def has_next(self) -> bool:
return True


class SmartQueue(Generic[Item], object):

class SmartQueue(Generic[Item]):
def __init__(self, items: AsyncIterator[Item]):
"""
:param items: the items to be iterated over
Expand Down

0 comments on commit 3954819

Please sign in to comment.