Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retry GetExecBatchResults on ApiExceptions caused by GSB Errors #588

Merged
merged 10 commits into from
Aug 11, 2021
126 changes: 126 additions & 0 deletions tests/rest/test_activity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import List, Optional, Tuple, Type
from unittest.mock import Mock

import pytest

from ya_activity.exceptions import ApiException
from yapapi.rest.activity import BatchError, PollingBatch


GetExecBatchResultsSpec = Tuple[Optional[Exception], List[str]]


def mock_activity(specs: List[GetExecBatchResultsSpec]):
"""Create a mock activity.

The argument `specs` is a list of pairs specifying the behavior of subsequent calls
to `get_exec_batch_results()`: i-th pair corresponds to the i-th call.
The first element of the pair is an optional error raised by the call, the second element
is the activity state (the `.state` component of the object returned by `Activity.state()`).
"""
i = -1

async def mock_results(*_args, **_kwargs):
nonlocal specs, i
i += 1
error = specs[i][0]
if error:
raise error
return [Mock(index=0)]

async def mock_state():
nonlocal specs, i
state = specs[i][1]
return Mock(state=state)

return Mock(state=mock_state, _api=Mock(get_exec_batch_results=mock_results))


GSB_ERROR = ":( GSB error: some endpoint address not found :("


@pytest.mark.parametrize(
"specs, expected_error",
[
# No errors
([(None, ["Running", "Running"])], None),
# Exception other than ApiException should stop iteration over batch results
(
[(ValueError("!?"), ["Running", "Running"])],
ValueError,
),
# ApiException not related to GSB should stop iteration over batch results
(
[(ApiException(status=400), ["Running", "Running"])],
ApiException,
),
# As above, but with status 500
(
[
(
ApiException(http_resp=Mock(status=500, data='{"message": "???"}')),
["Running", "Running"],
)
],
ApiException,
),
# ApiException not related to GSB should raise BatchError if activity is terminated
(
[
(
ApiException(http_resp=Mock(status=500, data='{"message": "???"}')),
["Running", "Terminated"],
)
],
BatchError,
),
# GSB-related ApiException should cause retrying if the activity is running
(
[
(
ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')),
["Running", "Running"],
),
(None, ["Running", "Running"]),
],
None,
),
# As above, but max number of tries is reached
(
[
(
ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')),
["Running", "Running"],
)
]
* PollingBatch.GET_EXEC_BATCH_RESULTS_MAX_TRIES,
ApiException,
),
# GSB-related ApiException should raise BatchError if activity is terminated
(
[
(
ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')),
["Running", "Terminated"],
)
],
BatchError,
),
],
)
@pytest.mark.asyncio
async def test_polling_batch_on_gsb_error(
specs: List[GetExecBatchResultsSpec], expected_error: Optional[Type[Exception]]
) -> None:
"""Test the behavior of PollingBatch when get_exec_batch_results() raises exceptions."""

PollingBatch.GET_EXEC_BATCH_RESULTS_INTERVAL = 0.1

activity = mock_activity(specs)
batch = PollingBatch(activity, "batch_id", 1)
try:
async for _ in batch:
pass
assert expected_error is None
except Exception as error:
assert expected_error is not None and isinstance(error, expected_error)
96 changes: 75 additions & 21 deletions yapapi/rest/activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import datetime, timedelta, timezone
import json
import logging
from typing import AsyncIterator, List, Optional, Type, Any, Dict
from typing import AsyncIterator, List, Optional, Tuple, Type, Any, Dict

from typing_extensions import AsyncContextManager, AsyncIterable

Expand All @@ -21,6 +21,8 @@
)

from yapapi import events
from yapapi.rest.common import is_intermittent_error, SuppressedExceptions


_log = logging.getLogger("yapapi.rest")

Expand Down Expand Up @@ -73,8 +75,8 @@ async def send(self, script: List[dict], deadline: Optional[datetime] = None) ->
batch_id = await self._api.call_exec(self._id, yaa.ExeScriptRequest(text=script_txt))

if self._stream_events:
return StreamingBatch(self._api, self._id, batch_id, len(script), deadline)
return PollingBatch(self._api, self._id, batch_id, len(script), deadline)
return StreamingBatch(self, batch_id, len(script), deadline)
return PollingBatch(self, batch_id, len(script), deadline)

async def __aenter__(self) -> "Activity":
return self
Expand Down Expand Up @@ -146,22 +148,19 @@ class BatchTimeoutError(BatchError):
class Batch(abc.ABC, AsyncIterable[events.CommandEventContext]):
"""Abstract base class for iterating over events related to a batch running on provider."""

