Skip to content

Commit

Permalink
Merge pull request #360 from golemfactory/az/async-task-generator
Browse files Browse the repository at this point in the history
Asynchronous task iterator as argument to `Executor.submit()`
  • Loading branch information
azawlocki authored May 13, 2021
2 parents 2a0d90a + ab8471d commit 5561d39
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 48 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ jobs:
poetry install -E integration-tests
- name: Disconnect Docker containers from default network
continue-on-error: true
run: |
docker network inspect docker_default
sudo apt-get install -y jq
docker network inspect docker_default | jq ".[0].Containers | map(.Name)[]" | tee /dev/stderr | xargs --max-args 1 -- docker network disconnect -f docker_default
continue-on-error: true
run: |
docker network inspect docker_default
sudo apt-get install -y jq
docker network inspect docker_default | jq ".[0].Containers | map(.Name)[]" | tee /dev/stderr | xargs --max-args 1 -- docker network disconnect -f docker_default
- name: Remove Docker containers
continue-on-error: true
Expand Down
34 changes: 34 additions & 0 deletions tests/executor/test_peekable_async_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import AsyncIterable

import pytest

from yapapi.executor._smartq import PeekableAsyncIterator


@pytest.mark.asyncio
@pytest.mark.parametrize(
"input",
[
[],
[1],
[1, 2],
[1, 2, 3],
[1, 2, 3, 4],
],
)
async def test_iterator(input):
async def iterator():
for item in input:
yield item

it = PeekableAsyncIterator(iterator())
assert (await it.has_next()) == bool(input)

output = []

async for item in it:
output.append(item)
assert (await it.has_next()) == (len(output) < len(input))

assert not await it.has_next()
assert input == output
35 changes: 20 additions & 15 deletions tests/executor/test_smartq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
import asyncio


async def async_iter(iterable):
for item in iterable:
yield item


@pytest.mark.asyncio
@pytest.mark.parametrize("length", [0, 1, 100])
async def test_smart_queue(length: int):
q = SmartQueue(range(length))

q = SmartQueue(async_iter(range(length)))

async def worker(i, queue):
print(f"worker {i} started")
Expand Down Expand Up @@ -46,40 +52,40 @@ async def stats():

@pytest.mark.asyncio
async def test_smart_queue_empty():
q: SmartQueue[None] = SmartQueue([])

q: SmartQueue = SmartQueue(async_iter([]))
with q.new_consumer() as c:
async for item in c:
async for _item in c:
assert False, "Expected empty list"
# done


@pytest.mark.asyncio
async def test_unassigned_items():
q = SmartQueue([1, 2, 3])
q = SmartQueue(async_iter([1, 2, 3]))
with q.new_consumer() as c:
async for handle in c:
assert q.has_new_items() == q.has_unassigned_items()
if not q.has_unassigned_items():
assert await q.has_new_items() == await q.has_unassigned_items()
if not await q.has_unassigned_items():
assert handle.data == 3
break
assert not q.has_unassigned_items()
assert not await q.has_unassigned_items()
await q.reschedule_all(c)
assert q.has_unassigned_items()
assert not q.has_new_items()
assert await q.has_unassigned_items()
assert not await q.has_new_items()


@pytest.mark.asyncio
async def test_smart_queue_retry(caplog):
loop = asyncio.get_event_loop()

caplog.set_level(logging.DEBUG)
q = SmartQueue([1, 2, 3])
q = SmartQueue(async_iter([1, 2, 3]))

async def invalid_worker(q):
print("w start")
with q.new_consumer() as c:
async for item in c:
print("item=", item.data)
print("item =", item.data)
print("w end")

try:
Expand All @@ -96,11 +102,10 @@ async def invalid_worker(q):
print("w start", q.stats())
with q.new_consumer() as c:
async for item in c:
print("item2=", item.data)
print("item2 =", item.data)
assert c.current_item == item.data
outputs.add(item.data)
await q.mark_done(item)
print("w end")

