Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Asynchronous task iterator as argument to Executor.submit() #360

Merged
merged 3 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ jobs:
poetry install -E integration-tests

- name: Disconnect Docker containers from default network
continue-on-error: true
run: |
docker network inspect docker_default
sudo apt-get install -y jq
docker network inspect docker_default | jq ".[0].Containers | map(.Name)[]" | tee /dev/stderr | xargs --max-args 1 -- docker network disconnect -f docker_default
continue-on-error: true
run: |
docker network inspect docker_default
sudo apt-get install -y jq
docker network inspect docker_default | jq ".[0].Containers | map(.Name)[]" | tee /dev/stderr | xargs --max-args 1 -- docker network disconnect -f docker_default

- name: Remove Docker containers
continue-on-error: true
Expand Down
34 changes: 34 additions & 0 deletions tests/executor/test_peekable_async_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import AsyncIterable

import pytest

from yapapi.executor._smartq import PeekableAsyncIterator


@pytest.mark.asyncio
@pytest.mark.parametrize(
"input",
[
[],
[1],
[1, 2],
[1, 2, 3],
[1, 2, 3, 4],
],
)
async def test_iterator(input):
async def iterator():
for item in input:
yield item

it = PeekableAsyncIterator(iterator())
assert (await it.has_next()) == bool(input)

output = []

async for item in it:
output.append(item)
assert (await it.has_next()) == (len(output) < len(input))

assert not await it.has_next()
assert input == output
35 changes: 20 additions & 15 deletions tests/executor/test_smartq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
import asyncio


async def async_iter(iterable):
for item in iterable:
yield item


@pytest.mark.asyncio
@pytest.mark.parametrize("length", [0, 1, 100])
async def test_smart_queue(length: int):
q = SmartQueue(range(length))

q = SmartQueue(async_iter(range(length)))

async def worker(i, queue):
print(f"worker {i} started")
Expand Down Expand Up @@ -46,40 +52,40 @@ async def stats():

@pytest.mark.asyncio
async def test_smart_queue_empty():
q: SmartQueue[None] = SmartQueue([])

q: SmartQueue = SmartQueue(async_iter([]))
with q.new_consumer() as c:
async for item in c:
async for _item in c:
assert False, "Expected empty list"
# done


@pytest.mark.asyncio
async def test_unassigned_items():
q = SmartQueue([1, 2, 3])
q = SmartQueue(async_iter([1, 2, 3]))
with q.new_consumer() as c:
async for handle in c:
assert q.has_new_items() == q.has_unassigned_items()
if not q.has_unassigned_items():
assert await q.has_new_items() == await q.has_unassigned_items()
if not await q.has_unassigned_items():
assert handle.data == 3
break
assert not q.has_unassigned_items()
assert not await q.has_unassigned_items()
await q.reschedule_all(c)
assert q.has_unassigned_items()
assert not q.has_new_items()
assert await q.has_unassigned_items()
assert not await q.has_new_items()


@pytest.mark.asyncio
async def test_smart_queue_retry(caplog):
loop = asyncio.get_event_loop()

caplog.set_level(logging.DEBUG)
q = SmartQueue([1, 2, 3])
q = SmartQueue(async_iter([1, 2, 3]))

async def invalid_worker(q):
print("w start")
with q.new_consumer() as c:
async for item in c:
print("item=", item.data)
print("item =", item.data)
print("w end")

try:
Expand All @@ -96,11 +102,10 @@ async def invalid_worker(q):
print("w start", q.stats())
with q.new_consumer() as c:
async for item in c:
print("item2=", item.data)
print("item2 =", item.data)
assert c.current_item == item.data
outputs.add(item.data)
await q.mark_done(item)
print("w end")

