diff --git a/tests/rest/test_activity.py b/tests/rest/test_activity.py new file mode 100644 index 000000000..a10b20fbc --- /dev/null +++ b/tests/rest/test_activity.py @@ -0,0 +1,126 @@ +from typing import List, Optional, Tuple, Type +from unittest.mock import Mock + +import pytest + +from ya_activity.exceptions import ApiException +from yapapi.rest.activity import BatchError, PollingBatch + + +GetExecBatchResultsSpec = Tuple[Optional[Exception], List[str]] + + +def mock_activity(specs: List[GetExecBatchResultsSpec]): + """Create a mock activity. + + The argument `specs` is a list of pairs specifying the behavior of subsequent calls + to `get_exec_batch_results()`: i-th pair corresponds to the i-th call. + The first element of the pair is an optional error raised by the call, the second element + is the activity state (the `.state` component of the object returned by `Activity.state()`). + """ + i = -1 + + async def mock_results(*_args, **_kwargs): + nonlocal specs, i + i += 1 + error = specs[i][0] + if error: + raise error + return [Mock(index=0)] + + async def mock_state(): + nonlocal specs, i + state = specs[i][1] + return Mock(state=state) + + return Mock(state=mock_state, _api=Mock(get_exec_batch_results=mock_results)) + + +GSB_ERROR = ":( GSB error: some endpoint address not found :(" + + +@pytest.mark.parametrize( + "specs, expected_error", + [ + # No errors + ([(None, ["Running", "Running"])], None), + # Exception other than ApiException should stop iteration over batch results + ( + [(ValueError("!?"), ["Running", "Running"])], + ValueError, + ), + # ApiException not related to GSB should stop iteration over batch results + ( + [(ApiException(status=400), ["Running", "Running"])], + ApiException, + ), + # As above, but with status 500 + ( + [ + ( + ApiException(http_resp=Mock(status=500, data='{"message": "???"}')), + ["Running", "Running"], + ) + ], + ApiException, + ), + # ApiException not related to GSB should raise BatchError if activity is terminated + ( + [ + ( + ApiException(http_resp=Mock(status=500, data='{"message": "???"}')), + ["Running", "Terminated"], + ) + ], + BatchError, + ), + # GSB-related ApiException should cause retrying if the activity is running + ( + [ + ( + ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')), + ["Running", "Running"], + ), + (None, ["Running", "Running"]), + ], + None, + ), + # As above, but max number of tries is reached + ( + [ + ( + ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')), + ["Running", "Running"], + ) + ] + * PollingBatch.GET_EXEC_BATCH_RESULTS_MAX_TRIES, + ApiException, + ), + # GSB-related ApiException should raise BatchError if activity is terminated + ( + [ + ( + ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')), + ["Running", "Terminated"], + ) + ], + BatchError, + ), + ], +) +@pytest.mark.asyncio +async def test_polling_batch_on_gsb_error( + specs: List[GetExecBatchResultsSpec], expected_error: Optional[Type[Exception]] +) -> None: + """Test the behavior of PollingBatch when get_exec_batch_results() raises exceptions.""" + + PollingBatch.GET_EXEC_BATCH_RESULTS_INTERVAL = 0.1 + + activity = mock_activity(specs) + batch = PollingBatch(activity, "batch_id", 1) + try: + async for _ in batch: + pass + assert expected_error is None + except Exception as error: + assert expected_error is not None and isinstance(error, expected_error) diff --git a/yapapi/rest/activity.py b/yapapi/rest/activity.py index 6e1406126..2bae38e9c 100644 --- a/yapapi/rest/activity.py +++ b/yapapi/rest/activity.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta, timezone import json import logging -from typing import AsyncIterator, List, Optional, Type, Any, Dict +from typing import AsyncIterator, List, Optional, Tuple, Type, Any, Dict from typing_extensions import AsyncContextManager, AsyncIterable @@ -21,6 +21,8 @@ ) from yapapi import events +from yapapi.rest.common import is_intermittent_error, SuppressedExceptions + _log = logging.getLogger("yapapi.rest") @@ -73,8 +75,8 @@ async def send(self, script: List[dict], deadline: Optional[datetime] = None) -> batch_id = await self._api.call_exec(self._id, yaa.ExeScriptRequest(text=script_txt)) if self._stream_events: - return StreamingBatch(self._api, self._id, batch_id, len(script), deadline) - return PollingBatch(self._api, self._id, batch_id, len(script), deadline) + return StreamingBatch(self, batch_id, len(script), deadline) + return PollingBatch(self, batch_id, len(script), deadline) async def __aenter__(self) -> "Activity": return self @@ -146,22 +148,19 @@ class BatchTimeoutError(BatchError): class Batch(abc.ABC, AsyncIterable[events.CommandEventContext]): """Abstract base class for iterating over events related to a batch running on provider.""" - _api: RequestorControlApi - _activity_id: str + _activity: Activity _batch_id: str _size: int _deadline: datetime def __init__( self, - api: RequestorControlApi, - activity_id: str, + activity: Activity, batch_id: str, batch_size: int, deadline: Optional[datetime] = None, ) -> None: - self._api = api - self._activity_id = activity_id + self._activity = activity self._batch_id = batch_id self._size = batch_size self._deadline = ( @@ -179,25 +178,80 @@ def id(self): return self._batch_id +def _is_gsb_endpoint_not_found_error(err: ApiException) -> bool: + """Check if `err` is caused by "Endpoint address not found" GSB error.""" + + if err.status != 500: + return False + try: + msg = json.loads(err.body)["message"] + return "GSB error" in msg and "endpoint address not found" in msg + except Exception: + _log.debug("Cannot read error message from ApiException", exc_info=True) + return False + + class PollingBatch(Batch): """A `Batch` implementation that polls the server repeatedly for command status.""" + GET_EXEC_BATCH_RESULTS_MAX_TRIES = 3 + """Max number of attempts to call GetExecBatchResults if a GSB error occurs.""" + + GET_EXEC_BATCH_RESULTS_INTERVAL = 3.0 + """Time in seconds before retrying GetExecBatchResults after a GSB error occurs.""" + + async def _activity_terminated(self) -> Tuple[bool, Optional[str], Optional[str]]: + """Check if the activity we're using is in "Terminated" state.""" + try: + state = await self._activity.state() + return "Terminated" in state.state, state.reason, state.error_message + except Exception: + _log.debug("Cannot query activity state", exc_info=True) + return False, None, None + + async def _get_results(self, timeout: float) -> List[yaa.ExeScriptCommandResult]: + """Call GetExecBatchResults with re-trying on "Endpoint address not found" GSB error.""" + + for n in range(self.GET_EXEC_BATCH_RESULTS_MAX_TRIES, 0, -1): + try: + results = await self._activity._api.get_exec_batch_results( + self._activity._id, self._batch_id, _request_timeout=min(timeout, 5) + ) + return results + except ApiException as err: + terminated, reason, error_msg = await self._activity_terminated() + if terminated: + raise BatchError("Activity terminated by provider", reason, error_msg) + # TODO: add and use a new Exception class (subclass of BatchError) + # to indicate closing the activity by the provider + if not _is_gsb_endpoint_not_found_error(err): + raise err + msg = "GetExecBatchResults failed due to GSB error" + if n > 1: + _log.debug("%s, retrying in %s s", msg, self.GET_EXEC_BATCH_RESULTS_INTERVAL) + await asyncio.sleep(self.GET_EXEC_BATCH_RESULTS_INTERVAL) + else: + _log.debug( + "%s, giving up after %d attempts", + msg, + self.GET_EXEC_BATCH_RESULTS_MAX_TRIES, + ) + raise err + + return [] + async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]: last_idx = 0 + while last_idx < self._size: timeout = self.seconds_left() if timeout <= 0: raise BatchTimeoutError() - try: - results: List[yaa.ExeScriptCommandResult] = await self._api.get_exec_batch_results( - self._activity_id, self._batch_id, _request_timeout=min(timeout, 5) - ) - except asyncio.TimeoutError: - continue - except ApiException as err: - if err.status == 408: - continue - raise + + results: List[yaa.ExeScriptCommandResult] = [] + async with SuppressedExceptions(is_intermittent_error): + results = await self._get_results(timeout=min(timeout, 5)) + any_new: bool = False results = results[last_idx:] for result in results: @@ -227,13 +281,13 @@ class StreamingBatch(Batch): async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]: from aiohttp_sse_client import client as sse_client # type: ignore - api_client = self._api.api_client + api_client = self._activity._api.api_client host = api_client.configuration.host headers = api_client.default_headers api_client.update_params_for_auth(headers, None, ["app_key"]) - activity_id = self._activity_id + activity_id = self._activity._id batch_id = self._batch_id last_idx = self._size - 1