Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: get annotations in pages #41

Merged
merged 9 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 67 additions & 9 deletions sypht/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
import os
from typing import List, Optional
from base64 import b64encode
from datetime import datetime, timedelta
from typing import List, Optional
from urllib.parse import quote_plus, urlencode, urljoin

import requests

from .util import fetch_all_pages

SYPHT_API_BASE_ENDPOINT = "https://api.sypht.com"
SYPHT_AUTH_ENDPOINT = "https://auth.sypht.com/oauth2/token"
SYPHT_LEGACY_AUTH_ENDPOINT = "https://login.sypht.com/oauth/token"
Expand Down Expand Up @@ -379,7 +381,9 @@ def get_file_data(self, file_id, endpoint=None, headers=None):

return response.content

def fetch_results(self, file_id, timeout=None, endpoint=None, verbose=False, headers=None):
def fetch_results(
self, file_id, timeout=None, endpoint=None, verbose=False, headers=None
):
"""
:param file_id: the id of the document that was uploaded and extracted
:param timeout: a timeout in milliseconds to wait for the results
Expand Down Expand Up @@ -415,7 +419,37 @@ def get_annotations(
to_date=None,
endpoint=None,
):
filters = []
page_iter = fetch_all_pages(
name="get_annotations",
fetch_page=self._get_annotations,
get_page=lambda response: response["annotations"],
)
annotations = []
for response in page_iter(
doc_id=doc_id,
task_id=task_id,
user_id=user_id,
specification=specification,
from_date=from_date,
to_date=to_date,
endpoint=endpoint,
):
annotations.extend(response["annotations"])
return {"annotations": annotations}

def _get_annotations(
self,
doc_id=None,
task_id=None,
user_id=None,
specification=None,
from_date=None,
to_date=None,
endpoint=None,
offset=0,
):
"""Fetch a single page of annotations skipping the given offset number of pages first. Use get_annotations to fetch all pages."""
filters = ["offset=" + str(offset)]
if doc_id is not None:
filters.append("docId=" + doc_id)
if task_id is not None:
Expand All @@ -438,7 +472,22 @@ def get_annotations(
return self._parse_response(self.requests.get(endpoint, headers=headers))

def get_annotations_for_docs(self, doc_ids, endpoint=None):
body = json.dumps({"docIds": doc_ids})
page_iter = fetch_all_pages(
name="get_annotations_for_docs",
fetch_page=self._get_annotations_for_docs,
get_page=lambda response: response["annotations"],
)
annotations = []
for response in page_iter(
doc_ids=doc_ids,
endpoint=endpoint,
):
annotations.extend(response["annotations"])
return {"annotations": annotations}

def _get_annotations_for_docs(self, doc_ids, endpoint=None, offset=0):
"""Fetch a single page of annotations skipping the given offset number of pages first. Use get_annotations_for_docs to fetch all pages."""
body = json.dumps({"docIds": doc_ids, "offset": offset})
endpoint = urljoin(endpoint or self.base_endpoint, ("/app/annotations/search"))
headers = self._get_headers()
headers["Accept"] = "application/json"
Expand Down Expand Up @@ -814,7 +863,13 @@ def submit_task(
self.requests.post(endpoint, data=json.dumps(task), headers=headers)
)

def add_tags_to_tasks(self, task_ids: List[str], tags: List[str], company_id: Optional[str]=None, endpoint: Optional[str]=None):
def add_tags_to_tasks(
self,
task_ids: List[str],
tags: List[str],
company_id: Optional[str] = None,
endpoint: Optional[str] = None,
):
company_id = company_id or self.company_id
endpoint = urljoin(
endpoint or self.base_endpoint,
Expand All @@ -825,12 +880,15 @@ def add_tags_to_tasks(self, task_ids: List[str], tags: List[str], company_id: Op
headers["Content-Type"] = "application/json"
data = {"taskIds": task_ids, "add": tags, "remove": []}
return self._parse_response(
self.requests.post(
endpoint, data=json.dumps(data), headers=headers
)
self.requests.post(endpoint, data=json.dumps(data), headers=headers)
)

def get_tags_for_task(self, task_id: str, company_id: Optional[str]=None, endpoint: Optional[str]=None):
def get_tags_for_task(
self,
task_id: str,
company_id: Optional[str] = None,
endpoint: Optional[str] = None,
):
company_id = company_id or self.company_id
endpoint = urljoin(
endpoint or self.base_endpoint,
Expand Down
47 changes: 47 additions & 0 deletions sypht/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Any, Callable, Iterator, List


def fetch_all_pages(
name: str,
fetch_page: Callable[..., Any],
get_page: Callable[..., List[Any]] = lambda x: x,
rec_limit=20000,
) -> Callable[..., Iterator[Any]]:
"""Returns an iterator that calls fetch_page with an offset that we increment by the number of pages fetched. Stop if page returns empty list.

