Skip to content

Commit

Permalink
Merge pull request #43 from sypht-team/chore/bump-rec-limit
Browse files Browse the repository at this point in the history
chore: bump rec_limit and allow override
  • Loading branch information
danielbush authored Sep 27, 2023
2 parents 090eae6 + d0808a7 commit 3f8d6b1
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 11 deletions.
12 changes: 10 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,20 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest coverage codecov
pip install pytest coverage codecov httpretty
pip install .
type python
type pip
type pytest
pip freeze
- name: Test with pytest
env:
SYPHT_API_BASE_ENDPOINT: ${{ vars.SYPHT_API_BASE_ENDPOINT }}
SYPHT_API_KEY: ${{ secrets.SYPHT_API_KEY }}
SYPHT_AUTH_ENDPOINT: ${{ vars.SYPHT_AUTH_ENDPOINT }}
run: |
pytest -s tests/*.py
type python
type pip
type pytest
pip freeze
pytest -s tests/test*.py
32 changes: 29 additions & 3 deletions sypht/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from urllib.parse import quote_plus, urlencode, urljoin

import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

from .util import fetch_all_pages
from sypht.util import fetch_all_pages

SYPHT_API_BASE_ENDPOINT = "https://api.sypht.com"
SYPHT_AUTH_ENDPOINT = "https://auth.sypht.com/oauth2/token"
Expand Down Expand Up @@ -75,7 +77,23 @@ def __init__(

@property
def _create_session(self):
return requests.Session()
session = requests.Session()
retries = Retry(
total=None, # set connect, read, redirect, status, other instead
connect=3,
read=3,
redirect=0,
status=3,
status_forcelist=[429, 502, 503, 504],
other=0, # catch-all for other errors
allowed_methods=["GET"],
respect_retry_after_header=False,
backoff_factor=0.5, # 0.0, 0.5, 1.0, 2.0, 4.0
# Support manual status handling in _parse_response.
raise_on_status=False,
)
session.mount(self.base_endpoint, HTTPAdapter(max_retries=retries))
return session

def _authenticate_v2(self, endpoint, client_id, client_secret, audience):
basic_auth_slug = b64encode(
Expand Down Expand Up @@ -418,11 +436,14 @@ def get_annotations(
from_date=None,
to_date=None,
endpoint=None,
rec_limit=None,
company_id=None,
):
page_iter = fetch_all_pages(
name="get_annotations",
fetch_page=self._get_annotations,
get_page=lambda response: response["annotations"],
rec_limit=rec_limit,
)
annotations = []
for response in page_iter(
Expand All @@ -433,6 +454,7 @@ def get_annotations(
from_date=from_date,
to_date=to_date,
endpoint=endpoint,
company_id=company_id,
):
annotations.extend(response["annotations"])
return {"annotations": annotations}
Expand All @@ -446,6 +468,7 @@ def _get_annotations(
from_date=None,
to_date=None,
endpoint=None,
company_id=None,
offset=0,
):
"""Fetch a single page of annotations skipping the given offset number of pages first. Use get_annotations to fetch all pages."""
Expand All @@ -462,6 +485,8 @@ def _get_annotations(
filters.append("fromDate=" + from_date)
if to_date is not None:
filters.append("toDate=" + to_date)
if company_id is not None:
filters.append("companyId=" + company_id)

endpoint = urljoin(
endpoint or self.base_endpoint, ("/app/annotations?" + "&".join(filters))
Expand All @@ -471,11 +496,12 @@ def _get_annotations(
headers["Content-Type"] = "application/json"
return self._parse_response(self.requests.get(endpoint, headers=headers))

def get_annotations_for_docs(self, doc_ids, endpoint=None):
def get_annotations_for_docs(self, doc_ids, endpoint=None, rec_limit=None):
page_iter = fetch_all_pages(
name="get_annotations_for_docs",
fetch_page=self._get_annotations_for_docs,
get_page=lambda response: response["annotations"],
rec_limit=rec_limit,
)
annotations = []
for response in page_iter(
Expand Down
16 changes: 13 additions & 3 deletions sypht/util.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import logging
from typing import Any, Callable, Iterator, List

DEFAULT_REC_LIMIT = 100_000


def fetch_all_pages(
name: str,
fetch_page: Callable[..., Any],
get_page: Callable[..., List[Any]] = lambda x: x,
rec_limit=20000,
rec_limit=DEFAULT_REC_LIMIT,
) -> Callable[..., Iterator[Any]]:
"""Returns an iterator that calls fetch_page with an offset that we increment by the number of pages fetched. Stop if page returns empty list.
:param fetch_page: a function that makes an api call to fetch a page of results (using zero-based offset)
:param get_page: a function that extracts the page from the response which should be a list
"""

# Enforce a default so that the loop will stop.
if rec_limit is None:
rec_limit = DEFAULT_REC_LIMIT

def fetch_all_pages(*args, **kwargs) -> Iterator[Any]:
page_count = 0
recs = 0
Expand All @@ -31,17 +38,20 @@ def fetch_all_pages(*args, **kwargs) -> Iterator[Any]:
)
except Exception as err:
raise Exception(
f"Failed fetching for {name} for offset={page_count - 1} (records fetched so far:{recs})"
f"Failed fetching for {name} for offset={page_count - 1} (page={page_count}) (records fetched so far:{recs}). Cause: {err}"
) from err
try:
page = get_page(response)
except Exception as err:
raise Exception(
f"get_page failed to extract page from response for {name} for offset={page_count - 1} (records fetched so far:{recs})"
f"get_page failed to extract page from response for {name} for offset={page_count - 1} (page={page_count}) (records fetched so far:{recs}). Cause: {err}"
) from err
if len(page) == 0:
break
recs += len(page)
logging.info(
f"fetch_all_pages({name}): fetched page {page_count} (records={recs})"
)
yield response

return fetch_all_pages
76 changes: 75 additions & 1 deletion tests/tests_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import os
import json
import unittest
import warnings
from datetime import datetime, timedelta
from unittest.mock import Mock, patch
from uuid import UUID, uuid4

import httpretty
import pytest

from sypht.client import SyphtClient


Expand Down Expand Up @@ -93,5 +97,75 @@ def test_reauthentication(self):
self.assertFalse(self.sypht_client._is_token_expired())


class RetryTest(unittest.TestCase):
"""Test the global retry logic works as we expect it to."""

@patch.object(SyphtClient, "_authenticate_v2", return_value=("access_token", 100))
@patch.object(SyphtClient, "_authenticate_v1", return_value=("access_token2", 100))
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_it_should_retry_n_times(self, auth_v1: Mock, auth_v2: Mock):
# arrange
self.count = 0

def get_annotations(request, uri, response_headers):
self.count += 1
# 1 req + 3 retries = 4
if self.count == 4:
return [200, response_headers, json.dumps({"annotations": []})]
return [502, response_headers, json.dumps({})]

httpretty.register_uri(
httpretty.GET,
"https://api.sypht.com/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01",
body=get_annotations,
)

sypht_client = SyphtClient(base_endpoint="https://api.sypht.com")

# act / assert
response = sypht_client.get_annotations(
from_date=datetime(
year=2021, month=1, day=1, hour=0, minute=0, second=0
).strftime("%Y-%m-%d"),
to_date=datetime(
year=2021, month=1, day=1, hour=0, minute=0, second=0
).strftime("%Y-%m-%d"),
)

assert response == {"annotations": []}

@patch.object(SyphtClient, "_authenticate_v2", return_value=("access_token", 100))
@patch.object(SyphtClient, "_authenticate_v1", return_value=("access_token2", 100))
@httpretty.activate(verbose=True, allow_net_connect=False)
def test_retry_should_eventually_fail_for_50x(self, auth_v1: Mock, auth_v2: Mock):
# arrange
self.count = 0

def get_annotations(request, uri, response_headers):
self.count += 1
return [502, response_headers, json.dumps({})]

httpretty.register_uri(
httpretty.GET,
"https://api.sypht.com/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01",
body=get_annotations,
)

sypht_client = SyphtClient(base_endpoint="https://api.sypht.com")

# act / assert
with self.assertRaisesRegex(Exception, ".") as e:
sypht_client.get_annotations(
from_date=datetime(
year=2021, month=1, day=1, hour=0, minute=0, second=0
).strftime("%Y-%m-%d"),
to_date=datetime(
year=2021, month=1, day=1, hour=0, minute=0, second=0
).strftime("%Y-%m-%d"),
)

assert self.count == 4, "should be 1 req + 3 retries"


if __name__ == "__main__":
unittest.main()
28 changes: 26 additions & 2 deletions tests/tests_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from sypht.util import fetch_all_pages
from sypht.util import DEFAULT_REC_LIMIT, fetch_all_pages


def test_fetch_all_pages_can_fetch_one_page():
Expand Down Expand Up @@ -73,6 +73,7 @@ def fetch_something(offset, pages=1):

def test_fetch_all_pages_never_ending():
"""Fail if fetch more than n pages."""

# arrange
def never_ending(*args, **kwargs):
return [0, 1, 2]
Expand All @@ -85,7 +86,30 @@ def never_ending(*args, **kwargs):
results += page

# assert
assert "more than the limit: 20000" in str(exc_info)
assert f"more than the limit: {DEFAULT_REC_LIMIT}" in str(exc_info)


def test_fetch_with_rec_limit():
# arrange
page_size = 5

def fetch_something(offset, pages=1):
pages0 = pages - 1
if offset > pages0:
return []
start = offset * page_size
page = range(start, start + page_size)
return list(page)

# act
page_iter = fetch_all_pages(name="test1", fetch_page=fetch_something, rec_limit=2)
results = []
with pytest.raises(Exception) as exc_info:
for page in page_iter():
results += page

# assert
assert f"fetched 5 records which is more than the limit: 2" in str(exc_info)


def test_fetch_all_pages_handle_error():
Expand Down

0 comments on commit 3f8d6b1

Please sign in to comment.