assert outputs == set([1, 2, 3])
# done
assert outputs == {1, 2, 3}
67 changes: 67 additions & 0 deletions tests/goth/test_async_task_generation/requestor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""A requestor script for testing asynchronous generation of input tasks."""
import asyncio
from datetime import timedelta
import pathlib
import sys
from typing import AsyncGenerator

from yapapi import Executor, Task
from yapapi.log import enable_default_logger, log_event_repr
from yapapi.package import vm


async def main():

vm_package = await vm.repo(
image_hash="9a3b5d67b0b27746283cb5f287c13eab1beaa12d92a9f536b747c7ae",
min_mem_gib=0.5,
min_storage_gib=2.0,
)

async def worker(work_ctx, tasks):
async for task in tasks:
print("task data:", task.data, file=sys.stderr)
work_ctx.run("/bin/sleep", "1")
yield work_ctx.commit()
task.accept_result(result=task.data)

async with Executor(
budget=10.0,
package=vm_package,
max_workers=1,
subnet_tag="goth",
timeout=timedelta(minutes=6),
event_consumer=log_event_repr,
) as executor:

# We use an async task generator that yields tasks removed from
# an async queue. Each computed task will potentially spawn
# new tasks -- this is made possible thanks to using async task
# generator as an input to `executor.submit()`.

task_queue = asyncio.Queue()

# Seed the queue with the first task:
await task_queue.put(Task(data=3))

async def input_generator():
"""Task generator yields tasks removed from `queue`."""
while True:
task = await task_queue.get()
if task.data == 0:
break
yield task

async for task in executor.submit(worker, input_generator()):
print("task result:", task.result, file=sys.stderr)
for n in range(task.result):
await task_queue.put(Task(data=task.result - 1))

print("all done!", file=sys.stderr)


if __name__ == "__main__":
test_dir = pathlib.Path(__file__).parent.name
enable_default_logger(log_file=f"{test_dir}.log")
asyncio.run(main())
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging
import os
from pathlib import Path

import pytest

import goth.configuration
from goth.runner import Runner
from goth.runner.log import configure_logging
from goth.runner.probe import RequestorProbe


logger = logging.getLogger("goth.test.async_task_generation")


@pytest.mark.asyncio
async def test_async_task_generation(project_dir: Path, log_dir: Path, config_overrides) -> None:
"""Run the `requestor.py` and make sure that it's standard output is as expected."""

configure_logging(log_dir)

# Override the default test configuration to create only one provider node
nodes = [
{"name": "requestor", "type": "Requestor"},
{"name": "provider-1", "type": "VM-Wasm-Provider", "use-proxy": True},
]
config_overrides.append(("nodes", nodes))
goth_config = goth.configuration.load_yaml(
project_dir / "tests" / "goth" / "assets" / "goth-config.yml",
config_overrides,
)

runner = Runner(base_log_dir=log_dir, compose_config=goth_config.compose_config)

async with runner(goth_config.containers):

requestor = runner.get_probes(probe_type=RequestorProbe)[0]

async with requestor.run_command_on_host(
str(Path(__file__).parent / "requestor.py"), env=os.environ
) as (_cmd_task, cmd_monitor):
# The requestor should print "task result: 3" once ...
await cmd_monitor.wait_for_pattern("task result: 3", timeout=60)
# ... then "task result: 2" twice ...
for _ in range(3):
await cmd_monitor.wait_for_pattern("task result: 2", timeout=10)
# ... and "task result: 1" six times.
for _ in range(6):
await cmd_monitor.wait_for_pattern("task result: 1", timeout=10)
await cmd_monitor.wait_for_pattern("all done!", timeout=10)
23 changes: 15 additions & 8 deletions yapapi/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def strategy(self) -> MarketStrategy:
async def submit(
self,
worker: Callable[[WorkContext, AsyncIterator[Task[D, R]]], AsyncGenerator[Work, None]],
data: Iterable[Task[D, R]],
data: Union[AsyncIterator[Task[D, R]], Iterable[Task[D, R]]],
) -> AsyncIterator[Task[D, R]]:
"""Submit a computation to be executed on providers.

Expand Down Expand Up @@ -388,7 +388,7 @@ async def _find_offers(self, state: "Executor.SubmissionState") -> None:
async def _submit(
self,
worker: Callable[[WorkContext, AsyncIterator[Task[D, R]]], AsyncGenerator[Work, None]],
data: Iterable[Task[D, R]],
data: Union[AsyncIterator[Task[D, R]], Iterable[Task[D, R]]],
services: Set[asyncio.Task],
workers: Set[asyncio.Task],
) -> AsyncGenerator[Task[D, R], None]:
Expand All @@ -415,7 +415,6 @@ async def _submit(

state = Executor.SubmissionState(builder, agreements_pool)

market_api = self._market_api
activity_api: rest.Activity = self._activity_api

done_queue: asyncio.Queue[Task[D, R]] = asyncio.Queue()
Expand All @@ -426,10 +425,15 @@ def on_task_done(task: Task[D, R], status: TaskStatus) -> None:
if status == TaskStatus.ACCEPTED:
done_queue.put_nowait(task)

def input_tasks() -> Iterable[Task[D, R]]:
for task in data:
task._add_callback(on_task_done)
yield task
async def input_tasks() -> AsyncIterator[Task[D, R]]:
if isinstance(data, AsyncIterator):
async for task in data:
task._add_callback(on_task_done)
yield task
else:
for task in data:
task._add_callback(on_task_done)
yield task

work_queue = SmartQueue(input_tasks())

Expand Down Expand Up @@ -607,7 +611,10 @@ async def worker_starter() -> None:
while True:
await asyncio.sleep(2)
await agreements_pool.cycle()
if len(workers) < self._conf.max_workers and work_queue.has_unassigned_items():
if (
len(workers) < self._conf.max_workers
and await work_queue.has_unassigned_items()
):
new_task = None
try:
new_task = await agreements_pool.use_agreement(
Expand Down
Loading