From da2f159d54344f871bf9779edd705b73aa329f0a Mon Sep 17 00:00:00 2001 From: azawlocki Date: Tue, 10 Aug 2021 12:47:11 +0200 Subject: [PATCH 1/9] Implement re-tries for ApiExceptions caused by GSB Errors --- yapapi/rest/activity.py | 78 ++++++++++++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 20 deletions(-) diff --git a/yapapi/rest/activity.py b/yapapi/rest/activity.py index 6e1406126..d15ca7ef3 100644 --- a/yapapi/rest/activity.py +++ b/yapapi/rest/activity.py @@ -21,6 +21,8 @@ ) from yapapi import events +from yapapi.rest.common import is_intermittent_error, SuppressedExceptions + _log = logging.getLogger("yapapi.rest") @@ -73,8 +75,8 @@ async def send(self, script: List[dict], deadline: Optional[datetime] = None) -> batch_id = await self._api.call_exec(self._id, yaa.ExeScriptRequest(text=script_txt)) if self._stream_events: - return StreamingBatch(self._api, self._id, batch_id, len(script), deadline) - return PollingBatch(self._api, self._id, batch_id, len(script), deadline) + return StreamingBatch(self, batch_id, len(script), deadline) + return PollingBatch(self, batch_id, len(script), deadline) async def __aenter__(self) -> "Activity": return self @@ -146,22 +148,19 @@ class BatchTimeoutError(BatchError): class Batch(abc.ABC, AsyncIterable[events.CommandEventContext]): """Abstract base class for iterating over events related to a batch running on provider.""" - _api: RequestorControlApi - _activity_id: str + _activity: Activity _batch_id: str _size: int _deadline: datetime def __init__( self, - api: RequestorControlApi, - activity_id: str, + activity: Activity, batch_id: str, batch_size: int, deadline: Optional[datetime] = None, ) -> None: - self._api = api - self._activity_id = activity_id + self._activity = activity self._batch_id = batch_id self._size = batch_size self._deadline = ( @@ -182,22 +181,61 @@ def id(self): class PollingBatch(Batch): """A `Batch` implementation that polls the server repeatedly for command status.""" + async def _activity_terminated(self) -> bool: + """Check if the activity we're using is in "Terminated" state.""" + try: + state_list = await self._activity.state().state # type: ignore + return "Terminated" in state_list + except Exception: + _log.debug("Cannot query activity state", exc_info=True) + return False + + def _is_endpoint_not_found_error(self, err: ApiException) -> bool: + """Check if `err` is caused by "Endpoint address not found" GSB error.""" + + if err.status != 500: + return False + try: + msg = json.loads(err.body)["message"] + return "GSB error" in msg and "Endpoint address not found" in msg + except Exception: + _log.debug("Cannot read error message from ApiException", exc_info=True) + return False + + async def _get_results(self, timeout: float, num_tries: int, delay: float = 3.0): + """Call GetExecBatchResults with re-trying on "Endpoint address not found" GSB error.""" + + while num_tries: + try: + results = await self._activity._api.get_exec_batch_results( + self._activity._id, self._batch_id, _request_timeout=min(timeout, 5) + ) + return results + except ApiException as err: + if await self._activity_terminated(): + _log.debug("Activity %s terminated by provider", self._activity._id) + # TODO: add and use a new Exception class (subclass of BatchError) + # to indicate closing the activity by the provider + raise err + if not self._is_endpoint_not_found_error(err): + raise err + num_tries -= 1 + if num_tries: + _log.debug("Retrying ") + await asyncio.sleep(delay) + async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]: last_idx = 0 + while last_idx < self._size: timeout = self.seconds_left() if timeout <= 0: raise BatchTimeoutError() - try: - results: List[yaa.ExeScriptCommandResult] = await self._api.get_exec_batch_results( - self._activity_id, self._batch_id, _request_timeout=min(timeout, 5) - ) - except asyncio.TimeoutError: - continue - except ApiException as err: - if err.status == 408: - continue - raise + + results: List[yaa.ExeScriptCommandResult] = [] + async with SuppressedExceptions(is_intermittent_error): + results = await self._get_results(timeout=min(timeout, 5), num_tries=3) + any_new: bool = False results = results[last_idx:] for result in results: @@ -227,13 +265,13 @@ class StreamingBatch(Batch): async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]: from aiohttp_sse_client import client as sse_client # type: ignore - api_client = self._api.api_client + api_client = self._activity._api.api_client host = api_client.configuration.host headers = api_client.default_headers api_client.update_params_for_auth(headers, None, ["app_key"]) - activity_id = self._activity_id + activity_id = self._activity._id batch_id = self._batch_id last_idx = self._size - 1 From 6b04dd6a1f56fbf905f7807c4a5157092a6907b6 Mon Sep 17 00:00:00 2001 From: filipgolem <44880692+filipgolem@users.noreply.github.com> Date: Tue, 10 Aug 2021 14:17:25 +0200 Subject: [PATCH 2/9] Endpoint -> endpoint --- yapapi/rest/activity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yapapi/rest/activity.py b/yapapi/rest/activity.py index d15ca7ef3..b35d1b720 100644 --- a/yapapi/rest/activity.py +++ b/yapapi/rest/activity.py @@ -197,7 +197,7 @@ def _is_endpoint_not_found_error(self, err: ApiException) -> bool: return False try: msg = json.loads(err.body)["message"] - return "GSB error" in msg and "Endpoint address not found" in msg + return "GSB error" in msg and "endpoint address not found" in msg except Exception: _log.debug("Cannot read error message from ApiException", exc_info=True) return False From c6356d95a4e897a4a6286fbe4e1280a8788553cd Mon Sep 17 00:00:00 2001 From: azawlocki Date: Tue, 10 Aug 2021 20:13:23 +0200 Subject: [PATCH 3/9] Apply fixes after code review --- yapapi/rest/activity.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/yapapi/rest/activity.py b/yapapi/rest/activity.py index b35d1b720..4782ce653 100644 --- a/yapapi/rest/activity.py +++ b/yapapi/rest/activity.py @@ -184,8 +184,8 @@ class PollingBatch(Batch): async def _activity_terminated(self) -> bool: """Check if the activity we're using is in "Terminated" state.""" try: - state_list = await self._activity.state().state # type: ignore - return "Terminated" in state_list + state = await self._activity.state() + return "Terminated" in state.state except Exception: _log.debug("Cannot query activity state", exc_info=True) return False @@ -220,9 +220,12 @@ async def _get_results(self, timeout: float, num_tries: int, delay: float = 3.0) if not self._is_endpoint_not_found_error(err): raise err num_tries -= 1 + msg = "GetExecBatchResults failed due to GSB error" if num_tries: - _log.debug("Retrying ") + _log.debug("%s, retrying in %s s", msg, delay) await asyncio.sleep(delay) + else: + _log.debug("%s, giving up", msg) async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]: last_idx = 0 From e5a2b3c4150faf0120272eece54f821691864e85 Mon Sep 17 00:00:00 2001 From: azawlocki Date: Tue, 10 Aug 2021 20:48:26 +0200 Subject: [PATCH 4/9] Fixes after code review: part II --- yapapi/rest/activity.py | 42 +++++++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/yapapi/rest/activity.py b/yapapi/rest/activity.py index 4782ce653..4e7ca9e7f 100644 --- a/yapapi/rest/activity.py +++ b/yapapi/rest/activity.py @@ -178,6 +178,19 @@ def id(self): return self._batch_id +def _is_gsb_endpoint_not_found_error(err: ApiException) -> bool: + """Check if `err` is caused by "Endpoint address not found" GSB error.""" + + if err.status != 500: + return False + try: + msg = json.loads(err.body)["message"] + return "GSB error" in msg and "endpoint address not found" in msg + except Exception: + _log.debug("Cannot read error message from ApiException", exc_info=True) + return False + + class PollingBatch(Batch): """A `Batch` implementation that polls the server repeatedly for command status.""" @@ -190,22 +203,13 @@ async def _activity_terminated(self) -> bool: _log.debug("Cannot query activity state", exc_info=True) return False - def _is_endpoint_not_found_error(self, err: ApiException) -> bool: - """Check if `err` is caused by "Endpoint address not found" GSB error.""" - - if err.status != 500: - return False - try: - msg = json.loads(err.body)["message"] - return "GSB error" in msg and "endpoint address not found" in msg - except Exception: - _log.debug("Cannot read error message from ApiException", exc_info=True) - return False - - async def _get_results(self, timeout: float, num_tries: int, delay: float = 3.0): + async def _get_results( + self, timeout: float, num_tries: int, delay: float = 3.0 + ) -> List[yaa.ExeScriptCommandResult]: """Call GetExecBatchResults with re-trying on "Endpoint address not found" GSB error.""" - while num_tries: + for n in range(num_tries, 0, -1): + # n = num_tries, ... , 1 try: results = await self._activity._api.get_exec_batch_results( self._activity._id, self._batch_id, _request_timeout=min(timeout, 5) @@ -217,15 +221,17 @@ async def _get_results(self, timeout: float, num_tries: int, delay: float = 3.0) # TODO: add and use a new Exception class (subclass of BatchError) # to indicate closing the activity by the provider raise err - if not self._is_endpoint_not_found_error(err): + if not _is_gsb_endpoint_not_found_error(err): raise err - num_tries -= 1 msg = "GetExecBatchResults failed due to GSB error" - if num_tries: + if n > 1: _log.debug("%s, retrying in %s s", msg, delay) await asyncio.sleep(delay) else: - _log.debug("%s, giving up", msg) + _log.debug("%s, giving up after %d attempts", msg, num_tries) + raise err + + return [] async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]: last_idx = 0 From 02c638b0621509e4d9e1fa37706a56c1ec6609b7 Mon Sep 17 00:00:00 2001 From: filipgolem <44880692+filipgolem@users.noreply.github.com> Date: Wed, 11 Aug 2021 11:58:59 +0200 Subject: [PATCH 5/9] debug -> warning --- yapapi/rest/activity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yapapi/rest/activity.py b/yapapi/rest/activity.py index 4e7ca9e7f..1f977f0cd 100644 --- a/yapapi/rest/activity.py +++ b/yapapi/rest/activity.py @@ -217,7 +217,7 @@ async def _get_results( return results except ApiException as err: if await self._activity_terminated(): - _log.debug("Activity %s terminated by provider", self._activity._id) + _log.warning("Activity %s terminated by provider", self._activity._id) # TODO: add and use a new Exception class (subclass of BatchError) # to indicate closing the activity by the provider raise err From 25b2a14bc5009dbb720234841fd65a015c94dde2 Mon Sep 17 00:00:00 2001 From: Filip Date: Wed, 11 Aug 2021 12:22:01 +0200 Subject: [PATCH 6/9] Improve logs when activity is prematurely terminated on the provider --- yapapi/rest/activity.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/yapapi/rest/activity.py b/yapapi/rest/activity.py index 1f977f0cd..a20e9accb 100644 --- a/yapapi/rest/activity.py +++ b/yapapi/rest/activity.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta, timezone import json import logging -from typing import AsyncIterator, List, Optional, Type, Any, Dict +from typing import AsyncIterator, List, Optional, Tuple, Type, Any, Dict from typing_extensions import AsyncContextManager, AsyncIterable @@ -194,14 +194,14 @@ def _is_gsb_endpoint_not_found_error(err: ApiException) -> bool: class PollingBatch(Batch): """A `Batch` implementation that polls the server repeatedly for command status.""" - async def _activity_terminated(self) -> bool: + async def _activity_terminated(self) -> Tuple[bool, Optional[str], Optional[str]]: """Check if the activity we're using is in "Terminated" state.""" try: state = await self._activity.state() - return "Terminated" in state.state + return "Terminated" in state.state, state.reason, state.error_message except Exception: _log.debug("Cannot query activity state", exc_info=True) - return False + return False, None, None async def _get_results( self, timeout: float, num_tries: int, delay: float = 3.0 @@ -216,8 +216,14 @@ async def _get_results( ) return results except ApiException as err: - if await self._activity_terminated(): - _log.warning("Activity %s terminated by provider", self._activity._id) + terminated, reason, errMsg = await self._activity_terminated() + if terminated: + _log.warning( + "Activity %s terminated by provider. Reason: %s, error: %s", + self._activity._id, + reason, + errMsg + ) # TODO: add and use a new Exception class (subclass of BatchError) # to indicate closing the activity by the provider raise err From d22fdd28ed8fcbc0b6b4153c56c13d15417e0460 Mon Sep 17 00:00:00 2001 From: Filip Date: Wed, 11 Aug 2021 12:28:21 +0200 Subject: [PATCH 7/9] Formatting --- yapapi/rest/activity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yapapi/rest/activity.py b/yapapi/rest/activity.py index a20e9accb..a6bb3ef7d 100644 --- a/yapapi/rest/activity.py +++ b/yapapi/rest/activity.py @@ -222,7 +222,7 @@ async def _get_results( "Activity %s terminated by provider. Reason: %s, error: %s", self._activity._id, reason, - errMsg + errMsg, ) # TODO: add and use a new Exception class (subclass of BatchError) # to indicate closing the activity by the provider From a5d0824f70e3410ffeaea7adb456f9b330cf7dfc Mon Sep 17 00:00:00 2001 From: azawlocki Date: Wed, 11 Aug 2021 18:16:35 +0200 Subject: [PATCH 8/9] Raise BatchError when an activity is terminated by the provider --- yapapi/rest/activity.py | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/yapapi/rest/activity.py b/yapapi/rest/activity.py index a6bb3ef7d..2bae38e9c 100644 --- a/yapapi/rest/activity.py +++ b/yapapi/rest/activity.py @@ -194,6 +194,12 @@ def _is_gsb_endpoint_not_found_error(err: ApiException) -> bool: class PollingBatch(Batch): """A `Batch` implementation that polls the server repeatedly for command status.""" + GET_EXEC_BATCH_RESULTS_MAX_TRIES = 3 + """Max number of attempts to call GetExecBatchResults if a GSB error occurs.""" + + GET_EXEC_BATCH_RESULTS_INTERVAL = 3.0 + """Time in seconds before retrying GetExecBatchResults after a GSB error occurs.""" + async def _activity_terminated(self) -> Tuple[bool, Optional[str], Optional[str]]: """Check if the activity we're using is in "Terminated" state.""" try: @@ -203,38 +209,33 @@ async def _activity_terminated(self) -> Tuple[bool, Optional[str], Optional[str] _log.debug("Cannot query activity state", exc_info=True) return False, None, None - async def _get_results( - self, timeout: float, num_tries: int, delay: float = 3.0 - ) -> List[yaa.ExeScriptCommandResult]: + async def _get_results(self, timeout: float) -> List[yaa.ExeScriptCommandResult]: """Call GetExecBatchResults with re-trying on "Endpoint address not found" GSB error.""" - for n in range(num_tries, 0, -1): - # n = num_tries, ... , 1 + for n in range(self.GET_EXEC_BATCH_RESULTS_MAX_TRIES, 0, -1): try: results = await self._activity._api.get_exec_batch_results( self._activity._id, self._batch_id, _request_timeout=min(timeout, 5) ) return results except ApiException as err: - terminated, reason, errMsg = await self._activity_terminated() + terminated, reason, error_msg = await self._activity_terminated() if terminated: - _log.warning( - "Activity %s terminated by provider. Reason: %s, error: %s", - self._activity._id, - reason, - errMsg, - ) + raise BatchError("Activity terminated by provider", reason, error_msg) # TODO: add and use a new Exception class (subclass of BatchError) # to indicate closing the activity by the provider - raise err if not _is_gsb_endpoint_not_found_error(err): raise err msg = "GetExecBatchResults failed due to GSB error" if n > 1: - _log.debug("%s, retrying in %s s", msg, delay) - await asyncio.sleep(delay) + _log.debug("%s, retrying in %s s", msg, self.GET_EXEC_BATCH_RESULTS_INTERVAL) + await asyncio.sleep(self.GET_EXEC_BATCH_RESULTS_INTERVAL) else: - _log.debug("%s, giving up after %d attempts", msg, num_tries) + _log.debug( + "%s, giving up after %d attempts", + msg, + self.GET_EXEC_BATCH_RESULTS_MAX_TRIES, + ) raise err return [] @@ -249,7 +250,7 @@ async def __aiter__(self) -> AsyncIterator[events.CommandEventContext]: results: List[yaa.ExeScriptCommandResult] = [] async with SuppressedExceptions(is_intermittent_error): - results = await self._get_results(timeout=min(timeout, 5), num_tries=3) + results = await self._get_results(timeout=min(timeout, 5)) any_new: bool = False results = results[last_idx:] From 3e9c236e0c70f7eb53b68f0598d8eed304496450 Mon Sep 17 00:00:00 2001 From: azawlocki Date: Wed, 11 Aug 2021 18:18:01 +0200 Subject: [PATCH 9/9] Add unit tests for PollingBatch behavior when GSB errors occur --- tests/rest/test_activity.py | 126 ++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 tests/rest/test_activity.py diff --git a/tests/rest/test_activity.py b/tests/rest/test_activity.py new file mode 100644 index 000000000..a10b20fbc --- /dev/null +++ b/tests/rest/test_activity.py @@ -0,0 +1,126 @@ +from typing import List, Optional, Tuple, Type +from unittest.mock import Mock + +import pytest + +from ya_activity.exceptions import ApiException +from yapapi.rest.activity import BatchError, PollingBatch + + +GetExecBatchResultsSpec = Tuple[Optional[Exception], List[str]] + + +def mock_activity(specs: List[GetExecBatchResultsSpec]): + """Create a mock activity. + + The argument `specs` is a list of pairs specifying the behavior of subsequent calls + to `get_exec_batch_results()`: i-th pair corresponds to the i-th call. + The first element of the pair is an optional error raised by the call, the second element + is the activity state (the `.state` component of the object returned by `Activity.state()`). + """ + i = -1 + + async def mock_results(*_args, **_kwargs): + nonlocal specs, i + i += 1 + error = specs[i][0] + if error: + raise error + return [Mock(index=0)] + + async def mock_state(): + nonlocal specs, i + state = specs[i][1] + return Mock(state=state) + + return Mock(state=mock_state, _api=Mock(get_exec_batch_results=mock_results)) + + +GSB_ERROR = ":( GSB error: some endpoint address not found :(" + + +@pytest.mark.parametrize( + "specs, expected_error", + [ + # No errors + ([(None, ["Running", "Running"])], None), + # Exception other than ApiException should stop iteration over batch results + ( + [(ValueError("!?"), ["Running", "Running"])], + ValueError, + ), + # ApiException not related to GSB should stop iteration over batch results + ( + [(ApiException(status=400), ["Running", "Running"])], + ApiException, + ), + # As above, but with status 500 + ( + [ + ( + ApiException(http_resp=Mock(status=500, data='{"message": "???"}')), + ["Running", "Running"], + ) + ], + ApiException, + ), + # ApiException not related to GSB should raise BatchError if activity is terminated + ( + [ + ( + ApiException(http_resp=Mock(status=500, data='{"message": "???"}')), + ["Running", "Terminated"], + ) + ], + BatchError, + ), + # GSB-related ApiException should cause retrying if the activity is running + ( + [ + ( + ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')), + ["Running", "Running"], + ), + (None, ["Running", "Running"]), + ], + None, + ), + # As above, but max number of tries is reached + ( + [ + ( + ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')), + ["Running", "Running"], + ) + ] + * PollingBatch.GET_EXEC_BATCH_RESULTS_MAX_TRIES, + ApiException, + ), + # GSB-related ApiException should raise BatchError if activity is terminated + ( + [ + ( + ApiException(http_resp=Mock(status=500, data=f'{{"message": "{GSB_ERROR}"}}')), + ["Running", "Terminated"], + ) + ], + BatchError, + ), + ], +) +@pytest.mark.asyncio +async def test_polling_batch_on_gsb_error( + specs: List[GetExecBatchResultsSpec], expected_error: Optional[Type[Exception]] +) -> None: + """Test the behavior of PollingBatch when get_exec_batch_results() raises exceptions.""" + + PollingBatch.GET_EXEC_BATCH_RESULTS_INTERVAL = 0.1 + + activity = mock_activity(specs) + batch = PollingBatch(activity, "batch_id", 1) + try: + async for _ in batch: + pass + assert expected_error is None + except Exception as error: + assert expected_error is not None and isinstance(error, expected_error)