Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: bump rec_limit and allow override #43

Merged
merged 14 commits into from
Sep 27, 2023
8 changes: 8 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,18 @@ jobs:
python -m pip install --upgrade pip
pip install pytest coverage codecov
pip install .
type python
type pip
type pytest
pip freeze
- name: Test with pytest
env:
SYPHT_API_BASE_ENDPOINT: ${{ vars.SYPHT_API_BASE_ENDPOINT }}
SYPHT_API_KEY: ${{ secrets.SYPHT_API_KEY }}
SYPHT_AUTH_ENDPOINT: ${{ vars.SYPHT_AUTH_ENDPOINT }}
run: |
type python
type pip
type pytest
pip freeze
pytest -s tests/*.py
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
"Programming Language :: Python :: 3.7",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
install_requires=["requests>=2.26.0", "urllib3>=1.26.5"],
install_requires=["requests>=2.26.0", "urllib3>=1.26.5,<2.0.0"],
)
34 changes: 31 additions & 3 deletions sypht/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from urllib.parse import quote_plus, urlencode, urljoin

import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

from .util import fetch_all_pages
from sypht.util import fetch_all_pages

SYPHT_API_BASE_ENDPOINT = "https://api.sypht.com"
SYPHT_AUTH_ENDPOINT = "https://auth.sypht.com/oauth2/token"
Expand Down Expand Up @@ -73,9 +75,27 @@ def __init__(
self._company_id = None
self._authenticate_client()

@property
def _retry_adapter(self):
retry_strategy = Retry(
total=None, # set connect, read, redirect, status, other instead
connect=3,
read=3,
redirect=0,
status=3,
status_forcelist=[429, 502, 503, 504],
other=0, # catch-all for other errors
allowed_methods=["GET"],
respect_retry_after_header=False,
backoff_factor=0.5, # 0.0, 0.5, 1.0, 2.0, 4.0
)
return HTTPAdapter(max_retries=retry_strategy)

@property
def _create_session(self):
return requests.Session()
session = requests.Session()
session.mount(self.base_endpoint, self._retry_adapter)
return session

def _authenticate_v2(self, endpoint, client_id, client_secret, audience):
basic_auth_slug = b64encode(
Expand Down Expand Up @@ -418,11 +438,14 @@ def get_annotations(
from_date=None,
to_date=None,
endpoint=None,
rec_limit=None,
company_id=None,
):
page_iter = fetch_all_pages(
name="get_annotations",
fetch_page=self._get_annotations,
get_page=lambda response: response["annotations"],
rec_limit=rec_limit,
)
annotations = []
for response in page_iter(
Expand All @@ -433,6 +456,7 @@ def get_annotations(
from_date=from_date,
to_date=to_date,
endpoint=endpoint,
company_id=company_id,
):
annotations.extend(response["annotations"])
return {"annotations": annotations}
Expand All @@ -446,6 +470,7 @@ def _get_annotations(
from_date=None,
to_date=None,
endpoint=None,
company_id=None,
offset=0,
):
"""Fetch a single page of annotations skipping the given offset number of pages first. Use get_annotations to fetch all pages."""
Expand All @@ -462,6 +487,8 @@ def _get_annotations(
filters.append("fromDate=" + from_date)
if to_date is not None:
filters.append("toDate=" + to_date)
if company_id is not None:
filters.append("companyId=" + company_id)

endpoint = urljoin(
endpoint or self.base_endpoint, ("/app/annotations?" + "&".join(filters))
Expand All @@ -471,11 +498,12 @@ def _get_annotations(
headers["Content-Type"] = "application/json"
return self._parse_response(self.requests.get(endpoint, headers=headers))

def get_annotations_for_docs(self, doc_ids, endpoint=None):
def get_annotations_for_docs(self, doc_ids, endpoint=None, rec_limit=None):
page_iter = fetch_all_pages(
name="get_annotations_for_docs",
fetch_page=self._get_annotations_for_docs,
get_page=lambda response: response["annotations"],
rec_limit=rec_limit,
)
annotations = []
for response in page_iter(
Expand Down
16 changes: 13 additions & 3 deletions sypht/util.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import logging
from typing import Any, Callable, Iterator, List

DEFAULT_REC_LIMIT = 100_000


def fetch_all_pages(
name: str,
fetch_page: Callable[..., Any],
get_page: Callable[..., List[Any]] = lambda x: x,
rec_limit=20000,
rec_limit=DEFAULT_REC_LIMIT,
) -> Callable[..., Iterator[Any]]:
"""Returns an iterator that calls fetch_page with an offset that we increment by the number of pages fetched. Stop if page returns empty list.

:param fetch_page: a function that makes an api call to fetch a page of results (using zero-based offset)
:param get_page: a function that extracts the page from the response which should be a list
"""

# Enforce a default so that the loop will stop.
if rec_limit is None:
rec_limit = DEFAULT_REC_LIMIT

def fetch_all_pages(*args, **kwargs) -> Iterator[Any]:
page_count = 0
recs = 0
Expand All @@ -31,17 +38,20 @@ def fetch_all_pages(*args, **kwargs) -> Iterator[Any]:
)
except Exception as err:
raise Exception(
f"Failed fetching for {name} for offset={page_count - 1} (records fetched so far:{recs})"
f"Failed fetching for {name} for offset={page_count - 1} (page={page_count}) (records fetched so far:{recs}). Cause: {err}"
) from err
try:
page = get_page(response)
except Exception as err:
raise Exception(
f"get_page failed to extract page from response for {name} for offset={page_count - 1} (records fetched so far:{recs})"
f"get_page failed to extract page from response for {name} for offset={page_count - 1} (page={page_count}) (records fetched so far:{recs}). Cause: {err}"
) from err
if len(page) == 0:
break
recs += len(page)
logging.info(
f"fetch_all_pages({name}): fetched page {page_count} (records={recs})"
)
yield response

return fetch_all_pages
64 changes: 63 additions & 1 deletion tests/tests_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import unittest
import warnings
from datetime import datetime, timedelta
from http.client import HTTPMessage
from unittest.mock import ANY, Mock, call, patch
from uuid import UUID, uuid4

from sypht.client import SyphtClient
Expand Down Expand Up @@ -93,5 +94,66 @@ def test_reauthentication(self):
self.assertFalse(self.sypht_client._is_token_expired())


class RetryTest(unittest.TestCase):
"""Test the global retry logic works as we expect it to."""

@patch.object(SyphtClient, "_authenticate_v2", return_value=("access_token", 100))
@patch.object(SyphtClient, "_authenticate_v1", return_value=("access_token2", 100))
@patch("urllib3.connectionpool.HTTPConnectionPool._get_conn")
def test_it_should_eventually_fail_for_50x(
self, getconn_mock: Mock, auth_v1: Mock, auth_v2: Mock
):
"""See https://stackoverflow.com/questions/66497627/how-to-test-retry-attempts-in-python-using-the-request-library ."""

# arrange
getconn_mock.return_value.getresponse.side_effect = [
Mock(status=502, msg=HTTPMessage()),
# Retries start from here...
# There should be n for where Retry(status=n).
Mock(status=502, msg=HTTPMessage()),
Mock(status=503, msg=HTTPMessage()),
Mock(status=504, msg=HTTPMessage()),
]
sypht_client = SyphtClient()

# act / assert
with self.assertRaisesRegex(Exception, "Max retries exceeded") as e:
sypht_client.get_annotations(
from_date=datetime(
year=2021, month=1, day=1, hour=0, minute=0, second=0
).strftime("%Y-%m-%d"),
to_date=datetime(
year=2021, month=1, day=1, hour=0, minute=0, second=0
).strftime("%Y-%m-%d"),
)
assert getconn_mock.return_value.request.mock_calls == [
call(
"GET",
"/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01",
body=None,
headers=ANY,
),
# Retries start here...
call(
"GET",
"/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01",
body=None,
headers=ANY,
),
call(
"GET",
"/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01",
body=None,
headers=ANY,
),
call(
"GET",
"/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01",
body=None,
headers=ANY,
),
]


if __name__ == "__main__":
unittest.main()
28 changes: 26 additions & 2 deletions tests/tests_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from sypht.util import fetch_all_pages
from sypht.util import DEFAULT_REC_LIMIT, fetch_all_pages


def test_fetch_all_pages_can_fetch_one_page():
Expand Down Expand Up @@ -73,6 +73,7 @@ def fetch_something(offset, pages=1):

def test_fetch_all_pages_never_ending():
"""Fail if fetch more than n pages."""

# arrange
def never_ending(*args, **kwargs):
return [0, 1, 2]
Expand All @@ -85,7 +86,30 @@ def never_ending(*args, **kwargs):
results += page

# assert
assert "more than the limit: 20000" in str(exc_info)
assert f"more than the limit: {DEFAULT_REC_LIMIT}" in str(exc_info)


def test_fetch_with_rec_limit():
# arrange
page_size = 5

def fetch_something(offset, pages=1):
pages0 = pages - 1
if offset > pages0:
return []
start = offset * page_size
page = range(start, start + page_size)
return list(page)

# act
page_iter = fetch_all_pages(name="test1", fetch_page=fetch_something, rec_limit=2)
results = []
with pytest.raises(Exception) as exc_info:
for page in page_iter():
results += page

# assert
assert f"fetched 5 records which is more than the limit: 2" in str(exc_info)


def test_fetch_all_pages_handle_error():
Expand Down