assert outputs == set([1, 2, 3])
# done
assert outputs == {1, 2, 3}
67 changes: 67 additions & 0 deletions tests/goth/test_async_task_generation/requestor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""A requestor script for testing asynchronous generation of input tasks."""
import asyncio
from datetime import timedelta
import pathlib
import sys
from typing import AsyncGenerator

from yapapi import Executor, Task
from yapapi.log import enable_default_logger, log_event_repr
from yapapi.package import vm


async def main():

vm_package = await vm.repo(
image_hash="9a3b5d67b0b27746283cb5f287c13eab1beaa12d92a9f536b747c7ae",
min_mem_gib=0.5,
min_storage_gib=2.0,
)

async def worker(work_ctx, tasks):
async for task in tasks:
print("task data:", task.data, file=sys.stderr)
work_ctx.run("/bin/sleep", "1")
yield work_ctx.commit()
task.accept_result(result=task.data)

async with Executor(
budget=10.0,
package=vm_package,
max_workers=1,
subnet_tag="goth",
timeout=timedelta(minutes=6),
event_consumer=log_event_repr,
) as executor:

# We use an async task generator that yields tasks removed from
# an async queue. Each computed task will potentially spawn
# new tasks -- this is made possible thanks to using async task
# generator as an input to `executor.submit()`.

task_queue = asyncio.Queue()

# Seed the queue with the first task:
await task_queue.put(Task(data=3))

async def input_generator():
"""Task generator yields tasks removed from `queue`."""
while True:
task = await task_queue.get()
if task.data == 0:
break
yield task

async for task in executor.submit(worker, input_generator()):
print("task result:", task.result, file=sys.stderr)
for n in range(task.result):
await task_queue.put(Task(data=task.result - 1))

print("all done!", file=sys.stderr)


if __name__ == "__main__":
test_dir = pathlib.Path(__file__).parent.name
enable_default_logger(log_file=f"{test_dir}.log")
asyncio.run(main())
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging
import os
from pathlib import Path

import pytest

import goth.configuration
from goth.runner import Runner
from goth.runner.log import configure_logging
from goth.runner.probe import RequestorProbe


logger = logging.getLogger("goth.test.async_task_generation")


@pytest.mark.asyncio
async def test_async_task_generation(project_dir: Path, log_dir: Path, config_overrides) -> None:
"""Run the `requestor.py` and make sure that it's standard output is as expected."""

configure_logging(log_dir)

# Override the default test configuration to create only one provider node
nodes = [
{"name": "requestor", "type": "Requestor"},
{"name": "provider-1", "type": "VM-Wasm-Provider", "use-proxy": True},
]
config_overrides.append(("nodes", nodes))
goth_config = goth.configuration.load_yaml(
project_dir / "tests" / "goth" / "assets" / "goth-config.yml",
config_overrides,
)

runner = Runner(base_log_dir=log_dir, compose_config=goth_config.compose_config)

async with runner(goth_config.containers):

requestor = runner.get_probes(probe_type=RequestorProbe)[0]

async with requestor.run_command_on_host(
str(Path(__file__).parent / "requestor.py"), env=os.environ
) as (_cmd_task, cmd_monitor):
# The requestor should print "task result: 3" once ...
await cmd_monitor.wait_for_pattern("task result: 3", timeout=60)
# ... then "task result: 2" twice ...
for _ in range(3):
await cmd_monitor.wait_for_pattern("task result: 2", timeout=10)
# ... and "task result: 1" six times.
for _ in range(6):
await cmd_monitor.wait_for_pattern("task result: 1", timeout=10)
await cmd_monitor.wait_for_pattern("all done!", timeout=10)
23 changes: 15 additions & 8 deletions yapapi/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def strategy(self) -> MarketStrategy:
async def submit(
self,
worker: Callable[[WorkContext, AsyncIterator[Task[D, R]]], AsyncGenerator[Work, None]],
data: Iterable[Task[D, R]],
data: Union[AsyncIterator[Task[D, R]], Iterable[Task[D, R]]],
) -> AsyncIterator[Task[D, R]]:
"""Submit a computation to be executed on providers.
Expand Down Expand Up @@ -388,7 +388,7 @@ async def _find_offers(self, state: "Executor.SubmissionState") -> None:
async def _submit(
self,
worker: Callable[[WorkContext, AsyncIterator[Task[D, R]]], AsyncGenerator[Work, None]],
data: Iterable[Task[D, R]],
data: Union[AsyncIterator[Task[D, R]], Iterable[Task[D, R]]],
services: Set[asyncio.Task],
workers: Set[asyncio.Task],
) -> AsyncGenerator[Task[D, R], None]:
Expand All @@ -415,7 +415,6 @@ async def _submit(

state = Executor.SubmissionState(builder, agreements_pool)

market_api = self._market_api
activity_api: rest.Activity = self._activity_api

done_queue: asyncio.Queue[Task[D, R]] = asyncio.Queue()
Expand All @@ -426,10 +425,15 @@ def on_task_done(task: Task[D, R], status: TaskStatus) -> None:
if status == TaskStatus.ACCEPTED:
done_queue.put_nowait(task)

def input_tasks() -> Iterable[Task[D, R]]:
for task in data:
task._add_callback(on_task_done)
yield task
async def input_tasks() -> AsyncIterator[Task[D, R]]:
if isinstance(data, AsyncIterator):
async for task in data:
task._add_callback(on_task_done)
yield task
else:
for task in data:
task._add_callback(on_task_done)
yield task

work_queue = SmartQueue(input_tasks())

Expand Down Expand Up @@ -607,7 +611,10 @@ async def worker_starter() -> None:
while True:
await asyncio.sleep(2)
await agreements_pool.cycle()
if len(workers) < self._conf.max_workers and work_queue.has_unassigned_items():
if (
len(workers) < self._conf.max_workers
and await work_queue.has_unassigned_items()
):
new_task = None
try:
new_task = await agreements_pool.use_agreement(
Expand Down
Loading

0 comments on commit 5561d39

Please sign in to comment.