diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1d9adb6..1e573f5 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -26,12 +26,20 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install pytest coverage codecov + pip install pytest coverage codecov httpretty pip install . + type python + type pip + type pytest + pip freeze - name: Test with pytest env: SYPHT_API_BASE_ENDPOINT: ${{ vars.SYPHT_API_BASE_ENDPOINT }} SYPHT_API_KEY: ${{ secrets.SYPHT_API_KEY }} SYPHT_AUTH_ENDPOINT: ${{ vars.SYPHT_AUTH_ENDPOINT }} run: | - pytest -s tests/*.py + type python + type pip + type pytest + pip freeze + pytest -s tests/test*.py diff --git a/sypht/client.py b/sypht/client.py index 81593b6..9a6454d 100644 --- a/sypht/client.py +++ b/sypht/client.py @@ -6,8 +6,10 @@ from urllib.parse import quote_plus, urlencode, urljoin import requests +from requests.adapters import HTTPAdapter +from urllib3.util import Retry -from .util import fetch_all_pages +from sypht.util import fetch_all_pages SYPHT_API_BASE_ENDPOINT = "https://api.sypht.com" SYPHT_AUTH_ENDPOINT = "https://auth.sypht.com/oauth2/token" @@ -75,7 +77,23 @@ def __init__( @property def _create_session(self): - return requests.Session() + session = requests.Session() + retries = Retry( + total=None, # set connect, read, redirect, status, other instead + connect=3, + read=3, + redirect=0, + status=3, + status_forcelist=[429, 502, 503, 504], + other=0, # catch-all for other errors + allowed_methods=["GET"], + respect_retry_after_header=False, + backoff_factor=0.5, # 0.0, 0.5, 1.0, 2.0, 4.0 + # Support manual status handling in _parse_response. + raise_on_status=False, + ) + session.mount(self.base_endpoint, HTTPAdapter(max_retries=retries)) + return session def _authenticate_v2(self, endpoint, client_id, client_secret, audience): basic_auth_slug = b64encode( @@ -418,11 +436,14 @@ def get_annotations( from_date=None, to_date=None, endpoint=None, + rec_limit=None, + company_id=None, ): page_iter = fetch_all_pages( name="get_annotations", fetch_page=self._get_annotations, get_page=lambda response: response["annotations"], + rec_limit=rec_limit, ) annotations = [] for response in page_iter( @@ -433,6 +454,7 @@ def get_annotations( from_date=from_date, to_date=to_date, endpoint=endpoint, + company_id=company_id, ): annotations.extend(response["annotations"]) return {"annotations": annotations} @@ -446,6 +468,7 @@ def _get_annotations( from_date=None, to_date=None, endpoint=None, + company_id=None, offset=0, ): """Fetch a single page of annotations skipping the given offset number of pages first. Use get_annotations to fetch all pages.""" @@ -462,6 +485,8 @@ def _get_annotations( filters.append("fromDate=" + from_date) if to_date is not None: filters.append("toDate=" + to_date) + if company_id is not None: + filters.append("companyId=" + company_id) endpoint = urljoin( endpoint or self.base_endpoint, ("/app/annotations?" + "&".join(filters)) @@ -471,11 +496,12 @@ def _get_annotations( headers["Content-Type"] = "application/json" return self._parse_response(self.requests.get(endpoint, headers=headers)) - def get_annotations_for_docs(self, doc_ids, endpoint=None): + def get_annotations_for_docs(self, doc_ids, endpoint=None, rec_limit=None): page_iter = fetch_all_pages( name="get_annotations_for_docs", fetch_page=self._get_annotations_for_docs, get_page=lambda response: response["annotations"], + rec_limit=rec_limit, ) annotations = [] for response in page_iter( diff --git a/sypht/util.py b/sypht/util.py index 01e2b1d..16bbfb9 100644 --- a/sypht/util.py +++ b/sypht/util.py @@ -1,11 +1,14 @@ +import logging from typing import Any, Callable, Iterator, List +DEFAULT_REC_LIMIT = 100_000 + def fetch_all_pages( name: str, fetch_page: Callable[..., Any], get_page: Callable[..., List[Any]] = lambda x: x, - rec_limit=20000, + rec_limit=DEFAULT_REC_LIMIT, ) -> Callable[..., Iterator[Any]]: """Returns an iterator that calls fetch_page with an offset that we increment by the number of pages fetched. Stop if page returns empty list. @@ -13,6 +16,10 @@ def fetch_all_pages( :param get_page: a function that extracts the page from the response which should be a list """ + # Enforce a default so that the loop will stop. + if rec_limit is None: + rec_limit = DEFAULT_REC_LIMIT + def fetch_all_pages(*args, **kwargs) -> Iterator[Any]: page_count = 0 recs = 0 @@ -31,17 +38,20 @@ def fetch_all_pages(*args, **kwargs) -> Iterator[Any]: ) except Exception as err: raise Exception( - f"Failed fetching for {name} for offset={page_count - 1} (records fetched so far:{recs})" + f"Failed fetching for {name} for offset={page_count - 1} (page={page_count}) (records fetched so far:{recs}). Cause: {err}" ) from err try: page = get_page(response) except Exception as err: raise Exception( - f"get_page failed to extract page from response for {name} for offset={page_count - 1} (records fetched so far:{recs})" + f"get_page failed to extract page from response for {name} for offset={page_count - 1} (page={page_count}) (records fetched so far:{recs}). Cause: {err}" ) from err if len(page) == 0: break recs += len(page) + logging.info( + f"fetch_all_pages({name}): fetched page {page_count} (records={recs})" + ) yield response return fetch_all_pages diff --git a/tests/tests_client.py b/tests/tests_client.py index c01c9c2..d21ed6a 100644 --- a/tests/tests_client.py +++ b/tests/tests_client.py @@ -1,9 +1,13 @@ -import os +import json import unittest import warnings from datetime import datetime, timedelta +from unittest.mock import Mock, patch from uuid import UUID, uuid4 +import httpretty +import pytest + from sypht.client import SyphtClient @@ -93,5 +97,75 @@ def test_reauthentication(self): self.assertFalse(self.sypht_client._is_token_expired()) +class RetryTest(unittest.TestCase): + """Test the global retry logic works as we expect it to.""" + + @patch.object(SyphtClient, "_authenticate_v2", return_value=("access_token", 100)) + @patch.object(SyphtClient, "_authenticate_v1", return_value=("access_token2", 100)) + @httpretty.activate(verbose=True, allow_net_connect=False) + def test_it_should_retry_n_times(self, auth_v1: Mock, auth_v2: Mock): + # arrange + self.count = 0 + + def get_annotations(request, uri, response_headers): + self.count += 1 + # 1 req + 3 retries = 4 + if self.count == 4: + return [200, response_headers, json.dumps({"annotations": []})] + return [502, response_headers, json.dumps({})] + + httpretty.register_uri( + httpretty.GET, + "https://api.sypht.com/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01", + body=get_annotations, + ) + + sypht_client = SyphtClient(base_endpoint="https://api.sypht.com") + + # act / assert + response = sypht_client.get_annotations( + from_date=datetime( + year=2021, month=1, day=1, hour=0, minute=0, second=0 + ).strftime("%Y-%m-%d"), + to_date=datetime( + year=2021, month=1, day=1, hour=0, minute=0, second=0 + ).strftime("%Y-%m-%d"), + ) + + assert response == {"annotations": []} + + @patch.object(SyphtClient, "_authenticate_v2", return_value=("access_token", 100)) + @patch.object(SyphtClient, "_authenticate_v1", return_value=("access_token2", 100)) + @httpretty.activate(verbose=True, allow_net_connect=False) + def test_retry_should_eventually_fail_for_50x(self, auth_v1: Mock, auth_v2: Mock): + # arrange + self.count = 0 + + def get_annotations(request, uri, response_headers): + self.count += 1 + return [502, response_headers, json.dumps({})] + + httpretty.register_uri( + httpretty.GET, + "https://api.sypht.com/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01", + body=get_annotations, + ) + + sypht_client = SyphtClient(base_endpoint="https://api.sypht.com") + + # act / assert + with self.assertRaisesRegex(Exception, ".") as e: + sypht_client.get_annotations( + from_date=datetime( + year=2021, month=1, day=1, hour=0, minute=0, second=0 + ).strftime("%Y-%m-%d"), + to_date=datetime( + year=2021, month=1, day=1, hour=0, minute=0, second=0 + ).strftime("%Y-%m-%d"), + ) + + assert self.count == 4, "should be 1 req + 3 retries" + + if __name__ == "__main__": unittest.main() diff --git a/tests/tests_util.py b/tests/tests_util.py index 9408a91..661ac26 100644 --- a/tests/tests_util.py +++ b/tests/tests_util.py @@ -1,6 +1,6 @@ import pytest -from sypht.util import fetch_all_pages +from sypht.util import DEFAULT_REC_LIMIT, fetch_all_pages def test_fetch_all_pages_can_fetch_one_page(): @@ -73,6 +73,7 @@ def fetch_something(offset, pages=1): def test_fetch_all_pages_never_ending(): """Fail if fetch more than n pages.""" + # arrange def never_ending(*args, **kwargs): return [0, 1, 2] @@ -85,7 +86,30 @@ def never_ending(*args, **kwargs): results += page # assert - assert "more than the limit: 20000" in str(exc_info) + assert f"more than the limit: {DEFAULT_REC_LIMIT}" in str(exc_info) + + +def test_fetch_with_rec_limit(): + # arrange + page_size = 5 + + def fetch_something(offset, pages=1): + pages0 = pages - 1 + if offset > pages0: + return [] + start = offset * page_size + page = range(start, start + page_size) + return list(page) + + # act + page_iter = fetch_all_pages(name="test1", fetch_page=fetch_something, rec_limit=2) + results = [] + with pytest.raises(Exception) as exc_info: + for page in page_iter(): + results += page + + # assert + assert f"fetched 5 records which is more than the limit: 2" in str(exc_info) def test_fetch_all_pages_handle_error():