diff --git a/dewy-client/dewy_client/api/default/get_document_status.py b/dewy-client/dewy_client/api/default/get_document_status.py new file mode 100644 index 0000000..f38f12d --- /dev/null +++ b/dewy-client/dewy_client/api/default/get_document_status.py @@ -0,0 +1,155 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.document_status import DocumentStatus +from ...models.http_validation_error import HTTPValidationError +from ...types import Response + + +def _get_kwargs( + id: int, +) -> Dict[str, Any]: + _kwargs: Dict[str, Any] = { + "method": "get", + "url": f"/api/documents/{id}/status", + } + + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[DocumentStatus, HTTPValidationError]]: + if response.status_code == HTTPStatus.OK: + response_200 = DocumentStatus.from_dict(response.json()) + + return response_200 + if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[DocumentStatus, HTTPValidationError]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[DocumentStatus, HTTPValidationError]]: + """Get Document Status + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[DocumentStatus, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + id=id, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[DocumentStatus, HTTPValidationError]]: + """Get Document Status + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[DocumentStatus, HTTPValidationError] + """ + + return sync_detailed( + id=id, + client=client, + ).parsed + + +async def asyncio_detailed( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[DocumentStatus, HTTPValidationError]]: + """Get Document Status + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[DocumentStatus, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + id=id, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[DocumentStatus, HTTPValidationError]]: + """Get Document Status + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[DocumentStatus, HTTPValidationError] + """ + + return ( + await asyncio_detailed( + id=id, + client=client, + ) + ).parsed diff --git a/dewy-client/dewy_client/models/__init__.py b/dewy-client/dewy_client/models/__init__.py index 9e17403..6cb4226 100644 --- a/dewy-client/dewy_client/models/__init__.py +++ b/dewy-client/dewy_client/models/__init__.py @@ -5,6 +5,7 @@ from .collection_create import CollectionCreate from .distance_metric import DistanceMetric from .document import Document +from .document_status import DocumentStatus from .http_validation_error import HTTPValidationError from .image_chunk import ImageChunk from .image_result import ImageResult @@ -21,6 +22,7 @@ "CollectionCreate", "DistanceMetric", "Document", + "DocumentStatus", "HTTPValidationError", "ImageChunk", "ImageResult", diff --git a/dewy-client/dewy_client/models/document.py b/dewy-client/dewy_client/models/document.py index 78c5bce..d7fbcc8 100644 --- a/dewy-client/dewy_client/models/document.py +++ b/dewy-client/dewy_client/models/document.py @@ -11,7 +11,7 @@ @_attrs_define class Document: - """Schema for documents in the SQL DB. + """Model for documents in Dewy. Attributes: collection_id (int): diff --git a/dewy-client/dewy_client/models/document_status.py b/dewy-client/dewy_client/models/document_status.py new file mode 100644 index 0000000..8357800 --- /dev/null +++ b/dewy-client/dewy_client/models/document_status.py @@ -0,0 +1,90 @@ +from typing import Any, Dict, List, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..models.ingest_state import IngestState +from ..types import UNSET, Unset + +T = TypeVar("T", bound="DocumentStatus") + + +@_attrs_define +class DocumentStatus: + """Model for document status. + + Attributes: + id (int): + ingest_state (IngestState): + ingest_error (Union[None, Unset, str]): + """ + + id: int + ingest_state: IngestState + ingest_error: Union[None, Unset, str] = UNSET + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + id = self.id + + ingest_state = self.ingest_state.value + + ingest_error: Union[None, Unset, str] + if isinstance(self.ingest_error, Unset): + ingest_error = UNSET + else: + ingest_error = self.ingest_error + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "id": id, + "ingest_state": ingest_state, + } + ) + if ingest_error is not UNSET: + field_dict["ingest_error"] = ingest_error + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + id = d.pop("id") + + ingest_state = IngestState(d.pop("ingest_state")) + + def _parse_ingest_error(data: object) -> Union[None, Unset, str]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, str], data) + + ingest_error = _parse_ingest_error(d.pop("ingest_error", UNSET)) + + document_status = cls( + id=id, + ingest_state=ingest_state, + ingest_error=ingest_error, + ) + + document_status.additional_properties = d + return document_status + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy/common/db_migration.py b/dewy/common/db_migration.py index 90336da..65ef448 100644 --- a/dewy/common/db_migration.py +++ b/dewy/common/db_migration.py @@ -38,7 +38,7 @@ async def apply_migrations( applied += 1 if applied_migrations: - logger.warn("Unrecognized migrations applied: {}", applied_migrations) + logger.warning("Unrecognized migrations applied: {}", applied_migrations) logger.info( "Migrations complete. {} total, {} newly applied", diff --git a/dewy/document/models.py b/dewy/document/models.py index 603cf67..edd1f05 100644 --- a/dewy/document/models.py +++ b/dewy/document/models.py @@ -24,7 +24,7 @@ class IngestState(Enum): class Document(BaseModel): - """Schema for documents in the SQL DB.""" + """Model for documents in Dewy.""" id: Optional[int] = None collection_id: int @@ -41,3 +41,11 @@ class Document(BaseModel): ingest_state: Optional[IngestState] = None ingest_error: Optional[str] = None + + +class DocumentStatus(BaseModel): + """Model for document status.""" + + id: int + ingest_state: IngestState + ingest_error: Optional[str] = None diff --git a/dewy/document/router.py b/dewy/document/router.py index f3c5e54..2b0ae62 100644 --- a/dewy/document/router.py +++ b/dewy/document/router.py @@ -8,7 +8,7 @@ from dewy.common.db import PgConnectionDep, PgPoolDep from dewy.document.models import Document -from .models import AddDocumentRequest +from .models import AddDocumentRequest, DocumentStatus router = APIRouter(prefix="/documents") @@ -125,3 +125,20 @@ async def get_document(conn: PgConnectionDep, id: PathDocumentId) -> Document: id, ) return Document.model_validate(dict(result)) + + +@router.get("/{id}/status") +async def get_document_status( + conn: PgConnectionDep, id: PathDocumentId +) -> DocumentStatus: + result = await conn.fetchrow( + """ + SELECT ingest_state, ingest_error + FROM document + WHERE id = $1 + """, + id, + ) + return DocumentStatus( + id=id, ingest_state=result["ingest_state"], ingest_error=result["ingest_error"] + ) diff --git a/dewy/migrations/0001_schema.sql b/dewy/migrations/0001_schema.sql index b818749..82af2b9 100644 --- a/dewy/migrations/0001_schema.sql +++ b/dewy/migrations/0001_schema.sql @@ -36,7 +36,7 @@ CREATE TABLE document( -- The state of the most recent ingestion of this document. -- TODO: Should we have a separate `ingestion` table and associate -- many ingestions with each document ID? - ingest_state ingest_state, + ingest_state ingest_state NOT NULL, -- The error (if any) resulting from the most recent ingestion. ingest_error VARCHAR, diff --git a/openapi.yaml b/openapi.yaml index dbba7d8..1d763a9 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -176,6 +176,32 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /api/documents/{id}/status: + get: + summary: Get Document Status + operationId: getDocumentStatus + parameters: + - name: id + in: path + required: true + schema: + type: integer + description: The document ID. + title: Id + description: The document ID. + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/DocumentStatus' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /api/chunks/: get: summary: List Chunks @@ -400,7 +426,25 @@ components: - collection_id - url title: Document - description: Schema for documents in the SQL DB. + description: Model for documents in Dewy. + DocumentStatus: + properties: + id: + type: integer + title: Id + ingest_state: + $ref: '#/components/schemas/IngestState' + ingest_error: + anyOf: + - type: string + - type: 'null' + title: Ingest Error + type: object + required: + - id + - ingest_state + title: DocumentStatus + description: Model for document status. HTTPValidationError: properties: detail: diff --git a/pyproject.toml b/pyproject.toml index dd60b92..ca63953 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,4 +119,4 @@ show_error_context = true [tool.pytest.ini_options] testpaths = ["tests"] asyncio_mode = "auto" -timeout = 5 \ No newline at end of file +timeout = 10 \ No newline at end of file diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 99b0a2c..cf04e7e 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -7,6 +7,7 @@ add_collection, add_document, get_document, + get_document_status, list_chunks, retrieve_chunks, ) @@ -23,7 +24,7 @@ @pytest.mark.parametrize( "embedding_model", ["openai:text-embedding-ada-002", "hf:BAAI/bge-small-en"] ) -@pytest.mark.timeout(60) # slow due to embedding +@pytest.mark.timeout(120) # slow due to embedding (especially in CI) async def test_index_retrieval(client, embedding_model): name = "".join(random.choices(string.ascii_lowercase, k=5)) @@ -39,9 +40,12 @@ async def test_index_retrieval(client, embedding_model): ), ) - while document.ingest_state != IngestState.INGESTED: + status = None + while getattr(status, "ingest_state", IngestState.PENDING) == IngestState.PENDING: time.sleep(0.5) - document = await get_document.asyncio(document.id, client=client) + status = await get_document_status.asyncio(document.id, client=client) + + document = await get_document.asyncio(document.id, client=client) assert document.extracted_text.startswith("Skeleton-of-Thought") chunks = await list_chunks.asyncio( @@ -77,10 +81,15 @@ async def test_ingest_error(client): body=AddDocumentRequest(url=f"error://{MESSAGE}", collection_id=collection.id), ) - while document.ingest_state == IngestState.PENDING: + status = None + while getattr(status, "ingest_state", IngestState.PENDING) == IngestState.PENDING: time.sleep(0.2) - document = await get_document.asyncio(document.id, client=client) + status = await get_document_status.asyncio(document.id, client=client) + + assert status.ingest_state == IngestState.FAILED + assert status.ingest_error == MESSAGE + document = await get_document.asyncio(document.id, client=client) assert document.ingest_state == IngestState.FAILED assert document.ingest_error == MESSAGE