From 6a0067fbbd800642e5f595b2c02d8be535ff9630 Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Fri, 2 Feb 2024 09:23:50 -0800 Subject: [PATCH 1/4] feat: Add `GET /documents/{id}/status` This returns just the ingestion status information. The intention is to use this for cases where a client wants to wait for ingestion, since it won't include the full document-text, etc. --- .../api/default/get_document_status.py | 155 ++++++++++++++++++ dewy-client/dewy_client/models/__init__.py | 2 + dewy-client/dewy_client/models/document.py | 2 +- .../dewy_client/models/document_status.py | 90 ++++++++++ dewy/common/db_migration.py | 2 +- dewy/document/models.py | 8 +- dewy/document/router.py | 18 +- dewy/migrations/0001_schema.sql | 2 +- openapi.yaml | 46 +++++- pyproject.toml | 2 +- tests/test_e2e.py | 17 +- 11 files changed, 333 insertions(+), 11 deletions(-) create mode 100644 dewy-client/dewy_client/api/default/get_document_status.py create mode 100644 dewy-client/dewy_client/models/document_status.py diff --git a/dewy-client/dewy_client/api/default/get_document_status.py b/dewy-client/dewy_client/api/default/get_document_status.py new file mode 100644 index 0000000..f38f12d --- /dev/null +++ b/dewy-client/dewy_client/api/default/get_document_status.py @@ -0,0 +1,155 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.document_status import DocumentStatus +from ...models.http_validation_error import HTTPValidationError +from ...types import Response + + +def _get_kwargs( + id: int, +) -> Dict[str, Any]: + _kwargs: Dict[str, Any] = { + "method": "get", + "url": f"/api/documents/{id}/status", + } + + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[DocumentStatus, HTTPValidationError]]: + if response.status_code == HTTPStatus.OK: + response_200 = DocumentStatus.from_dict(response.json()) + + return response_200 + if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[DocumentStatus, HTTPValidationError]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[DocumentStatus, HTTPValidationError]]: + """Get Document Status + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[DocumentStatus, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + id=id, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[DocumentStatus, HTTPValidationError]]: + """Get Document Status + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[DocumentStatus, HTTPValidationError] + """ + + return sync_detailed( + id=id, + client=client, + ).parsed + + +async def asyncio_detailed( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[DocumentStatus, HTTPValidationError]]: + """Get Document Status + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[DocumentStatus, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + id=id, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[DocumentStatus, HTTPValidationError]]: + """Get Document Status + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[DocumentStatus, HTTPValidationError] + """ + + return ( + await asyncio_detailed( + id=id, + client=client, + ) + ).parsed diff --git a/dewy-client/dewy_client/models/__init__.py b/dewy-client/dewy_client/models/__init__.py index 9e17403..6cb4226 100644 --- a/dewy-client/dewy_client/models/__init__.py +++ b/dewy-client/dewy_client/models/__init__.py @@ -5,6 +5,7 @@ from .collection_create import CollectionCreate from .distance_metric import DistanceMetric from .document import Document +from .document_status import DocumentStatus from .http_validation_error import HTTPValidationError from .image_chunk import ImageChunk from .image_result import ImageResult @@ -21,6 +22,7 @@ "CollectionCreate", "DistanceMetric", "Document", + "DocumentStatus", "HTTPValidationError", "ImageChunk", "ImageResult", diff --git a/dewy-client/dewy_client/models/document.py b/dewy-client/dewy_client/models/document.py index 78c5bce..d7fbcc8 100644 --- a/dewy-client/dewy_client/models/document.py +++ b/dewy-client/dewy_client/models/document.py @@ -11,7 +11,7 @@ @_attrs_define class Document: - """Schema for documents in the SQL DB. + """Model for documents in Dewy. Attributes: collection_id (int): diff --git a/dewy-client/dewy_client/models/document_status.py b/dewy-client/dewy_client/models/document_status.py new file mode 100644 index 0000000..8357800 --- /dev/null +++ b/dewy-client/dewy_client/models/document_status.py @@ -0,0 +1,90 @@ +from typing import Any, Dict, List, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..models.ingest_state import IngestState +from ..types import UNSET, Unset + +T = TypeVar("T", bound="DocumentStatus") + + +@_attrs_define +class DocumentStatus: + """Model for document status. + + Attributes: + id (int): + ingest_state (IngestState): + ingest_error (Union[None, Unset, str]): + """ + + id: int + ingest_state: IngestState + ingest_error: Union[None, Unset, str] = UNSET + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + id = self.id + + ingest_state = self.ingest_state.value + + ingest_error: Union[None, Unset, str] + if isinstance(self.ingest_error, Unset): + ingest_error = UNSET + else: + ingest_error = self.ingest_error + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "id": id, + "ingest_state": ingest_state, + } + ) + if ingest_error is not UNSET: + field_dict["ingest_error"] = ingest_error + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + id = d.pop("id") + + ingest_state = IngestState(d.pop("ingest_state")) + + def _parse_ingest_error(data: object) -> Union[None, Unset, str]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, str], data) + + ingest_error = _parse_ingest_error(d.pop("ingest_error", UNSET)) + + document_status = cls( + id=id, + ingest_state=ingest_state, + ingest_error=ingest_error, + ) + + document_status.additional_properties = d + return document_status + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy/common/db_migration.py b/dewy/common/db_migration.py index 90336da..65ef448 100644 --- a/dewy/common/db_migration.py +++ b/dewy/common/db_migration.py @@ -38,7 +38,7 @@ async def apply_migrations( applied += 1 if applied_migrations: - logger.warn("Unrecognized migrations applied: {}", applied_migrations) + logger.warning("Unrecognized migrations applied: {}", applied_migrations) logger.info( "Migrations complete. {} total, {} newly applied", diff --git a/dewy/document/models.py b/dewy/document/models.py index 603cf67..92ac032 100644 --- a/dewy/document/models.py +++ b/dewy/document/models.py @@ -24,7 +24,7 @@ class IngestState(Enum): class Document(BaseModel): - """Schema for documents in the SQL DB.""" + """Model for documents in Dewy.""" id: Optional[int] = None collection_id: int @@ -41,3 +41,9 @@ class Document(BaseModel): ingest_state: Optional[IngestState] = None ingest_error: Optional[str] = None + +class DocumentStatus(BaseModel): + """Model for document status.""" + id: int + ingest_state: IngestState + ingest_error: Optional[str] = None \ No newline at end of file diff --git a/dewy/document/router.py b/dewy/document/router.py index f3c5e54..4fff54c 100644 --- a/dewy/document/router.py +++ b/dewy/document/router.py @@ -8,7 +8,7 @@ from dewy.common.db import PgConnectionDep, PgPoolDep from dewy.document.models import Document -from .models import AddDocumentRequest +from .models import AddDocumentRequest, DocumentStatus, IngestState router = APIRouter(prefix="/documents") @@ -125,3 +125,19 @@ async def get_document(conn: PgConnectionDep, id: PathDocumentId) -> Document: id, ) return Document.model_validate(dict(result)) + +@router.get("/{id}/status") +async def get_document_status(conn: PgConnectionDep, id: PathDocumentId) -> DocumentStatus: + result = await conn.fetchrow( + """ + SELECT ingest_state, ingest_error + FROM document + WHERE id = $1 + """, + id + ) + return DocumentStatus( + id=id, + ingest_state = result["ingest_state"], + ingest_error = result["ingest_error"] + ) \ No newline at end of file diff --git a/dewy/migrations/0001_schema.sql b/dewy/migrations/0001_schema.sql index b818749..82af2b9 100644 --- a/dewy/migrations/0001_schema.sql +++ b/dewy/migrations/0001_schema.sql @@ -36,7 +36,7 @@ CREATE TABLE document( -- The state of the most recent ingestion of this document. -- TODO: Should we have a separate `ingestion` table and associate -- many ingestions with each document ID? - ingest_state ingest_state, + ingest_state ingest_state NOT NULL, -- The error (if any) resulting from the most recent ingestion. ingest_error VARCHAR, diff --git a/openapi.yaml b/openapi.yaml index dbba7d8..1d763a9 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -176,6 +176,32 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /api/documents/{id}/status: + get: + summary: Get Document Status + operationId: getDocumentStatus + parameters: + - name: id + in: path + required: true + schema: + type: integer + description: The document ID. + title: Id + description: The document ID. + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/DocumentStatus' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /api/chunks/: get: summary: List Chunks @@ -400,7 +426,25 @@ components: - collection_id - url title: Document - description: Schema for documents in the SQL DB. + description: Model for documents in Dewy. + DocumentStatus: + properties: + id: + type: integer + title: Id + ingest_state: + $ref: '#/components/schemas/IngestState' + ingest_error: + anyOf: + - type: string + - type: 'null' + title: Ingest Error + type: object + required: + - id + - ingest_state + title: DocumentStatus + description: Model for document status. HTTPValidationError: properties: detail: diff --git a/pyproject.toml b/pyproject.toml index dd60b92..ca63953 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,4 +119,4 @@ show_error_context = true [tool.pytest.ini_options] testpaths = ["tests"] asyncio_mode = "auto" -timeout = 5 \ No newline at end of file +timeout = 10 \ No newline at end of file diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 99b0a2c..cb68eef 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -7,6 +7,7 @@ add_collection, add_document, get_document, + get_document_status, list_chunks, retrieve_chunks, ) @@ -39,9 +40,12 @@ async def test_index_retrieval(client, embedding_model): ), ) - while document.ingest_state != IngestState.INGESTED: + status = None + while getattr(status, "ingest_state", IngestState.PENDING) == IngestState.PENDING: time.sleep(0.5) - document = await get_document.asyncio(document.id, client=client) + status = await get_document_status.asyncio(document.id, client=client) + + document = await get_document.asyncio(document.id, client=client) assert document.extracted_text.startswith("Skeleton-of-Thought") chunks = await list_chunks.asyncio( @@ -77,10 +81,15 @@ async def test_ingest_error(client): body=AddDocumentRequest(url=f"error://{MESSAGE}", collection_id=collection.id), ) - while document.ingest_state == IngestState.PENDING: + status = None + while getattr(status, "ingest_state", IngestState.PENDING) == IngestState.PENDING: time.sleep(0.2) - document = await get_document.asyncio(document.id, client=client) + status = await get_document_status.asyncio(document.id, client=client) + + assert status.ingest_state == IngestState.FAILED + assert status.ingest_error == MESSAGE + document = await get_document.asyncio(document.id, client=client) assert document.ingest_state == IngestState.FAILED assert document.ingest_error == MESSAGE From f769d0c830aae27eef41fd8e0303f1a0cd152131 Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Sat, 3 Feb 2024 06:35:03 -0800 Subject: [PATCH 2/4] fixes --- dewy/document/models.py | 4 +++- dewy/document/router.py | 15 ++++++++------- pyproject.toml | 4 +++- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/dewy/document/models.py b/dewy/document/models.py index 92ac032..edd1f05 100644 --- a/dewy/document/models.py +++ b/dewy/document/models.py @@ -42,8 +42,10 @@ class Document(BaseModel): ingest_state: Optional[IngestState] = None ingest_error: Optional[str] = None + class DocumentStatus(BaseModel): """Model for document status.""" + id: int ingest_state: IngestState - ingest_error: Optional[str] = None \ No newline at end of file + ingest_error: Optional[str] = None diff --git a/dewy/document/router.py b/dewy/document/router.py index 4fff54c..2b0ae62 100644 --- a/dewy/document/router.py +++ b/dewy/document/router.py @@ -8,7 +8,7 @@ from dewy.common.db import PgConnectionDep, PgPoolDep from dewy.document.models import Document -from .models import AddDocumentRequest, DocumentStatus, IngestState +from .models import AddDocumentRequest, DocumentStatus router = APIRouter(prefix="/documents") @@ -126,18 +126,19 @@ async def get_document(conn: PgConnectionDep, id: PathDocumentId) -> Document: ) return Document.model_validate(dict(result)) + @router.get("/{id}/status") -async def get_document_status(conn: PgConnectionDep, id: PathDocumentId) -> DocumentStatus: +async def get_document_status( + conn: PgConnectionDep, id: PathDocumentId +) -> DocumentStatus: result = await conn.fetchrow( """ SELECT ingest_state, ingest_error FROM document WHERE id = $1 """, - id + id, ) return DocumentStatus( - id=id, - ingest_state = result["ingest_state"], - ingest_error = result["ingest_error"] - ) \ No newline at end of file + id=id, ingest_state=result["ingest_state"], ingest_error=result["ingest_error"] + ) diff --git a/pyproject.toml b/pyproject.toml index ca63953..4f5d003 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,4 +119,6 @@ show_error_context = true [tool.pytest.ini_options] testpaths = ["tests"] asyncio_mode = "auto" -timeout = 10 \ No newline at end of file +# 10s seems to be enough locally, but CI is slower. +# set it high enough for both +timeout = 30 \ No newline at end of file From 84bb171f395881ff633d36f900f4e854ec296f6a Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Sat, 3 Feb 2024 08:49:08 -0800 Subject: [PATCH 3/4] even higher timeout --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4f5d003..1e30c95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -121,4 +121,4 @@ testpaths = ["tests"] asyncio_mode = "auto" # 10s seems to be enough locally, but CI is slower. # set it high enough for both -timeout = 30 \ No newline at end of file +timeout = 60 \ No newline at end of file From b4f89471eaa5e2867a8f1271402a421d2a831b94 Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Sat, 3 Feb 2024 10:35:18 -0800 Subject: [PATCH 4/4] tweak timeouts further --- pyproject.toml | 4 +--- tests/test_e2e.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1e30c95..ca63953 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,4 @@ show_error_context = true [tool.pytest.ini_options] testpaths = ["tests"] asyncio_mode = "auto" -# 10s seems to be enough locally, but CI is slower. -# set it high enough for both -timeout = 60 \ No newline at end of file +timeout = 10 \ No newline at end of file diff --git a/tests/test_e2e.py b/tests/test_e2e.py index cb68eef..cf04e7e 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -24,7 +24,7 @@ @pytest.mark.parametrize( "embedding_model", ["openai:text-embedding-ada-002", "hf:BAAI/bge-small-en"] ) -@pytest.mark.timeout(60) # slow due to embedding +@pytest.mark.timeout(120) # slow due to embedding (especially in CI) async def test_index_retrieval(client, embedding_model): name = "".join(random.choices(string.ascii_lowercase, k=5))