Skip to content

Commit

Permalink
Fix error handling in process_batches()
Browse files Browse the repository at this point in the history
  • Loading branch information
azawlocki committed May 14, 2021
1 parent a7972eb commit 5ffc639
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 252 deletions.
173 changes: 0 additions & 173 deletions examples/blender/blender-async-results.py

This file was deleted.

9 changes: 8 additions & 1 deletion yapapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
from pathlib import Path
from pkg_resources import get_distribution

from .executor import Executor, NoPaymentAccountError, Task, WorkContext, CaptureContext
from .executor import (
CaptureContext,
ExecOptions,
Executor,
NoPaymentAccountError,
Task,
WorkContext,
)


def get_version() -> str:
Expand Down
134 changes: 66 additions & 68 deletions yapapi/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,90 +548,85 @@ def unpack_work_item(item: WorkItem) -> Tuple[Work, ExecOptions]:
return item, ExecOptions()

async def process_batches(
agreement: rest.market.Agreement,
act: rest.activity.Activity,
agreement_id: str,
activity: rest.activity.Activity,
command_generator: AsyncGenerator[Work, Awaitable[List[events.CommandEvent]]],
consumer: Consumer[Task[D, R]],
) -> None:
"""Send command batches produced by `command_generator` to `activity`."""

try:
item = await command_generator.__anext__()
except StopAsyncIteration:
return

batch, exec_options = unpack_work_item(item)
if batch.timeout:
if exec_options.batch_timeout:
logger.warning(
"Overriding batch timeout set with commit(batch_timeout)"
"by the value set in exec options"
)
else:
exec_options.batch_timeout = batch.timeout
item = await command_generator.__anext__()

while True:

batch, exec_options = unpack_work_item(item)
if batch.timeout:
if exec_options.batch_timeout:
logger.warning(
"Overriding batch timeout set with commit(batch_timeout)"
"by the value set in exec options"
)
else:
exec_options.batch_timeout = batch.timeout

batch_deadline = (
datetime.now(timezone.utc) + exec_options.batch_timeout
if exec_options.batch_timeout
else None
)
try:
current_worker_task = consumer.current_item
if current_worker_task:
emit(
events.TaskStarted(
agr_id=agreement.id,
task_id=current_worker_task.id,
task_data=current_worker_task.data,
)

current_worker_task = consumer.current_item
if current_worker_task:
emit(
events.TaskStarted(
agr_id=agreement_id,
task_id=current_worker_task.id,
task_data=current_worker_task.data,
)
task_id = current_worker_task.id if current_worker_task else None
await batch.prepare()
cc = CommandContainer()
batch.register(cc)
remote = await act.send(cc.commands(), stream_output, deadline=batch_deadline)
cmds = cc.commands()
emit(events.ScriptSent(agr_id=agreement.id, task_id=task_id, cmds=cmds))

async def get_batch_results() -> List[events.CommandEvent]:
results = []
async for evt_ctx in remote:
evt = evt_ctx.event(agr_id=agreement.id, task_id=task_id, cmds=cmds)
emit(evt)
results.append(evt)
if isinstance(evt, events.CommandExecuted) and not evt.success:
raise CommandExecutionError(evt.command, evt.message)

emit(events.GettingResults(agr_id=agreement.id, task_id=task_id))
await batch.post()
emit(events.ScriptFinished(agr_id=agreement.id, task_id=task_id))
await accept_payment_for_agreement(agreement.id, partial=True)
return results

loop = asyncio.get_event_loop()

if exec_options.wait_for_results:
# Block until the results are available
results = await get_batch_results()
)
task_id = current_worker_task.id if current_worker_task else None
await batch.prepare()
cc = CommandContainer()
batch.register(cc)
remote = await activity.send(cc.commands(), stream_output, deadline=batch_deadline)
cmds = cc.commands()
emit(events.ScriptSent(agr_id=agreement_id, task_id=task_id, cmds=cmds))

async def get_batch_results() -> List[events.CommandEvent]:
results = []
async for evt_ctx in remote:
evt = evt_ctx.event(agr_id=agreement_id, task_id=task_id, cmds=cmds)
emit(evt)
results.append(evt)
if isinstance(evt, events.CommandExecuted) and not evt.success:
raise CommandExecutionError(evt.command, evt.message)

emit(events.GettingResults(agr_id=agreement_id, task_id=task_id))
await batch.post()
emit(events.ScriptFinished(agr_id=agreement_id, task_id=task_id))
await accept_payment_for_agreement(agreement_id, partial=True)
return results

loop = asyncio.get_event_loop()

if exec_options.wait_for_results:
# Block until the results are available
try:
future_results = loop.create_future()
results = await get_batch_results()
future_results.set_result(results)
else:
# Schedule the coroutine in a separate asyncio task
future_results = loop.create_task(get_batch_results())

try:
item = await command_generator.asend(future_results)
except StopAsyncIteration:
break

batch, exec_options = unpack_work_item(item)

except Exception:
# Raise the exception in the command_generator (the `worker` coroutine).
# If the client code is able to handle it then we'll proceed with
# subsequent batches. Otherwise the worker finishes with error.
await command_generator.athrow(*sys.exc_info())
raise
except Exception:
# Raise the exception in `command_generator` (the `worker` coroutine).
# If the client code is able to handle it then we'll proceed with
# subsequent batches. Otherwise the worker finishes with error.
item = await command_generator.athrow(*sys.exc_info())
else:
# Schedule the coroutine in a separate asyncio task
future_results = loop.create_task(get_batch_results())
item = await command_generator.asend(future_results)

async def start_worker(agreement: rest.market.Agreement, node_info: NodeInfo) -> None:

Expand Down Expand Up @@ -666,7 +661,10 @@ async def start_worker(agreement: rest.market.Agreement, node_info: NodeInfo) ->
Task.for_handle(handle, work_queue, emit) async for handle in consumer
)
batch_generator = worker(work_context, tasks)
await process_batches(agreement, act, batch_generator, consumer)
try:
await process_batches(agreement.id, act, batch_generator, consumer)
except StopAsyncIteration:
pass
emit(events.WorkerFinished(agr_id=agreement.id))
except Exception:
emit(
Expand Down
12 changes: 2 additions & 10 deletions yapapi/executor/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ def timeout(self) -> Optional[timedelta]:
"""Return the optional timeout set for execution of this work."""
return None

@property
def contains_init_step(self) -> bool:
"""Return `True` iff this work item contains the initialization step."""
return False


class _InitStep(Work):
def register(self, commands: CommandContainer):
Expand Down Expand Up @@ -246,11 +241,6 @@ def timeout(self) -> Optional[timedelta]:
"""Return the optional timeout set for execution of all steps."""
return self._timeout

@property
def contains_init_step(self) -> bool:
"""Return `True` iff the steps include an initialization step."""
return any(isinstance(step, _InitStep) for step in self._steps)

async def prepare(self):
"""Execute the `prepare` hook for all the defined steps."""
for step in self._steps:
Expand All @@ -269,6 +259,8 @@ async def post(self):

@dataclass
class ExecOptions:
"""Options related to command batch execution."""

wait_for_results: bool = True
batch_timeout: Optional[timedelta] = None

Expand Down

0 comments on commit 5ffc639

Please sign in to comment.