_api: RequestorControlApi
_activity_id: str
_activity: Activity
_batch_id: str
_size: int
_deadline: datetime

def __init__(
self,
api: RequestorControlApi,
activity_id: str,
activity: Activity,
batch_id: str,
batch_size: int,
deadline: Optional[datetime] = None,
) -> None:
self._api = api
self._activity_id = activity_id
self._activity = activity
self._batch_id = batch_id
self._size = batch_size
self._deadline = (
Expand All @@ -179,25 +178,80 @@ def id(self):
return self._batch_id


def _is_gsb_endpoint_not_found_error(err: ApiException) -> bool:
"""Check if `err` is caused by "Endpoint address not found" GSB error."""

if err.status != 500:
return False
try:
msg = json.loads(err.body)["message"]
return "GSB error" in msg and "endpoint address not found" in msg
except Exception:
_log.debug("Cannot read error message from ApiException", exc_info=True)
return False


class PollingBatch(Batch):
"""A `Batch` implementation that polls the server repeatedly for command status."""

GET_EXEC_BATCH_RESULTS_MAX_TRIES = 3
"""Max number of attempts to call GetExecBatchResults if a GSB error occurs."""

GET_EXEC_BATCH_RESULTS_INTERVAL = 3.0
"""Time in seconds before retrying GetExecBatchResults after a GSB error occurs."""

async def _activity_terminated(self) -> Tuple[bool, Optional[str], Optional[str]]:
"""Check if the activity we're using is in "Terminated" state."""
try:
state = await self._activity.state()
return "Terminated" in state.state, state.reason, state.error_message
except Exception:
_log.debug("Cannot query activity state", exc_info=True)
return False, None, None

async def _get_results(self, timeout: float) -> List[yaa.ExeScriptCommandResult]:
"""Call GetExecBatchResults with re-trying on "Endpoint address not found" GSB error."""

for n in range(self.GET_EXEC_BATCH_RESULTS_MAX_TRIES, 0, -1):
try:
results = await self._activity._api.get_exec_batch_results(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conflicts with changes in #548 , which should be merged first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#548 was a PR to master, and the current PR is to b0.6. When we will merge changes in b0.6 to master we'll have to merge those two sets of changes, but no need to worry about this now. I don't think we plan to backport #548 to b0.6, do we?

Copy link
Contributor

@filipgolem filipgolem Aug 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar PR will be a part of the next yajsapi release. IMO it would be great to backport #548.

self._activity._id, self._batch_id, _request_timeout=min(timeout, 5)
)
return results
except ApiException as err:
terminated, reason, error_msg = await self._activity_terminated()
if terminated:
raise BatchError("Activity terminated by provider", reason, error_msg)
# TODO: add and use a new Exception class (subclass of BatchError)
# to indicate closing the activity by the provider
if not _is_gsb_endpoint_not_found_error(err):
raise err
msg = "GetExecBatchResults failed due to GSB error"
if n > 1:
_log.debug("%s, retrying in %s s", msg, self.GET_EXEC_BATCH_RESULTS_INTERVAL)
await asyncio.sleep(self.GET_EXEC_BATCH_RESULTS_INTERVAL)
else:
_log.debug(
"%s, giving up after %d attempts",
msg,
self.GET_EXEC_BATCH_RESULTS_MAX_TRIES,
)
raise err

return []

async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]:
last_idx = 0

while last_idx < self._size:
timeout = self.seconds_left()
if timeout <= 0:
raise BatchTimeoutError()
try:
results: List[yaa.ExeScriptCommandResult] = await self._api.get_exec_batch_results(
self._activity_id, self._batch_id, _request_timeout=min(timeout, 5)
)
except asyncio.TimeoutError:
continue
except ApiException as err:
if err.status == 408:
continue
raise

results: List[yaa.ExeScriptCommandResult] = []
async with SuppressedExceptions(is_intermittent_error):
results = await self._get_results(timeout=min(timeout, 5))

any_new: bool = False
results = results[last_idx:]
for result in results:
Expand Down Expand Up @@ -227,13 +281,13 @@ class StreamingBatch(Batch):
async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]:
from aiohttp_sse_client import client as sse_client # type: ignore

api_client = self._api.api_client
api_client = self._activity._api.api_client
host = api_client.configuration.host
headers = api_client.default_headers

api_client.update_params_for_auth(headers, None, ["app_key"])

activity_id = self._activity_id
activity_id = self._activity._id
batch_id = self._batch_id
last_idx = self._size - 1

Expand Down