Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: get annotations in pages #41

Merged
merged 9 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions sypht/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
import os
from typing import List, Optional
from base64 import b64encode
from datetime import datetime, timedelta
from typing import List, Optional
from urllib.parse import quote_plus, urlencode, urljoin

import requests

from .util import fetch_all_pages

SYPHT_API_BASE_ENDPOINT = "https://api.sypht.com"
SYPHT_AUTH_ENDPOINT = "https://auth.sypht.com/oauth2/token"
SYPHT_LEGACY_AUTH_ENDPOINT = "https://login.sypht.com/oauth/token"
Expand Down Expand Up @@ -379,7 +381,9 @@ def get_file_data(self, file_id, endpoint=None, headers=None):

return response.content

def fetch_results(self, file_id, timeout=None, endpoint=None, verbose=False, headers=None):
def fetch_results(
self, file_id, timeout=None, endpoint=None, verbose=False, headers=None
):
"""
:param file_id: the id of the document that was uploaded and extracted
:param timeout: a timeout in milliseconds to wait for the results
Expand Down Expand Up @@ -415,7 +419,38 @@ def get_annotations(
to_date=None,
endpoint=None,
):
filters = []
page_iter = fetch_all_pages(
name="get_annotations",
fetch_page=lambda *args, **kwargs: self._get_annotations(*args, **kwargs)[
"annotations"
],
)
annotations = []
for page in page_iter(
doc_id=doc_id,
task_id=task_id,
user_id=user_id,
specification=specification,
from_date=from_date,
to_date=to_date,
endpoint=endpoint,
):
annotations.extend(page)
return {"annotations": annotations}

def _get_annotations(
self,
doc_id=None,
task_id=None,
user_id=None,
specification=None,
from_date=None,
to_date=None,
endpoint=None,
offset=0,
):
"""Fetch a single page of annotations skipping the given offset number of annotations first. Use get_annotations to fetch all pages."""
filters = ["offset=" + str(offset)]
if doc_id is not None:
filters.append("docId=" + doc_id)
if task_id is not None:
Expand All @@ -438,7 +473,23 @@ def get_annotations(
return self._parse_response(self.requests.get(endpoint, headers=headers))

def get_annotations_for_docs(self, doc_ids, endpoint=None):
body = json.dumps({"docIds": doc_ids})
page_iter = fetch_all_pages(
name="get_annotations_for_docs",
fetch_page=lambda *args, **kwargs: self._get_annotations_for_docs(
*args, **kwargs
)["annotations"],
)
annotations = []
for page in page_iter(
doc_ids=doc_ids,
endpoint=endpoint,
):
annotations.extend(page)
return {"annotations": annotations}

def _get_annotations_for_docs(self, doc_ids, endpoint=None, offset=0):
"""Fetch a single page of annotations skipping the given offset number of annotations first. Use get_annotations_for_docs to fetch all pages."""
body = json.dumps({"docIds": doc_ids, "offset": offset})
endpoint = urljoin(endpoint or self.base_endpoint, ("/app/annotations/search"))
headers = self._get_headers()
headers["Accept"] = "application/json"
Expand Down Expand Up @@ -814,7 +865,13 @@ def submit_task(
self.requests.post(endpoint, data=json.dumps(task), headers=headers)
)

def add_tags_to_tasks(self, task_ids: List[str], tags: List[str], company_id: Optional[str]=None, endpoint: Optional[str]=None):
def add_tags_to_tasks(
self,
task_ids: List[str],
tags: List[str],
company_id: Optional[str] = None,
endpoint: Optional[str] = None,
):
company_id = company_id or self.company_id
endpoint = urljoin(
endpoint or self.base_endpoint,
Expand All @@ -825,12 +882,15 @@ def add_tags_to_tasks(self, task_ids: List[str], tags: List[str], company_id: Op
headers["Content-Type"] = "application/json"
data = {"taskIds": task_ids, "add": tags, "remove": []}
return self._parse_response(
self.requests.post(
endpoint, data=json.dumps(data), headers=headers
)
self.requests.post(endpoint, data=json.dumps(data), headers=headers)
)

def get_tags_for_task(self, task_id: str, company_id: Optional[str]=None, endpoint: Optional[str]=None):
def get_tags_for_task(
self,
task_id: str,
company_id: Optional[str] = None,
endpoint: Optional[str] = None,
):
company_id = company_id or self.company_id
endpoint = urljoin(
endpoint or self.base_endpoint,
Expand Down
35 changes: 35 additions & 0 deletions sypht/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Callable, Iterator, List


def fetch_all_pages(
name: str, fetch_page: Callable[..., List[Any]], page_limit=1000
) -> Callable[..., Iterator[List[Any]]]:
"""Returns an iterator that calls fetch_page with an offset that we increment by the number of records returned from the last call to fetch_page. Stop if page returns empty list."""

def fetch_all_pages(*args, **kwargs) -> Iterator[List[Any]]:
offset = 0
page_count = 0
while True:
page_count += 1
if page_count > page_limit:
# Don't want to DOS ourselves...
raise Exception(
f"fetch_all_pages({name}): fetched more than {page_limit} pages - you sure this thing is gonna stop? Consider using a date range to reduce the number of pages fetched."
)
try:
result = fetch_page(
*args,
**kwargs,
offset=offset,
)
except Exception as err:
raise Exception(
f"Failed fetching for {name} for page={page_count} offset={offset}"
) from err
if not result:
break
offset += len(result)
yield result
return None

return fetch_all_pages
77 changes: 77 additions & 0 deletions tests/tests_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import pytest

from sypht.util import fetch_all_pages


def test_fetch_all_pages_can_fetch_one_page():
# arrange
num_pages = 1

def fetch_something(offset, n=0):
result = range(offset + n, offset + n + 3)
if offset > (len(result) * num_pages) - 1:
return []
return result

# act
page_iter = fetch_all_pages(name="test1", fetch_page=fetch_something)
results = []
for page in page_iter(n=10):
results += page

# assert
assert results == [10, 11, 12]


def test_fetch_all_pages_can_fetch_several_pages():
# arrange
num_pages = 3

def fetch_something(offset, n=0):
result = range(offset + n, offset + n + 3)
if offset > (len(result) * num_pages) - 1:
return []
return result

# act
page_iter = fetch_all_pages(name="test1", fetch_page=fetch_something)
results = []
for page in page_iter(n=2):
results += page

# assert
assert results == [2, 3, 4, 5, 6, 7, 8, 9, 10]


def test_fetch_all_pages_never_ending():
"""Fail if fetch more than n pages."""
# arrange
def never_ending(*args, **kwargs):
return [0, 1, 2]

# act
page_iter = fetch_all_pages(name="test1", fetch_page=never_ending)
results = []
with pytest.raises(Exception) as exc_info:
for page in page_iter(n=2):
results += page

# assert
assert "fetched more than 1000 pages" in str(exc_info)


def test_fetch_all_pages_handle_error():
# arrange
def failing(*args, **kwargs):
raise Exception("fetch error")

# act
page_iter = fetch_all_pages(name="test1", fetch_page=failing)
results = []
with pytest.raises(Exception) as exc_info:
for page in page_iter(n=2):
results += page

# assert
assert "fetch error" in str(exc_info.value.__cause__)
assert "Failed fetching for test1" in str(exc_info)