Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return batch results with optional blocking #367

Merged
merged 7 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion yapapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from pathlib import Path
from pkg_resources import get_distribution

from .executor import Executor, NoPaymentAccountError, Task, WorkContext, CaptureContext
from .executor import (
CaptureContext,
ExecOptions,
Executor,
NoPaymentAccountError,
Task,
WorkContext,
)


def get_version() -> str:
Expand Down
184 changes: 127 additions & 57 deletions yapapi/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
from typing import (
AsyncContextManager,
AsyncIterator,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
cast,
Expand All @@ -29,7 +31,7 @@
from yapapi.executor.agreements_pool import AgreementsPool
from typing_extensions import Final, AsyncGenerator

from .ctx import CaptureContext, CommandContainer, Work, WorkContext
from .ctx import CaptureContext, CommandContainer, ExecOptions, Work, WorkContext
from .events import Event
from . import events
from .task import Task, TaskStatus
Expand All @@ -41,7 +43,7 @@
from ..rest.activity import CommandExecutionError
from ..rest.market import OfferProposal, Subscription
from ..storage import gftp
from ._smartq import SmartQueue, Handle
from ._smartq import Consumer, Handle, SmartQueue
from .strategy import (
DecreaseScoreForUnconfirmedAgreement,
LeastExpensiveLinearPayuMS,
Expand Down Expand Up @@ -99,6 +101,10 @@ class _ExecutorConfig:
traceback: bool = bool(os.getenv("YAPAPI_TRACEBACK", 0))


WorkItem = Union[Work, Tuple[Work, ExecOptions]]
"""The type of items yielded by a generator created by the `worker` function supplied by user."""


D = TypeVar("D") # Type var for task data
R = TypeVar("R") # Type var for task result

Expand Down Expand Up @@ -205,7 +211,10 @@ def strategy(self) -> MarketStrategy:

async def submit(
self,
worker: Callable[[WorkContext, AsyncIterator[Task[D, R]]], AsyncGenerator[Work, None]],
worker: Callable[
[WorkContext, AsyncIterator[Task[D, R]]],
AsyncGenerator[Work, Awaitable[List[events.CommandEvent]]],
],
data: Union[AsyncIterator[Task[D, R]], Iterable[Task[D, R]]],
) -> AsyncIterator[Task[D, R]]:
"""Submit a computation to be executed on providers.
Expand Down Expand Up @@ -387,7 +396,10 @@ async def _find_offers(self, state: "Executor.SubmissionState") -> None:

async def _submit(
self,
worker: Callable[[WorkContext, AsyncIterator[Task[D, R]]], AsyncGenerator[Work, None]],
worker: Callable[
[WorkContext, AsyncIterator[Task[D, R]]],
AsyncGenerator[Work, Awaitable[List[events.CommandEvent]]],
],
data: Union[AsyncIterator[Task[D, R]], Iterable[Task[D, R]]],
services: Set[asyncio.Task],
workers: Set[asyncio.Task],
Expand Down Expand Up @@ -526,7 +538,99 @@ async def accept_payment_for_agreement(agreement_id: str, *, partial: bool = Fal

storage_manager = await self._stack.enter_async_context(gftp.provider())

def unpack_work_item(item: WorkItem) -> Tuple[Work, ExecOptions]:
"""Extract `Work` object and options from a work item.
If the item does not specify options, default ones are provided.
"""
if isinstance(item, tuple):
return item
else:
return item, ExecOptions()

async def process_batches(
agreement_id: str,
activity: rest.activity.Activity,
command_generator: AsyncGenerator[Work, Awaitable[List[events.CommandEvent]]],
consumer: Consumer[Task[D, R]],
) -> None:
"""Send command batches produced by `command_generator` to `activity`."""

item = await command_generator.__anext__()

while True:

batch, exec_options = unpack_work_item(item)
if batch.timeout:
if exec_options.batch_timeout:
logger.warning(
"Overriding batch timeout set with commit(batch_timeout)"
"by the value set in exec options"
)
else:
exec_options.batch_timeout = batch.timeout

batch_deadline = (
datetime.now(timezone.utc) + exec_options.batch_timeout
if exec_options.batch_timeout
else None
)

current_worker_task = consumer.current_item
if current_worker_task:
emit(
events.TaskStarted(
agr_id=agreement_id,
task_id=current_worker_task.id,
task_data=current_worker_task.data,
)
)
task_id = current_worker_task.id if current_worker_task else None

await batch.prepare()
cc = CommandContainer()
batch.register(cc)
remote = await activity.send(cc.commands(), stream_output, deadline=batch_deadline)
cmds = cc.commands()
emit(events.ScriptSent(agr_id=agreement_id, task_id=task_id, cmds=cmds))

async def get_batch_results() -> List[events.CommandEvent]:
results = []
async for evt_ctx in remote:
evt = evt_ctx.event(agr_id=agreement_id, task_id=task_id, cmds=cmds)
emit(evt)
results.append(evt)
if isinstance(evt, events.CommandExecuted) and not evt.success:
raise CommandExecutionError(evt.command, evt.message)

emit(events.GettingResults(agr_id=agreement_id, task_id=task_id))
await batch.post()
emit(events.ScriptFinished(agr_id=agreement_id, task_id=task_id))
await accept_payment_for_agreement(agreement_id, partial=True)
return results

loop = asyncio.get_event_loop()

if exec_options.wait_for_results:
# Block until the results are available
try:
future_results = loop.create_future()
results = await get_batch_results()
future_results.set_result(results)
item = await command_generator.asend(future_results)
except StopAsyncIteration:
raise
except Exception:
# Raise the exception in `command_generator` (the `worker` coroutine).
# If the client code is able to handle it then we'll proceed with
# subsequent batches. Otherwise the worker finishes with error.
item = await command_generator.athrow(*sys.exc_info())
else:
# Schedule the coroutine in a separate asyncio task
future_results = loop.create_task(get_batch_results())
item = await command_generator.asend(future_results)

async def start_worker(agreement: rest.market.Agreement, node_info: NodeInfo) -> None:

nonlocal last_wid
wid = last_wid
last_wid += 1
Expand All @@ -543,69 +647,35 @@ async def start_worker(agreement: rest.market.Agreement, node_info: NodeInfo) ->
)
emit(events.WorkerFinished(agr_id=agreement.id))
raise

async with act:

emit(events.ActivityCreated(act_id=act.id, agr_id=agreement.id))
agreements_accepting_debit_notes.add(agreement.id)
work_context = WorkContext(
f"worker-{wid}", node_info, storage_manager, emitter=emit
)
with work_queue.new_consumer() as consumer:

command_generator = worker(
work_context,
(Task.for_handle(handle, work_queue, emit) async for handle in consumer),
)
async for batch in command_generator:
batch_deadline = (
datetime.now(timezone.utc) + batch.timeout if batch.timeout else None
with work_queue.new_consumer() as consumer:
try:
tasks = (
Task.for_handle(handle, work_queue, emit) async for handle in consumer
)
batch_generator = worker(work_context, tasks)
try:
current_worker_task = consumer.current_item
if current_worker_task:
emit(
events.TaskStarted(
agr_id=agreement.id,
task_id=current_worker_task.id,
task_data=current_worker_task.data,
)
)
task_id = current_worker_task.id if current_worker_task else None
await batch.prepare()
cc = CommandContainer()
batch.register(cc)
remote = await act.send(
cc.commands(), stream_output, deadline=batch_deadline
await process_batches(agreement.id, act, batch_generator, consumer)
except StopAsyncIteration:
pass
emit(events.WorkerFinished(agr_id=agreement.id))
except Exception:
emit(
events.WorkerFinished(
agr_id=agreement.id, exc_info=sys.exc_info() # type: ignore
)
cmds = cc.commands()
emit(events.ScriptSent(agr_id=agreement.id, task_id=task_id, cmds=cmds))

async for evt_ctx in remote:
evt = evt_ctx.event(agr_id=agreement.id, task_id=task_id, cmds=cmds)
emit(evt)
if isinstance(evt, events.CommandExecuted) and not evt.success:
raise CommandExecutionError(evt.command, evt.message)

emit(events.GettingResults(agr_id=agreement.id, task_id=task_id))
await batch.post()
emit(events.ScriptFinished(agr_id=agreement.id, task_id=task_id))
await accept_payment_for_agreement(agreement.id, partial=True)

except Exception:

try:
await command_generator.athrow(*sys.exc_info())
except Exception:
if self._conf.traceback:
traceback.print_exc()
emit(
events.WorkerFinished(
agr_id=agreement.id, exc_info=sys.exc_info() # type: ignore
)
)
raise

await accept_payment_for_agreement(agreement.id)
emit(events.WorkerFinished(agr_id=agreement.id))
)
raise
finally:
await accept_payment_for_agreement(agreement.id)

async def worker_starter() -> None:
while True:
Expand Down
13 changes: 13 additions & 0 deletions yapapi/executor/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ async def __on_json_download(on_download: Callable[[bytes], Awaitable], content:

class Steps(Work):
def __init__(self, *steps: Work, timeout: Optional[timedelta] = None):
"""Create a `Work` item consisting of a sequence of steps (subitems).

:param steps: sequence of steps to be executed
:param timeout: timeout for waiting for the steps' results
"""
self._steps: Tuple[Work, ...] = steps
self._timeout: Optional[timedelta] = timeout

Expand All @@ -252,6 +257,14 @@ async def post(self):
await step.post()


@dataclass
class ExecOptions:
"""Options related to command batch execution."""

wait_for_results: bool = True
batch_timeout: Optional[timedelta] = None


class WorkContext:
"""An object used to schedule commands to be sent to provider."""

Expand Down