:param fetch_page: a function that makes an api call to fetch a page of results (using zero-based offset)
:param get_page: a function that extracts the page from the response which should be a list
"""

def fetch_all_pages(*args, **kwargs) -> Iterator[Any]:
page_count = 0
recs = 0
while True:
page_count += 1
if recs > rec_limit:
# Don't want to DOS ourselves...
raise Exception(
f"fetch_all_pages({name}): fetched {recs} records which is more than the limit: {rec_limit} . Consider adding or adjusting a filter to reduce the total number of items fetched."
)
try:
response = fetch_page(
*args,
**kwargs,
offset=page_count - 1,
)
except Exception as err:
raise Exception(
f"Failed fetching for {name} for offset={page_count - 1} (records fetched so far:{recs})"
) from err
try:
page = get_page(response)
except Exception as err:
raise Exception(
f"get_page failed to extract page from response for {name} for offset={page_count - 1} (records fetched so far:{recs})"
) from err
if len(page) == 0:
break
recs += len(page)
yield response

return fetch_all_pages
105 changes: 105 additions & 0 deletions tests/tests_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import pytest

from sypht.util import fetch_all_pages


def test_fetch_all_pages_can_fetch_one_page():
# arrange
page_size = 5

def fetch_something(offset, pages=1):
pages0 = pages - 1
if offset > pages0:
return []
start = offset * page_size
page = range(start, start + page_size)
return list(page)

# act
page_iter = fetch_all_pages(name="test1", fetch_page=fetch_something)
results = []
for page in page_iter(pages=1):
results += page

# assert
assert results == [0, 1, 2, 3, 4]


def test_fetch_all_pages_can_fetch_one_page_with_get_page():
# arrange
page_size = 5

def fetch_something(offset, pages=1):
pages0 = pages - 1
if offset > pages0:
return {"results": []}
start = offset * page_size
page = range(start, start + page_size)
return {"results": list(page)}

# act
page_iter = fetch_all_pages(
name="test1", fetch_page=fetch_something, get_page=lambda resp: resp["results"]
)
results = []
for resp in page_iter(pages=1):
results += resp["results"]

# assert
assert results == [0, 1, 2, 3, 4]


def test_fetch_all_pages_can_fetch_several_pages():
# arrange
page_size = 5

def fetch_something(offset, pages=1):
pages0 = pages - 1
if offset > pages0:
return []
start = offset * page_size
page = range(start, start + page_size)
return list(page)

# act
page_iter = fetch_all_pages(name="test1", fetch_page=fetch_something)
results = []
for page in page_iter(pages=2):
results += page

# assert
assert results == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


def test_fetch_all_pages_never_ending():
"""Fail if fetch more than n pages."""
# arrange
def never_ending(*args, **kwargs):
return [0, 1, 2]

# act
page_iter = fetch_all_pages(name="test1", fetch_page=never_ending)
results = []
with pytest.raises(Exception) as exc_info:
for page in page_iter():
results += page

# assert
assert "more than the limit: 20000" in str(exc_info)


def test_fetch_all_pages_handle_error():
# arrange
def failing(*args, **kwargs):
raise Exception("fetch error")

# act
page_iter = fetch_all_pages(name="test1", fetch_page=failing)
results = []
with pytest.raises(Exception) as exc_info:
for page in page_iter():
results += page

# assert
assert "fetch error" in str(exc_info.value.__cause__)
assert "Failed fetching for test1" in str(exc_info)