From edf1433ab42595491bd2f517e23222cdeecce6b8 Mon Sep 17 00:00:00 2001 From: Ben Chambers <35960+bjchambers@users.noreply.github.com> Date: Tue, 30 Jan 2024 13:28:48 -0800 Subject: [PATCH] feat: Generate python dewy-client (#45) * feat: Generate and use Python client in tests * format, lint --- .github/workflows/ci.yml | 10 +- .gitignore | 3 +- dewy-client/.gitignore | 23 + dewy-client/README.md | 124 ++++ dewy-client/dewy_client/__init__.py | 7 + dewy-client/dewy_client/api/__init__.py | 1 + .../dewy_client/api/default/__init__.py | 0 .../dewy_client/api/default/add_collection.py | 173 +++++ .../dewy_client/api/default/add_document.py | 173 +++++ .../dewy_client/api/default/get_chunk.py | 172 +++++ .../dewy_client/api/default/get_collection.py | 163 +++++ .../dewy_client/api/default/get_document.py | 155 +++++ .../dewy_client/api/default/list_chunks.py | 258 +++++++ .../api/default/list_collections.py | 139 ++++ .../dewy_client/api/default/list_documents.py | 185 +++++ .../api/default/retrieve_chunks.py | 173 +++++ dewy-client/dewy_client/client.py | 268 ++++++++ dewy-client/dewy_client/errors.py | 14 + dewy-client/dewy_client/models/__init__.py | 33 + .../models/add_document_request.py | 80 +++ dewy-client/dewy_client/models/collection.py | 93 +++ .../dewy_client/models/collection_create.py | 87 +++ .../dewy_client/models/distance_metric.py | 10 + dewy-client/dewy_client/models/document.py | 160 +++++ .../models/http_validation_error.py | 74 ++ dewy-client/dewy_client/models/image_chunk.py | 133 ++++ .../dewy_client/models/image_result.py | 130 ++++ .../dewy_client/models/ingest_state.py | 10 + .../dewy_client/models/retrieve_request.py | 105 +++ .../dewy_client/models/retrieve_response.py | 106 +++ dewy-client/dewy_client/models/text_chunk.py | 133 ++++ dewy-client/dewy_client/models/text_result.py | 132 ++++ .../dewy_client/models/validation_error.py | 87 +++ dewy-client/dewy_client/py.typed | 1 + dewy-client/dewy_client/types.py | 44 ++ dewy-client/pyproject.toml | 25 + dewy/config.py | 4 +- openapi.yaml | 641 ++++++++++++++++++ openapi_client_config.yaml | 2 + poetry.lock | 58 +- pyproject.toml | 11 + tests/conftest.py | 7 +- tests/test_collection.py | 31 +- tests/test_e2e.py | 130 ++-- 44 files changed, 4256 insertions(+), 112 deletions(-) create mode 100644 dewy-client/.gitignore create mode 100644 dewy-client/README.md create mode 100644 dewy-client/dewy_client/__init__.py create mode 100644 dewy-client/dewy_client/api/__init__.py create mode 100644 dewy-client/dewy_client/api/default/__init__.py create mode 100644 dewy-client/dewy_client/api/default/add_collection.py create mode 100644 dewy-client/dewy_client/api/default/add_document.py create mode 100644 dewy-client/dewy_client/api/default/get_chunk.py create mode 100644 dewy-client/dewy_client/api/default/get_collection.py create mode 100644 dewy-client/dewy_client/api/default/get_document.py create mode 100644 dewy-client/dewy_client/api/default/list_chunks.py create mode 100644 dewy-client/dewy_client/api/default/list_collections.py create mode 100644 dewy-client/dewy_client/api/default/list_documents.py create mode 100644 dewy-client/dewy_client/api/default/retrieve_chunks.py create mode 100644 dewy-client/dewy_client/client.py create mode 100644 dewy-client/dewy_client/errors.py create mode 100644 dewy-client/dewy_client/models/__init__.py create mode 100644 dewy-client/dewy_client/models/add_document_request.py create mode 100644 dewy-client/dewy_client/models/collection.py create mode 100644 dewy-client/dewy_client/models/collection_create.py create mode 100644 dewy-client/dewy_client/models/distance_metric.py create mode 100644 dewy-client/dewy_client/models/document.py create mode 100644 dewy-client/dewy_client/models/http_validation_error.py create mode 100644 dewy-client/dewy_client/models/image_chunk.py create mode 100644 dewy-client/dewy_client/models/image_result.py create mode 100644 dewy-client/dewy_client/models/ingest_state.py create mode 100644 dewy-client/dewy_client/models/retrieve_request.py create mode 100644 dewy-client/dewy_client/models/retrieve_response.py create mode 100644 dewy-client/dewy_client/models/text_chunk.py create mode 100644 dewy-client/dewy_client/models/text_result.py create mode 100644 dewy-client/dewy_client/models/validation_error.py create mode 100644 dewy-client/dewy_client/py.typed create mode 100644 dewy-client/dewy_client/types.py create mode 100644 dewy-client/pyproject.toml create mode 100644 openapi.yaml create mode 100644 openapi_client_config.yaml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8b2a9b3..be16a07 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,11 +33,19 @@ jobs: path: ./.venv key: venv-${{ hashFiles('poetry.lock') }} - name: Install the project dependencies - run: poetry install + run: poetry install --with=dev - name: pytest run: poetry run pytest -v env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + - name: check openapi client up to date + run: | + poetry run poe extract-openapi + poetry run poe update-client + # Record intent to add any new files in client + git add -N dewy-client + # Diff, and report any changes (including any new files in dewy-client) + git diff --exit-code python_lint: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 78b6a8d..6150e7b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ .ruff_cache __pycache__ .env -.vscode/ -openapi.yaml \ No newline at end of file +.vscode/ \ No newline at end of file diff --git a/dewy-client/.gitignore b/dewy-client/.gitignore new file mode 100644 index 0000000..79a2c3d --- /dev/null +++ b/dewy-client/.gitignore @@ -0,0 +1,23 @@ +__pycache__/ +build/ +dist/ +*.egg-info/ +.pytest_cache/ + +# pyenv +.python-version + +# Environments +.env +.venv + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# JetBrains +.idea/ + +/coverage.xml +/.coverage diff --git a/dewy-client/README.md b/dewy-client/README.md new file mode 100644 index 0000000..5bd624d --- /dev/null +++ b/dewy-client/README.md @@ -0,0 +1,124 @@ +# dewy-client +A client library for accessing Dewy Knowledge Base API + +## Usage +First, create a client: + +```python +from dewy_client import Client + +client = Client(base_url="https://api.example.com") +``` + +If the endpoints you're going to hit require authentication, use `AuthenticatedClient` instead: + +```python +from dewy_client import AuthenticatedClient + +client = AuthenticatedClient(base_url="https://api.example.com", token="SuperSecretToken") +``` + +Now call your endpoint and use your models: + +```python +from dewy_client.models import MyDataModel +from dewy_client.api.my_tag import get_my_data_model +from dewy_client.types import Response + +with client as client: + my_data: MyDataModel = get_my_data_model.sync(client=client) + # or if you need more info (e.g. status_code) + response: Response[MyDataModel] = get_my_data_model.sync_detailed(client=client) +``` + +Or do the same thing with an async version: + +```python +from dewy_client.models import MyDataModel +from dewy_client.api.my_tag import get_my_data_model +from dewy_client.types import Response + +async with client as client: + my_data: MyDataModel = await get_my_data_model.asyncio(client=client) + response: Response[MyDataModel] = await get_my_data_model.asyncio_detailed(client=client) +``` + +By default, when you're calling an HTTPS API it will attempt to verify that SSL is working correctly. Using certificate verification is highly recommended most of the time, but sometimes you may need to authenticate to a server (especially an internal server) using a custom certificate bundle. + +```python +client = AuthenticatedClient( + base_url="https://internal_api.example.com", + token="SuperSecretToken", + verify_ssl="/path/to/certificate_bundle.pem", +) +``` + +You can also disable certificate validation altogether, but beware that **this is a security risk**. + +```python +client = AuthenticatedClient( + base_url="https://internal_api.example.com", + token="SuperSecretToken", + verify_ssl=False +) +``` + +Things to know: +1. Every path/method combo becomes a Python module with four functions: + 1. `sync`: Blocking request that returns parsed data (if successful) or `None` + 1. `sync_detailed`: Blocking request that always returns a `Request`, optionally with `parsed` set if the request was successful. + 1. `asyncio`: Like `sync` but async instead of blocking + 1. `asyncio_detailed`: Like `sync_detailed` but async instead of blocking + +1. All path/query params, and bodies become method arguments. +1. If your endpoint had any tags on it, the first tag will be used as a module name for the function (my_tag above) +1. Any endpoint which did not have a tag will be in `dewy_client.api.default` + +## Advanced customizations + +There are more settings on the generated `Client` class which let you control more runtime behavior, check out the docstring on that class for more info. You can also customize the underlying `httpx.Client` or `httpx.AsyncClient` (depending on your use-case): + +```python +from dewy_client import Client + +def log_request(request): + print(f"Request event hook: {request.method} {request.url} - Waiting for response") + +def log_response(response): + request = response.request + print(f"Response event hook: {request.method} {request.url} - Status {response.status_code}") + +client = Client( + base_url="https://api.example.com", + httpx_args={"event_hooks": {"request": [log_request], "response": [log_response]}}, +) + +# Or get the underlying httpx client to modify directly with client.get_httpx_client() or client.get_async_httpx_client() +``` + +You can even set the httpx client directly, but beware that this will override any existing settings (e.g., base_url): + +```python +import httpx +from dewy_client import Client + +client = Client( + base_url="https://api.example.com", +) +# Note that base_url needs to be re-set, as would any shared cookies, headers, etc. +client.set_httpx_client(httpx.Client(base_url="https://api.example.com", proxies="http://localhost:8030")) +``` + +## Building / publishing this package +This project uses [Poetry](https://python-poetry.org/) to manage dependencies and packaging. Here are the basics: +1. Update the metadata in pyproject.toml (e.g. authors, version) +1. If you're using a private repository, configure it with Poetry + 1. `poetry config repositories. ` + 1. `poetry config http-basic. ` +1. Publish the client with `poetry publish --build -r ` or, if for public PyPI, just `poetry publish --build` + +If you want to install this client into another project without publishing it (e.g. for development) then: +1. If that project **is using Poetry**, you can simply do `poetry add ` from that project +1. If that project is not using Poetry: + 1. Build a wheel with `poetry build -f wheel` + 1. Install that wheel from the other project `pip install ` diff --git a/dewy-client/dewy_client/__init__.py b/dewy-client/dewy_client/__init__.py new file mode 100644 index 0000000..a96f200 --- /dev/null +++ b/dewy-client/dewy_client/__init__.py @@ -0,0 +1,7 @@ +""" A client library for accessing Dewy Knowledge Base API """ +from .client import AuthenticatedClient, Client + +__all__ = ( + "AuthenticatedClient", + "Client", +) diff --git a/dewy-client/dewy_client/api/__init__.py b/dewy-client/dewy_client/api/__init__.py new file mode 100644 index 0000000..dc035f4 --- /dev/null +++ b/dewy-client/dewy_client/api/__init__.py @@ -0,0 +1 @@ +""" Contains methods for accessing the API """ diff --git a/dewy-client/dewy_client/api/default/__init__.py b/dewy-client/dewy_client/api/default/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dewy-client/dewy_client/api/default/add_collection.py b/dewy-client/dewy_client/api/default/add_collection.py new file mode 100644 index 0000000..067b3ba --- /dev/null +++ b/dewy-client/dewy_client/api/default/add_collection.py @@ -0,0 +1,173 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.collection import Collection +from ...models.collection_create import CollectionCreate +from ...models.http_validation_error import HTTPValidationError +from ...types import Response + + +def _get_kwargs( + *, + body: CollectionCreate, +) -> Dict[str, Any]: + headers: Dict[str, Any] = {} + + _kwargs: Dict[str, Any] = { + "method": "put", + "url": "/api/collections/", + } + + _body = body.to_dict() + + _kwargs["json"] = _body + headers["Content-Type"] = "application/json" + + _kwargs["headers"] = headers + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[Collection, HTTPValidationError]]: + if response.status_code == HTTPStatus.OK: + response_200 = Collection.from_dict(response.json()) + + return response_200 + if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[Collection, HTTPValidationError]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], + body: CollectionCreate, +) -> Response[Union[Collection, HTTPValidationError]]: + """Add Collection + + Create a collection. + + Args: + body (CollectionCreate): The request to create a collection. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[Collection, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], + body: CollectionCreate, +) -> Optional[Union[Collection, HTTPValidationError]]: + """Add Collection + + Create a collection. + + Args: + body (CollectionCreate): The request to create a collection. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[Collection, HTTPValidationError] + """ + + return sync_detailed( + client=client, + body=body, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], + body: CollectionCreate, +) -> Response[Union[Collection, HTTPValidationError]]: + """Add Collection + + Create a collection. + + Args: + body (CollectionCreate): The request to create a collection. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[Collection, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], + body: CollectionCreate, +) -> Optional[Union[Collection, HTTPValidationError]]: + """Add Collection + + Create a collection. + + Args: + body (CollectionCreate): The request to create a collection. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[Collection, HTTPValidationError] + """ + + return ( + await asyncio_detailed( + client=client, + body=body, + ) + ).parsed diff --git a/dewy-client/dewy_client/api/default/add_document.py b/dewy-client/dewy_client/api/default/add_document.py new file mode 100644 index 0000000..495143d --- /dev/null +++ b/dewy-client/dewy_client/api/default/add_document.py @@ -0,0 +1,173 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.add_document_request import AddDocumentRequest +from ...models.document import Document +from ...models.http_validation_error import HTTPValidationError +from ...types import Response + + +def _get_kwargs( + *, + body: AddDocumentRequest, +) -> Dict[str, Any]: + headers: Dict[str, Any] = {} + + _kwargs: Dict[str, Any] = { + "method": "put", + "url": "/api/documents/", + } + + _body = body.to_dict() + + _kwargs["json"] = _body + headers["Content-Type"] = "application/json" + + _kwargs["headers"] = headers + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[Document, HTTPValidationError]]: + if response.status_code == HTTPStatus.OK: + response_200 = Document.from_dict(response.json()) + + return response_200 + if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[Document, HTTPValidationError]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], + body: AddDocumentRequest, +) -> Response[Union[Document, HTTPValidationError]]: + """Add Document + + Add a document. + + Args: + body (AddDocumentRequest): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[Document, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], + body: AddDocumentRequest, +) -> Optional[Union[Document, HTTPValidationError]]: + """Add Document + + Add a document. + + Args: + body (AddDocumentRequest): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[Document, HTTPValidationError] + """ + + return sync_detailed( + client=client, + body=body, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], + body: AddDocumentRequest, +) -> Response[Union[Document, HTTPValidationError]]: + """Add Document + + Add a document. + + Args: + body (AddDocumentRequest): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[Document, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], + body: AddDocumentRequest, +) -> Optional[Union[Document, HTTPValidationError]]: + """Add Document + + Add a document. + + Args: + body (AddDocumentRequest): + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[Document, HTTPValidationError] + """ + + return ( + await asyncio_detailed( + client=client, + body=body, + ) + ).parsed diff --git a/dewy-client/dewy_client/api/default/get_chunk.py b/dewy-client/dewy_client/api/default/get_chunk.py new file mode 100644 index 0000000..a65db15 --- /dev/null +++ b/dewy-client/dewy_client/api/default/get_chunk.py @@ -0,0 +1,172 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.http_validation_error import HTTPValidationError +from ...models.image_chunk import ImageChunk +from ...models.text_chunk import TextChunk +from ...types import Response + + +def _get_kwargs( + id: int, +) -> Dict[str, Any]: + _kwargs: Dict[str, Any] = { + "method": "get", + "url": f"/api/chunks/{id}", + } + + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[HTTPValidationError, Union["ImageChunk", "TextChunk"]]]: + if response.status_code == HTTPStatus.OK: + + def _parse_response_200(data: object) -> Union["ImageChunk", "TextChunk"]: + try: + if not isinstance(data, dict): + raise TypeError() + response_200_type_0 = TextChunk.from_dict(data) + + return response_200_type_0 + except: # noqa: E722 + pass + if not isinstance(data, dict): + raise TypeError() + response_200_type_1 = ImageChunk.from_dict(data) + + return response_200_type_1 + + response_200 = _parse_response_200(response.json()) + + return response_200 + if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[HTTPValidationError, Union["ImageChunk", "TextChunk"]]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[HTTPValidationError, Union["ImageChunk", "TextChunk"]]]: + """Get Chunk + + Args: + id (int): The chunk ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[HTTPValidationError, Union['ImageChunk', 'TextChunk']]] + """ + + kwargs = _get_kwargs( + id=id, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[HTTPValidationError, Union["ImageChunk", "TextChunk"]]]: + """Get Chunk + + Args: + id (int): The chunk ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[HTTPValidationError, Union['ImageChunk', 'TextChunk']] + """ + + return sync_detailed( + id=id, + client=client, + ).parsed + + +async def asyncio_detailed( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[HTTPValidationError, Union["ImageChunk", "TextChunk"]]]: + """Get Chunk + + Args: + id (int): The chunk ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[HTTPValidationError, Union['ImageChunk', 'TextChunk']]] + """ + + kwargs = _get_kwargs( + id=id, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[HTTPValidationError, Union["ImageChunk", "TextChunk"]]]: + """Get Chunk + + Args: + id (int): The chunk ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[HTTPValidationError, Union['ImageChunk', 'TextChunk']] + """ + + return ( + await asyncio_detailed( + id=id, + client=client, + ) + ).parsed diff --git a/dewy-client/dewy_client/api/default/get_collection.py b/dewy-client/dewy_client/api/default/get_collection.py new file mode 100644 index 0000000..76bdd00 --- /dev/null +++ b/dewy-client/dewy_client/api/default/get_collection.py @@ -0,0 +1,163 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.collection import Collection +from ...models.http_validation_error import HTTPValidationError +from ...types import Response + + +def _get_kwargs( + id: int, +) -> Dict[str, Any]: + _kwargs: Dict[str, Any] = { + "method": "get", + "url": f"/api/collections/{id}", + } + + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[Collection, HTTPValidationError]]: + if response.status_code == HTTPStatus.OK: + response_200 = Collection.from_dict(response.json()) + + return response_200 + if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[Collection, HTTPValidationError]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[Collection, HTTPValidationError]]: + """Get Collection + + Get a specific collection. + + Args: + id (int): The collection ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[Collection, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + id=id, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[Collection, HTTPValidationError]]: + """Get Collection + + Get a specific collection. + + Args: + id (int): The collection ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[Collection, HTTPValidationError] + """ + + return sync_detailed( + id=id, + client=client, + ).parsed + + +async def asyncio_detailed( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[Collection, HTTPValidationError]]: + """Get Collection + + Get a specific collection. + + Args: + id (int): The collection ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[Collection, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + id=id, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[Collection, HTTPValidationError]]: + """Get Collection + + Get a specific collection. + + Args: + id (int): The collection ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[Collection, HTTPValidationError] + """ + + return ( + await asyncio_detailed( + id=id, + client=client, + ) + ).parsed diff --git a/dewy-client/dewy_client/api/default/get_document.py b/dewy-client/dewy_client/api/default/get_document.py new file mode 100644 index 0000000..97b3504 --- /dev/null +++ b/dewy-client/dewy_client/api/default/get_document.py @@ -0,0 +1,155 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.document import Document +from ...models.http_validation_error import HTTPValidationError +from ...types import Response + + +def _get_kwargs( + id: int, +) -> Dict[str, Any]: + _kwargs: Dict[str, Any] = { + "method": "get", + "url": f"/api/documents/{id}", + } + + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[Document, HTTPValidationError]]: + if response.status_code == HTTPStatus.OK: + response_200 = Document.from_dict(response.json()) + + return response_200 + if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[Document, HTTPValidationError]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[Document, HTTPValidationError]]: + """Get Document + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[Document, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + id=id, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[Document, HTTPValidationError]]: + """Get Document + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[Document, HTTPValidationError] + """ + + return sync_detailed( + id=id, + client=client, + ).parsed + + +async def asyncio_detailed( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Response[Union[Document, HTTPValidationError]]: + """Get Document + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[Document, HTTPValidationError]] + """ + + kwargs = _get_kwargs( + id=id, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + id: int, + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[Union[Document, HTTPValidationError]]: + """Get Document + + Args: + id (int): The document ID. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[Document, HTTPValidationError] + """ + + return ( + await asyncio_detailed( + id=id, + client=client, + ) + ).parsed diff --git a/dewy-client/dewy_client/api/default/list_chunks.py b/dewy-client/dewy_client/api/default/list_chunks.py new file mode 100644 index 0000000..8e756ca --- /dev/null +++ b/dewy-client/dewy_client/api/default/list_chunks.py @@ -0,0 +1,258 @@ +from http import HTTPStatus +from typing import Any, Dict, List, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.http_validation_error import HTTPValidationError +from ...models.image_chunk import ImageChunk +from ...models.text_chunk import TextChunk +from ...types import UNSET, Response, Unset + + +def _get_kwargs( + *, + collection_id: Union[None, Unset, int] = UNSET, + document_id: Union[None, Unset, int] = UNSET, + page: Union[None, Unset, int] = 1, + per_page: Union[None, Unset, int] = 10, +) -> Dict[str, Any]: + params: Dict[str, Any] = {} + + json_collection_id: Union[None, Unset, int] + if isinstance(collection_id, Unset): + json_collection_id = UNSET + else: + json_collection_id = collection_id + params["collection_id"] = json_collection_id + + json_document_id: Union[None, Unset, int] + if isinstance(document_id, Unset): + json_document_id = UNSET + else: + json_document_id = document_id + params["document_id"] = json_document_id + + json_page: Union[None, Unset, int] + if isinstance(page, Unset): + json_page = UNSET + else: + json_page = page + params["page"] = json_page + + json_per_page: Union[None, Unset, int] + if isinstance(per_page, Unset): + json_per_page = UNSET + else: + json_per_page = per_page + params["perPage"] = json_per_page + + params = {k: v for k, v in params.items() if v is not UNSET and v is not None} + + _kwargs: Dict[str, Any] = { + "method": "get", + "url": "/api/chunks/", + "params": params, + } + + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[HTTPValidationError, List[Union["ImageChunk", "TextChunk"]]]]: + if response.status_code == HTTPStatus.OK: + response_200 = [] + _response_200 = response.json() + for response_200_item_data in _response_200: + + def _parse_response_200_item(data: object) -> Union["ImageChunk", "TextChunk"]: + try: + if not isinstance(data, dict): + raise TypeError() + response_200_item_type_0 = TextChunk.from_dict(data) + + return response_200_item_type_0 + except: # noqa: E722 + pass + if not isinstance(data, dict): + raise TypeError() + response_200_item_type_1 = ImageChunk.from_dict(data) + + return response_200_item_type_1 + + response_200_item = _parse_response_200_item(response_200_item_data) + + response_200.append(response_200_item) + + return response_200 + if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[HTTPValidationError, List[Union["ImageChunk", "TextChunk"]]]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], + collection_id: Union[None, Unset, int] = UNSET, + document_id: Union[None, Unset, int] = UNSET, + page: Union[None, Unset, int] = 1, + per_page: Union[None, Unset, int] = 10, +) -> Response[Union[HTTPValidationError, List[Union["ImageChunk", "TextChunk"]]]]: + """List Chunks + + List chunks. + + Args: + collection_id (Union[None, Unset, int]): Limit to chunks associated with this collection + document_id (Union[None, Unset, int]): Limit to chunks associated with this document + page (Union[None, Unset, int]): Default: 1. + per_page (Union[None, Unset, int]): Default: 10. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[HTTPValidationError, List[Union['ImageChunk', 'TextChunk']]]] + """ + + kwargs = _get_kwargs( + collection_id=collection_id, + document_id=document_id, + page=page, + per_page=per_page, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], + collection_id: Union[None, Unset, int] = UNSET, + document_id: Union[None, Unset, int] = UNSET, + page: Union[None, Unset, int] = 1, + per_page: Union[None, Unset, int] = 10, +) -> Optional[Union[HTTPValidationError, List[Union["ImageChunk", "TextChunk"]]]]: + """List Chunks + + List chunks. + + Args: + collection_id (Union[None, Unset, int]): Limit to chunks associated with this collection + document_id (Union[None, Unset, int]): Limit to chunks associated with this document + page (Union[None, Unset, int]): Default: 1. + per_page (Union[None, Unset, int]): Default: 10. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[HTTPValidationError, List[Union['ImageChunk', 'TextChunk']]] + """ + + return sync_detailed( + client=client, + collection_id=collection_id, + document_id=document_id, + page=page, + per_page=per_page, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], + collection_id: Union[None, Unset, int] = UNSET, + document_id: Union[None, Unset, int] = UNSET, + page: Union[None, Unset, int] = 1, + per_page: Union[None, Unset, int] = 10, +) -> Response[Union[HTTPValidationError, List[Union["ImageChunk", "TextChunk"]]]]: + """List Chunks + + List chunks. + + Args: + collection_id (Union[None, Unset, int]): Limit to chunks associated with this collection + document_id (Union[None, Unset, int]): Limit to chunks associated with this document + page (Union[None, Unset, int]): Default: 1. + per_page (Union[None, Unset, int]): Default: 10. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[HTTPValidationError, List[Union['ImageChunk', 'TextChunk']]]] + """ + + kwargs = _get_kwargs( + collection_id=collection_id, + document_id=document_id, + page=page, + per_page=per_page, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], + collection_id: Union[None, Unset, int] = UNSET, + document_id: Union[None, Unset, int] = UNSET, + page: Union[None, Unset, int] = 1, + per_page: Union[None, Unset, int] = 10, +) -> Optional[Union[HTTPValidationError, List[Union["ImageChunk", "TextChunk"]]]]: + """List Chunks + + List chunks. + + Args: + collection_id (Union[None, Unset, int]): Limit to chunks associated with this collection + document_id (Union[None, Unset, int]): Limit to chunks associated with this document + page (Union[None, Unset, int]): Default: 1. + per_page (Union[None, Unset, int]): Default: 10. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[HTTPValidationError, List[Union['ImageChunk', 'TextChunk']]] + """ + + return ( + await asyncio_detailed( + client=client, + collection_id=collection_id, + document_id=document_id, + page=page, + per_page=per_page, + ) + ).parsed diff --git a/dewy-client/dewy_client/api/default/list_collections.py b/dewy-client/dewy_client/api/default/list_collections.py new file mode 100644 index 0000000..de7c36f --- /dev/null +++ b/dewy-client/dewy_client/api/default/list_collections.py @@ -0,0 +1,139 @@ +from http import HTTPStatus +from typing import Any, Dict, List, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.collection import Collection +from ...types import Response + + +def _get_kwargs() -> Dict[str, Any]: + _kwargs: Dict[str, Any] = { + "method": "get", + "url": "/api/collections/", + } + + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[List["Collection"]]: + if response.status_code == HTTPStatus.OK: + response_200 = [] + _response_200 = response.json() + for response_200_item_data in _response_200: + response_200_item = Collection.from_dict(response_200_item_data) + + response_200.append(response_200_item) + + return response_200 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[List["Collection"]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], +) -> Response[List["Collection"]]: + """List Collections + + List collections. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[List['Collection']] + """ + + kwargs = _get_kwargs() + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[List["Collection"]]: + """List Collections + + List collections. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + List['Collection'] + """ + + return sync_detailed( + client=client, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], +) -> Response[List["Collection"]]: + """List Collections + + List collections. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[List['Collection']] + """ + + kwargs = _get_kwargs() + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], +) -> Optional[List["Collection"]]: + """List Collections + + List collections. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + List['Collection'] + """ + + return ( + await asyncio_detailed( + client=client, + ) + ).parsed diff --git a/dewy-client/dewy_client/api/default/list_documents.py b/dewy-client/dewy_client/api/default/list_documents.py new file mode 100644 index 0000000..11484cb --- /dev/null +++ b/dewy-client/dewy_client/api/default/list_documents.py @@ -0,0 +1,185 @@ +from http import HTTPStatus +from typing import Any, Dict, List, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.document import Document +from ...models.http_validation_error import HTTPValidationError +from ...types import UNSET, Response, Unset + + +def _get_kwargs( + *, + collection_id: Union[None, Unset, int] = UNSET, +) -> Dict[str, Any]: + params: Dict[str, Any] = {} + + json_collection_id: Union[None, Unset, int] + if isinstance(collection_id, Unset): + json_collection_id = UNSET + else: + json_collection_id = collection_id + params["collection_id"] = json_collection_id + + params = {k: v for k, v in params.items() if v is not UNSET and v is not None} + + _kwargs: Dict[str, Any] = { + "method": "get", + "url": "/api/documents/", + "params": params, + } + + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[HTTPValidationError, List["Document"]]]: + if response.status_code == HTTPStatus.OK: + response_200 = [] + _response_200 = response.json() + for response_200_item_data in _response_200: + response_200_item = Document.from_dict(response_200_item_data) + + response_200.append(response_200_item) + + return response_200 + if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[HTTPValidationError, List["Document"]]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], + collection_id: Union[None, Unset, int] = UNSET, +) -> Response[Union[HTTPValidationError, List["Document"]]]: + """List Documents + + List documents. + + Args: + collection_id (Union[None, Unset, int]): Limit to documents associated with this + collection + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[HTTPValidationError, List['Document']]] + """ + + kwargs = _get_kwargs( + collection_id=collection_id, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], + collection_id: Union[None, Unset, int] = UNSET, +) -> Optional[Union[HTTPValidationError, List["Document"]]]: + """List Documents + + List documents. + + Args: + collection_id (Union[None, Unset, int]): Limit to documents associated with this + collection + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[HTTPValidationError, List['Document']] + """ + + return sync_detailed( + client=client, + collection_id=collection_id, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], + collection_id: Union[None, Unset, int] = UNSET, +) -> Response[Union[HTTPValidationError, List["Document"]]]: + """List Documents + + List documents. + + Args: + collection_id (Union[None, Unset, int]): Limit to documents associated with this + collection + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[HTTPValidationError, List['Document']]] + """ + + kwargs = _get_kwargs( + collection_id=collection_id, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], + collection_id: Union[None, Unset, int] = UNSET, +) -> Optional[Union[HTTPValidationError, List["Document"]]]: + """List Documents + + List documents. + + Args: + collection_id (Union[None, Unset, int]): Limit to documents associated with this + collection + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[HTTPValidationError, List['Document']] + """ + + return ( + await asyncio_detailed( + client=client, + collection_id=collection_id, + ) + ).parsed diff --git a/dewy-client/dewy_client/api/default/retrieve_chunks.py b/dewy-client/dewy_client/api/default/retrieve_chunks.py new file mode 100644 index 0000000..48a316d --- /dev/null +++ b/dewy-client/dewy_client/api/default/retrieve_chunks.py @@ -0,0 +1,173 @@ +from http import HTTPStatus +from typing import Any, Dict, Optional, Union + +import httpx + +from ... import errors +from ...client import AuthenticatedClient, Client +from ...models.http_validation_error import HTTPValidationError +from ...models.retrieve_request import RetrieveRequest +from ...models.retrieve_response import RetrieveResponse +from ...types import Response + + +def _get_kwargs( + *, + body: RetrieveRequest, +) -> Dict[str, Any]: + headers: Dict[str, Any] = {} + + _kwargs: Dict[str, Any] = { + "method": "post", + "url": "/api/chunks/retrieve", + } + + _body = body.to_dict() + + _kwargs["json"] = _body + headers["Content-Type"] = "application/json" + + _kwargs["headers"] = headers + return _kwargs + + +def _parse_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Optional[Union[HTTPValidationError, RetrieveResponse]]: + if response.status_code == HTTPStatus.OK: + response_200 = RetrieveResponse.from_dict(response.json()) + + return response_200 + if response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY: + response_422 = HTTPValidationError.from_dict(response.json()) + + return response_422 + if client.raise_on_unexpected_status: + raise errors.UnexpectedStatus(response.status_code, response.content) + else: + return None + + +def _build_response( + *, client: Union[AuthenticatedClient, Client], response: httpx.Response +) -> Response[Union[HTTPValidationError, RetrieveResponse]]: + return Response( + status_code=HTTPStatus(response.status_code), + content=response.content, + headers=response.headers, + parsed=_parse_response(client=client, response=response), + ) + + +def sync_detailed( + *, + client: Union[AuthenticatedClient, Client], + body: RetrieveRequest, +) -> Response[Union[HTTPValidationError, RetrieveResponse]]: + """Retrieve Chunks + + Retrieve chunks based on a given query. + + Args: + body (RetrieveRequest): A request for retrieving chunks from a collection. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[HTTPValidationError, RetrieveResponse]] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = client.get_httpx_client().request( + **kwargs, + ) + + return _build_response(client=client, response=response) + + +def sync( + *, + client: Union[AuthenticatedClient, Client], + body: RetrieveRequest, +) -> Optional[Union[HTTPValidationError, RetrieveResponse]]: + """Retrieve Chunks + + Retrieve chunks based on a given query. + + Args: + body (RetrieveRequest): A request for retrieving chunks from a collection. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[HTTPValidationError, RetrieveResponse] + """ + + return sync_detailed( + client=client, + body=body, + ).parsed + + +async def asyncio_detailed( + *, + client: Union[AuthenticatedClient, Client], + body: RetrieveRequest, +) -> Response[Union[HTTPValidationError, RetrieveResponse]]: + """Retrieve Chunks + + Retrieve chunks based on a given query. + + Args: + body (RetrieveRequest): A request for retrieving chunks from a collection. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Response[Union[HTTPValidationError, RetrieveResponse]] + """ + + kwargs = _get_kwargs( + body=body, + ) + + response = await client.get_async_httpx_client().request(**kwargs) + + return _build_response(client=client, response=response) + + +async def asyncio( + *, + client: Union[AuthenticatedClient, Client], + body: RetrieveRequest, +) -> Optional[Union[HTTPValidationError, RetrieveResponse]]: + """Retrieve Chunks + + Retrieve chunks based on a given query. + + Args: + body (RetrieveRequest): A request for retrieving chunks from a collection. + + Raises: + errors.UnexpectedStatus: If the server returns an undocumented status code and Client.raise_on_unexpected_status is True. + httpx.TimeoutException: If the request takes longer than Client.timeout. + + Returns: + Union[HTTPValidationError, RetrieveResponse] + """ + + return ( + await asyncio_detailed( + client=client, + body=body, + ) + ).parsed diff --git a/dewy-client/dewy_client/client.py b/dewy-client/dewy_client/client.py new file mode 100644 index 0000000..74b476c --- /dev/null +++ b/dewy-client/dewy_client/client.py @@ -0,0 +1,268 @@ +import ssl +from typing import Any, Dict, Optional, Union + +import httpx +from attrs import define, evolve, field + + +@define +class Client: + """A class for keeping track of data related to the API + + The following are accepted as keyword arguments and will be used to construct httpx Clients internally: + + ``base_url``: The base URL for the API, all requests are made to a relative path to this URL + + ``cookies``: A dictionary of cookies to be sent with every request + + ``headers``: A dictionary of headers to be sent with every request + + ``timeout``: The maximum amount of a time a request can take. API functions will raise + httpx.TimeoutException if this is exceeded. + + ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production, + but can be set to False for testing purposes. + + ``follow_redirects``: Whether or not to follow redirects. Default value is False. + + ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor. + + + Attributes: + raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a + status code that was not documented in the source OpenAPI document. Can also be provided as a keyword + argument to the constructor. + """ + + raise_on_unexpected_status: bool = field(default=False, kw_only=True) + _base_url: str + _cookies: Dict[str, str] = field(factory=dict, kw_only=True) + _headers: Dict[str, str] = field(factory=dict, kw_only=True) + _timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True) + _verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True) + _follow_redirects: bool = field(default=False, kw_only=True) + _httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True) + _client: Optional[httpx.Client] = field(default=None, init=False) + _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False) + + def with_headers(self, headers: Dict[str, str]) -> "Client": + """Get a new client matching this one with additional headers""" + if self._client is not None: + self._client.headers.update(headers) + if self._async_client is not None: + self._async_client.headers.update(headers) + return evolve(self, headers={**self._headers, **headers}) + + def with_cookies(self, cookies: Dict[str, str]) -> "Client": + """Get a new client matching this one with additional cookies""" + if self._client is not None: + self._client.cookies.update(cookies) + if self._async_client is not None: + self._async_client.cookies.update(cookies) + return evolve(self, cookies={**self._cookies, **cookies}) + + def with_timeout(self, timeout: httpx.Timeout) -> "Client": + """Get a new client matching this one with a new timeout (in seconds)""" + if self._client is not None: + self._client.timeout = timeout + if self._async_client is not None: + self._async_client.timeout = timeout + return evolve(self, timeout=timeout) + + def set_httpx_client(self, client: httpx.Client) -> "Client": + """Manually the underlying httpx.Client + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._client = client + return self + + def get_httpx_client(self) -> httpx.Client: + """Get the underlying httpx.Client, constructing a new one if not previously set""" + if self._client is None: + self._client = httpx.Client( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._client + + def __enter__(self) -> "Client": + """Enter a context manager for self.client—you cannot enter twice (see httpx docs)""" + self.get_httpx_client().__enter__() + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for internal httpx.Client (see httpx docs)""" + self.get_httpx_client().__exit__(*args, **kwargs) + + def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "Client": + """Manually the underlying httpx.AsyncClient + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._async_client = async_client + return self + + def get_async_httpx_client(self) -> httpx.AsyncClient: + """Get the underlying httpx.AsyncClient, constructing a new one if not previously set""" + if self._async_client is None: + self._async_client = httpx.AsyncClient( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._async_client + + async def __aenter__(self) -> "Client": + """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)""" + await self.get_async_httpx_client().__aenter__() + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)""" + await self.get_async_httpx_client().__aexit__(*args, **kwargs) + + +@define +class AuthenticatedClient: + """A Client which has been authenticated for use on secured endpoints + + The following are accepted as keyword arguments and will be used to construct httpx Clients internally: + + ``base_url``: The base URL for the API, all requests are made to a relative path to this URL + + ``cookies``: A dictionary of cookies to be sent with every request + + ``headers``: A dictionary of headers to be sent with every request + + ``timeout``: The maximum amount of a time a request can take. API functions will raise + httpx.TimeoutException if this is exceeded. + + ``verify_ssl``: Whether or not to verify the SSL certificate of the API server. This should be True in production, + but can be set to False for testing purposes. + + ``follow_redirects``: Whether or not to follow redirects. Default value is False. + + ``httpx_args``: A dictionary of additional arguments to be passed to the ``httpx.Client`` and ``httpx.AsyncClient`` constructor. + + + Attributes: + raise_on_unexpected_status: Whether or not to raise an errors.UnexpectedStatus if the API returns a + status code that was not documented in the source OpenAPI document. Can also be provided as a keyword + argument to the constructor. + token: The token to use for authentication + prefix: The prefix to use for the Authorization header + auth_header_name: The name of the Authorization header + """ + + raise_on_unexpected_status: bool = field(default=False, kw_only=True) + _base_url: str + _cookies: Dict[str, str] = field(factory=dict, kw_only=True) + _headers: Dict[str, str] = field(factory=dict, kw_only=True) + _timeout: Optional[httpx.Timeout] = field(default=None, kw_only=True) + _verify_ssl: Union[str, bool, ssl.SSLContext] = field(default=True, kw_only=True) + _follow_redirects: bool = field(default=False, kw_only=True) + _httpx_args: Dict[str, Any] = field(factory=dict, kw_only=True) + _client: Optional[httpx.Client] = field(default=None, init=False) + _async_client: Optional[httpx.AsyncClient] = field(default=None, init=False) + + token: str + prefix: str = "Bearer" + auth_header_name: str = "Authorization" + + def with_headers(self, headers: Dict[str, str]) -> "AuthenticatedClient": + """Get a new client matching this one with additional headers""" + if self._client is not None: + self._client.headers.update(headers) + if self._async_client is not None: + self._async_client.headers.update(headers) + return evolve(self, headers={**self._headers, **headers}) + + def with_cookies(self, cookies: Dict[str, str]) -> "AuthenticatedClient": + """Get a new client matching this one with additional cookies""" + if self._client is not None: + self._client.cookies.update(cookies) + if self._async_client is not None: + self._async_client.cookies.update(cookies) + return evolve(self, cookies={**self._cookies, **cookies}) + + def with_timeout(self, timeout: httpx.Timeout) -> "AuthenticatedClient": + """Get a new client matching this one with a new timeout (in seconds)""" + if self._client is not None: + self._client.timeout = timeout + if self._async_client is not None: + self._async_client.timeout = timeout + return evolve(self, timeout=timeout) + + def set_httpx_client(self, client: httpx.Client) -> "AuthenticatedClient": + """Manually the underlying httpx.Client + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._client = client + return self + + def get_httpx_client(self) -> httpx.Client: + """Get the underlying httpx.Client, constructing a new one if not previously set""" + if self._client is None: + self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token + self._client = httpx.Client( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._client + + def __enter__(self) -> "AuthenticatedClient": + """Enter a context manager for self.client—you cannot enter twice (see httpx docs)""" + self.get_httpx_client().__enter__() + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for internal httpx.Client (see httpx docs)""" + self.get_httpx_client().__exit__(*args, **kwargs) + + def set_async_httpx_client(self, async_client: httpx.AsyncClient) -> "AuthenticatedClient": + """Manually the underlying httpx.AsyncClient + + **NOTE**: This will override any other settings on the client, including cookies, headers, and timeout. + """ + self._async_client = async_client + return self + + def get_async_httpx_client(self) -> httpx.AsyncClient: + """Get the underlying httpx.AsyncClient, constructing a new one if not previously set""" + if self._async_client is None: + self._headers[self.auth_header_name] = f"{self.prefix} {self.token}" if self.prefix else self.token + self._async_client = httpx.AsyncClient( + base_url=self._base_url, + cookies=self._cookies, + headers=self._headers, + timeout=self._timeout, + verify=self._verify_ssl, + follow_redirects=self._follow_redirects, + **self._httpx_args, + ) + return self._async_client + + async def __aenter__(self) -> "AuthenticatedClient": + """Enter a context manager for underlying httpx.AsyncClient—you cannot enter twice (see httpx docs)""" + await self.get_async_httpx_client().__aenter__() + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + """Exit a context manager for underlying httpx.AsyncClient (see httpx docs)""" + await self.get_async_httpx_client().__aexit__(*args, **kwargs) diff --git a/dewy-client/dewy_client/errors.py b/dewy-client/dewy_client/errors.py new file mode 100644 index 0000000..426f8a2 --- /dev/null +++ b/dewy-client/dewy_client/errors.py @@ -0,0 +1,14 @@ +""" Contains shared errors types that can be raised from API functions """ + + +class UnexpectedStatus(Exception): + """Raised by api functions when the response status an undocumented status and Client.raise_on_unexpected_status is True""" + + def __init__(self, status_code: int, content: bytes): + self.status_code = status_code + self.content = content + + super().__init__(f"Unexpected status code: {status_code}") + + +__all__ = ["UnexpectedStatus"] diff --git a/dewy-client/dewy_client/models/__init__.py b/dewy-client/dewy_client/models/__init__.py new file mode 100644 index 0000000..9e17403 --- /dev/null +++ b/dewy-client/dewy_client/models/__init__.py @@ -0,0 +1,33 @@ +""" Contains all the data models used in inputs/outputs """ + +from .add_document_request import AddDocumentRequest +from .collection import Collection +from .collection_create import CollectionCreate +from .distance_metric import DistanceMetric +from .document import Document +from .http_validation_error import HTTPValidationError +from .image_chunk import ImageChunk +from .image_result import ImageResult +from .ingest_state import IngestState +from .retrieve_request import RetrieveRequest +from .retrieve_response import RetrieveResponse +from .text_chunk import TextChunk +from .text_result import TextResult +from .validation_error import ValidationError + +__all__ = ( + "AddDocumentRequest", + "Collection", + "CollectionCreate", + "DistanceMetric", + "Document", + "HTTPValidationError", + "ImageChunk", + "ImageResult", + "IngestState", + "RetrieveRequest", + "RetrieveResponse", + "TextChunk", + "TextResult", + "ValidationError", +) diff --git a/dewy-client/dewy_client/models/add_document_request.py b/dewy-client/dewy_client/models/add_document_request.py new file mode 100644 index 0000000..48f9311 --- /dev/null +++ b/dewy-client/dewy_client/models/add_document_request.py @@ -0,0 +1,80 @@ +from typing import Any, Dict, List, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="AddDocumentRequest") + + +@_attrs_define +class AddDocumentRequest: + """ + Attributes: + url (str): + collection_id (Union[None, Unset, int]): + """ + + url: str + collection_id: Union[None, Unset, int] = UNSET + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + url = self.url + + collection_id: Union[None, Unset, int] + if isinstance(self.collection_id, Unset): + collection_id = UNSET + else: + collection_id = self.collection_id + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "url": url, + } + ) + if collection_id is not UNSET: + field_dict["collection_id"] = collection_id + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + url = d.pop("url") + + def _parse_collection_id(data: object) -> Union[None, Unset, int]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, int], data) + + collection_id = _parse_collection_id(d.pop("collection_id", UNSET)) + + add_document_request = cls( + url=url, + collection_id=collection_id, + ) + + add_document_request.additional_properties = d + return add_document_request + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/collection.py b/dewy-client/dewy_client/models/collection.py new file mode 100644 index 0000000..1105a21 --- /dev/null +++ b/dewy-client/dewy_client/models/collection.py @@ -0,0 +1,93 @@ +from typing import Any, Dict, List, Type, TypeVar, Union + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..models.distance_metric import DistanceMetric +from ..types import UNSET, Unset + +T = TypeVar("T", bound="Collection") + + +@_attrs_define +class Collection: + """ + Attributes: + id (int): + name (str): + text_embedding_model (str): + text_distance_metric (Union[Unset, DistanceMetric]): Default: DistanceMetric.COSINE. + """ + + id: int + name: str + text_embedding_model: str + text_distance_metric: Union[Unset, DistanceMetric] = DistanceMetric.COSINE + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + id = self.id + + name = self.name + + text_embedding_model = self.text_embedding_model + + text_distance_metric: Union[Unset, str] = UNSET + if not isinstance(self.text_distance_metric, Unset): + text_distance_metric = self.text_distance_metric.value + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "id": id, + "name": name, + "text_embedding_model": text_embedding_model, + } + ) + if text_distance_metric is not UNSET: + field_dict["text_distance_metric"] = text_distance_metric + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + id = d.pop("id") + + name = d.pop("name") + + text_embedding_model = d.pop("text_embedding_model") + + _text_distance_metric = d.pop("text_distance_metric", UNSET) + text_distance_metric: Union[Unset, DistanceMetric] + if isinstance(_text_distance_metric, Unset): + text_distance_metric = UNSET + else: + text_distance_metric = DistanceMetric(_text_distance_metric) + + collection = cls( + id=id, + name=name, + text_embedding_model=text_embedding_model, + text_distance_metric=text_distance_metric, + ) + + collection.additional_properties = d + return collection + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/collection_create.py b/dewy-client/dewy_client/models/collection_create.py new file mode 100644 index 0000000..916690a --- /dev/null +++ b/dewy-client/dewy_client/models/collection_create.py @@ -0,0 +1,87 @@ +from typing import Any, Dict, List, Type, TypeVar, Union + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..models.distance_metric import DistanceMetric +from ..types import UNSET, Unset + +T = TypeVar("T", bound="CollectionCreate") + + +@_attrs_define +class CollectionCreate: + """The request to create a collection. + + Attributes: + name (str): + text_embedding_model (Union[Unset, str]): Default: 'openai:text-embedding-ada-002'. + text_distance_metric (Union[Unset, DistanceMetric]): Default: DistanceMetric.COSINE. + """ + + name: str + text_embedding_model: Union[Unset, str] = "openai:text-embedding-ada-002" + text_distance_metric: Union[Unset, DistanceMetric] = DistanceMetric.COSINE + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + name = self.name + + text_embedding_model = self.text_embedding_model + + text_distance_metric: Union[Unset, str] = UNSET + if not isinstance(self.text_distance_metric, Unset): + text_distance_metric = self.text_distance_metric.value + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "name": name, + } + ) + if text_embedding_model is not UNSET: + field_dict["text_embedding_model"] = text_embedding_model + if text_distance_metric is not UNSET: + field_dict["text_distance_metric"] = text_distance_metric + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + name = d.pop("name") + + text_embedding_model = d.pop("text_embedding_model", UNSET) + + _text_distance_metric = d.pop("text_distance_metric", UNSET) + text_distance_metric: Union[Unset, DistanceMetric] + if isinstance(_text_distance_metric, Unset): + text_distance_metric = UNSET + else: + text_distance_metric = DistanceMetric(_text_distance_metric) + + collection_create = cls( + name=name, + text_embedding_model=text_embedding_model, + text_distance_metric=text_distance_metric, + ) + + collection_create.additional_properties = d + return collection_create + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/distance_metric.py b/dewy-client/dewy_client/models/distance_metric.py new file mode 100644 index 0000000..3c154af --- /dev/null +++ b/dewy-client/dewy_client/models/distance_metric.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class DistanceMetric(str, Enum): + COSINE = "cosine" + IP = "ip" + L2 = "l2" + + def __str__(self) -> str: + return str(self.value) diff --git a/dewy-client/dewy_client/models/document.py b/dewy-client/dewy_client/models/document.py new file mode 100644 index 0000000..78c5bce --- /dev/null +++ b/dewy-client/dewy_client/models/document.py @@ -0,0 +1,160 @@ +from typing import Any, Dict, List, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..models.ingest_state import IngestState +from ..types import UNSET, Unset + +T = TypeVar("T", bound="Document") + + +@_attrs_define +class Document: + """Schema for documents in the SQL DB. + + Attributes: + collection_id (int): + url (str): + id (Union[None, Unset, int]): + extracted_text (Union[None, Unset, str]): + ingest_state (Union[IngestState, None, Unset]): + ingest_error (Union[None, Unset, str]): + """ + + collection_id: int + url: str + id: Union[None, Unset, int] = UNSET + extracted_text: Union[None, Unset, str] = UNSET + ingest_state: Union[IngestState, None, Unset] = UNSET + ingest_error: Union[None, Unset, str] = UNSET + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + collection_id = self.collection_id + + url = self.url + + id: Union[None, Unset, int] + if isinstance(self.id, Unset): + id = UNSET + else: + id = self.id + + extracted_text: Union[None, Unset, str] + if isinstance(self.extracted_text, Unset): + extracted_text = UNSET + else: + extracted_text = self.extracted_text + + ingest_state: Union[None, Unset, str] + if isinstance(self.ingest_state, Unset): + ingest_state = UNSET + elif isinstance(self.ingest_state, IngestState): + ingest_state = self.ingest_state.value + else: + ingest_state = self.ingest_state + + ingest_error: Union[None, Unset, str] + if isinstance(self.ingest_error, Unset): + ingest_error = UNSET + else: + ingest_error = self.ingest_error + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "collection_id": collection_id, + "url": url, + } + ) + if id is not UNSET: + field_dict["id"] = id + if extracted_text is not UNSET: + field_dict["extracted_text"] = extracted_text + if ingest_state is not UNSET: + field_dict["ingest_state"] = ingest_state + if ingest_error is not UNSET: + field_dict["ingest_error"] = ingest_error + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + collection_id = d.pop("collection_id") + + url = d.pop("url") + + def _parse_id(data: object) -> Union[None, Unset, int]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, int], data) + + id = _parse_id(d.pop("id", UNSET)) + + def _parse_extracted_text(data: object) -> Union[None, Unset, str]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, str], data) + + extracted_text = _parse_extracted_text(d.pop("extracted_text", UNSET)) + + def _parse_ingest_state(data: object) -> Union[IngestState, None, Unset]: + if data is None: + return data + if isinstance(data, Unset): + return data + try: + if not isinstance(data, str): + raise TypeError() + ingest_state_type_0 = IngestState(data) + + return ingest_state_type_0 + except: # noqa: E722 + pass + return cast(Union[IngestState, None, Unset], data) + + ingest_state = _parse_ingest_state(d.pop("ingest_state", UNSET)) + + def _parse_ingest_error(data: object) -> Union[None, Unset, str]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, str], data) + + ingest_error = _parse_ingest_error(d.pop("ingest_error", UNSET)) + + document = cls( + collection_id=collection_id, + url=url, + id=id, + extracted_text=extracted_text, + ingest_state=ingest_state, + ingest_error=ingest_error, + ) + + document.additional_properties = d + return document + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/http_validation_error.py b/dewy-client/dewy_client/models/http_validation_error.py new file mode 100644 index 0000000..dd5f086 --- /dev/null +++ b/dewy-client/dewy_client/models/http_validation_error.py @@ -0,0 +1,74 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +if TYPE_CHECKING: + from ..models.validation_error import ValidationError + + +T = TypeVar("T", bound="HTTPValidationError") + + +@_attrs_define +class HTTPValidationError: + """ + Attributes: + detail (Union[Unset, List['ValidationError']]): + """ + + detail: Union[Unset, List["ValidationError"]] = UNSET + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + detail: Union[Unset, List[Dict[str, Any]]] = UNSET + if not isinstance(self.detail, Unset): + detail = [] + for detail_item_data in self.detail: + detail_item = detail_item_data.to_dict() + detail.append(detail_item) + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update({}) + if detail is not UNSET: + field_dict["detail"] = detail + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + from ..models.validation_error import ValidationError + + d = src_dict.copy() + detail = [] + _detail = d.pop("detail", UNSET) + for detail_item_data in _detail or []: + detail_item = ValidationError.from_dict(detail_item_data) + + detail.append(detail_item) + + http_validation_error = cls( + detail=detail, + ) + + http_validation_error.additional_properties = d + return http_validation_error + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/image_chunk.py b/dewy-client/dewy_client/models/image_chunk.py new file mode 100644 index 0000000..b71ded5 --- /dev/null +++ b/dewy-client/dewy_client/models/image_chunk.py @@ -0,0 +1,133 @@ +from typing import Any, Dict, List, Literal, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="ImageChunk") + + +@_attrs_define +class ImageChunk: + """ + Attributes: + id (int): + document_id (int): + image (Union[None, str]): Image of the node. + image_mimetype (Union[None, str]): Mimetype of the image. + image_path (Union[None, str]): Path of the image. + image_url (Union[None, str]): URL of the image. + kind (Union[Literal['image'], Unset]): Default: 'image'. + """ + + id: int + document_id: int + image: Union[None, str] + image_mimetype: Union[None, str] + image_path: Union[None, str] + image_url: Union[None, str] + kind: Union[Literal["image"], Unset] = "image" + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + id = self.id + + document_id = self.document_id + + image: Union[None, str] + image = self.image + + image_mimetype: Union[None, str] + image_mimetype = self.image_mimetype + + image_path: Union[None, str] + image_path = self.image_path + + image_url: Union[None, str] + image_url = self.image_url + + kind = self.kind + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "id": id, + "document_id": document_id, + "image": image, + "image_mimetype": image_mimetype, + "image_path": image_path, + "image_url": image_url, + } + ) + if kind is not UNSET: + field_dict["kind"] = kind + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + id = d.pop("id") + + document_id = d.pop("document_id") + + def _parse_image(data: object) -> Union[None, str]: + if data is None: + return data + return cast(Union[None, str], data) + + image = _parse_image(d.pop("image")) + + def _parse_image_mimetype(data: object) -> Union[None, str]: + if data is None: + return data + return cast(Union[None, str], data) + + image_mimetype = _parse_image_mimetype(d.pop("image_mimetype")) + + def _parse_image_path(data: object) -> Union[None, str]: + if data is None: + return data + return cast(Union[None, str], data) + + image_path = _parse_image_path(d.pop("image_path")) + + def _parse_image_url(data: object) -> Union[None, str]: + if data is None: + return data + return cast(Union[None, str], data) + + image_url = _parse_image_url(d.pop("image_url")) + + kind = d.pop("kind", UNSET) + + image_chunk = cls( + id=id, + document_id=document_id, + image=image, + image_mimetype=image_mimetype, + image_path=image_path, + image_url=image_url, + kind=kind, + ) + + image_chunk.additional_properties = d + return image_chunk + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/image_result.py b/dewy-client/dewy_client/models/image_result.py new file mode 100644 index 0000000..f8e6336 --- /dev/null +++ b/dewy-client/dewy_client/models/image_result.py @@ -0,0 +1,130 @@ +from typing import Any, Dict, List, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +T = TypeVar("T", bound="ImageResult") + + +@_attrs_define +class ImageResult: + """ + Attributes: + chunk_id (int): + document_id (int): + score (float): + image (Union[None, str]): Image of the node. + image_mimetype (Union[None, str]): Mimetype of the image. + image_path (Union[None, str]): Path of the image. + image_url (Union[None, str]): URL of the image. + """ + + chunk_id: int + document_id: int + score: float + image: Union[None, str] + image_mimetype: Union[None, str] + image_path: Union[None, str] + image_url: Union[None, str] + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + chunk_id = self.chunk_id + + document_id = self.document_id + + score = self.score + + image: Union[None, str] + image = self.image + + image_mimetype: Union[None, str] + image_mimetype = self.image_mimetype + + image_path: Union[None, str] + image_path = self.image_path + + image_url: Union[None, str] + image_url = self.image_url + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "chunk_id": chunk_id, + "document_id": document_id, + "score": score, + "image": image, + "image_mimetype": image_mimetype, + "image_path": image_path, + "image_url": image_url, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + chunk_id = d.pop("chunk_id") + + document_id = d.pop("document_id") + + score = d.pop("score") + + def _parse_image(data: object) -> Union[None, str]: + if data is None: + return data + return cast(Union[None, str], data) + + image = _parse_image(d.pop("image")) + + def _parse_image_mimetype(data: object) -> Union[None, str]: + if data is None: + return data + return cast(Union[None, str], data) + + image_mimetype = _parse_image_mimetype(d.pop("image_mimetype")) + + def _parse_image_path(data: object) -> Union[None, str]: + if data is None: + return data + return cast(Union[None, str], data) + + image_path = _parse_image_path(d.pop("image_path")) + + def _parse_image_url(data: object) -> Union[None, str]: + if data is None: + return data + return cast(Union[None, str], data) + + image_url = _parse_image_url(d.pop("image_url")) + + image_result = cls( + chunk_id=chunk_id, + document_id=document_id, + score=score, + image=image, + image_mimetype=image_mimetype, + image_path=image_path, + image_url=image_url, + ) + + image_result.additional_properties = d + return image_result + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/ingest_state.py b/dewy-client/dewy_client/models/ingest_state.py new file mode 100644 index 0000000..be3a0e5 --- /dev/null +++ b/dewy-client/dewy_client/models/ingest_state.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class IngestState(str, Enum): + FAILED = "failed" + INGESTED = "ingested" + PENDING = "pending" + + def __str__(self) -> str: + return str(self.value) diff --git a/dewy-client/dewy_client/models/retrieve_request.py b/dewy-client/dewy_client/models/retrieve_request.py new file mode 100644 index 0000000..051766c --- /dev/null +++ b/dewy-client/dewy_client/models/retrieve_request.py @@ -0,0 +1,105 @@ +from typing import Any, Dict, List, Type, TypeVar, Union + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="RetrieveRequest") + + +@_attrs_define +class RetrieveRequest: + """A request for retrieving chunks from a collection. + + Attributes: + collection_id (int): + query (str): + n (Union[Unset, int]): Default: 10. + include_text_chunks (Union[Unset, bool]): Default: True. + include_image_chunks (Union[Unset, bool]): Default: True. + include_summary (Union[Unset, bool]): Default: False. + """ + + collection_id: int + query: str + n: Union[Unset, int] = 10 + include_text_chunks: Union[Unset, bool] = True + include_image_chunks: Union[Unset, bool] = True + include_summary: Union[Unset, bool] = False + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + collection_id = self.collection_id + + query = self.query + + n = self.n + + include_text_chunks = self.include_text_chunks + + include_image_chunks = self.include_image_chunks + + include_summary = self.include_summary + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "collection_id": collection_id, + "query": query, + } + ) + if n is not UNSET: + field_dict["n"] = n + if include_text_chunks is not UNSET: + field_dict["include_text_chunks"] = include_text_chunks + if include_image_chunks is not UNSET: + field_dict["include_image_chunks"] = include_image_chunks + if include_summary is not UNSET: + field_dict["include_summary"] = include_summary + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + collection_id = d.pop("collection_id") + + query = d.pop("query") + + n = d.pop("n", UNSET) + + include_text_chunks = d.pop("include_text_chunks", UNSET) + + include_image_chunks = d.pop("include_image_chunks", UNSET) + + include_summary = d.pop("include_summary", UNSET) + + retrieve_request = cls( + collection_id=collection_id, + query=query, + n=n, + include_text_chunks=include_text_chunks, + include_image_chunks=include_image_chunks, + include_summary=include_summary, + ) + + retrieve_request.additional_properties = d + return retrieve_request + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/retrieve_response.py b/dewy-client/dewy_client/models/retrieve_response.py new file mode 100644 index 0000000..514fdad --- /dev/null +++ b/dewy-client/dewy_client/models/retrieve_response.py @@ -0,0 +1,106 @@ +from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +if TYPE_CHECKING: + from ..models.image_result import ImageResult + from ..models.text_result import TextResult + + +T = TypeVar("T", bound="RetrieveResponse") + + +@_attrs_define +class RetrieveResponse: + """The response from a retrieval request. + + Attributes: + summary (Union[None, str]): + text_results (List['TextResult']): + image_results (List['ImageResult']): + """ + + summary: Union[None, str] + text_results: List["TextResult"] + image_results: List["ImageResult"] + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + summary: Union[None, str] + summary = self.summary + + text_results = [] + for text_results_item_data in self.text_results: + text_results_item = text_results_item_data.to_dict() + text_results.append(text_results_item) + + image_results = [] + for image_results_item_data in self.image_results: + image_results_item = image_results_item_data.to_dict() + image_results.append(image_results_item) + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "summary": summary, + "text_results": text_results, + "image_results": image_results, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + from ..models.image_result import ImageResult + from ..models.text_result import TextResult + + d = src_dict.copy() + + def _parse_summary(data: object) -> Union[None, str]: + if data is None: + return data + return cast(Union[None, str], data) + + summary = _parse_summary(d.pop("summary")) + + text_results = [] + _text_results = d.pop("text_results") + for text_results_item_data in _text_results: + text_results_item = TextResult.from_dict(text_results_item_data) + + text_results.append(text_results_item) + + image_results = [] + _image_results = d.pop("image_results") + for image_results_item_data in _image_results: + image_results_item = ImageResult.from_dict(image_results_item_data) + + image_results.append(image_results_item) + + retrieve_response = cls( + summary=summary, + text_results=text_results, + image_results=image_results, + ) + + retrieve_response.additional_properties = d + return retrieve_response + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/text_chunk.py b/dewy-client/dewy_client/models/text_chunk.py new file mode 100644 index 0000000..b89bf0b --- /dev/null +++ b/dewy-client/dewy_client/models/text_chunk.py @@ -0,0 +1,133 @@ +from typing import Any, Dict, List, Literal, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="TextChunk") + + +@_attrs_define +class TextChunk: + """ + Attributes: + id (int): + document_id (int): + text (str): + raw (bool): + kind (Union[Literal['text'], Unset]): Default: 'text'. + start_char_idx (Union[None, Unset, int]): Start char index of the chunk. + end_char_idx (Union[None, Unset, int]): End char index of the chunk. + """ + + id: int + document_id: int + text: str + raw: bool + kind: Union[Literal["text"], Unset] = "text" + start_char_idx: Union[None, Unset, int] = UNSET + end_char_idx: Union[None, Unset, int] = UNSET + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + id = self.id + + document_id = self.document_id + + text = self.text + + raw = self.raw + + kind = self.kind + + start_char_idx: Union[None, Unset, int] + if isinstance(self.start_char_idx, Unset): + start_char_idx = UNSET + else: + start_char_idx = self.start_char_idx + + end_char_idx: Union[None, Unset, int] + if isinstance(self.end_char_idx, Unset): + end_char_idx = UNSET + else: + end_char_idx = self.end_char_idx + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "id": id, + "document_id": document_id, + "text": text, + "raw": raw, + } + ) + if kind is not UNSET: + field_dict["kind"] = kind + if start_char_idx is not UNSET: + field_dict["start_char_idx"] = start_char_idx + if end_char_idx is not UNSET: + field_dict["end_char_idx"] = end_char_idx + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + id = d.pop("id") + + document_id = d.pop("document_id") + + text = d.pop("text") + + raw = d.pop("raw") + + kind = d.pop("kind", UNSET) + + def _parse_start_char_idx(data: object) -> Union[None, Unset, int]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, int], data) + + start_char_idx = _parse_start_char_idx(d.pop("start_char_idx", UNSET)) + + def _parse_end_char_idx(data: object) -> Union[None, Unset, int]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, int], data) + + end_char_idx = _parse_end_char_idx(d.pop("end_char_idx", UNSET)) + + text_chunk = cls( + id=id, + document_id=document_id, + text=text, + raw=raw, + kind=kind, + start_char_idx=start_char_idx, + end_char_idx=end_char_idx, + ) + + text_chunk.additional_properties = d + return text_chunk + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/text_result.py b/dewy-client/dewy_client/models/text_result.py new file mode 100644 index 0000000..3cdaec7 --- /dev/null +++ b/dewy-client/dewy_client/models/text_result.py @@ -0,0 +1,132 @@ +from typing import Any, Dict, List, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +from ..types import UNSET, Unset + +T = TypeVar("T", bound="TextResult") + + +@_attrs_define +class TextResult: + """ + Attributes: + chunk_id (int): + document_id (int): + score (float): + text (str): + raw (bool): + start_char_idx (Union[None, Unset, int]): Start char index of the chunk. + end_char_idx (Union[None, Unset, int]): End char index of the chunk. + """ + + chunk_id: int + document_id: int + score: float + text: str + raw: bool + start_char_idx: Union[None, Unset, int] = UNSET + end_char_idx: Union[None, Unset, int] = UNSET + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + chunk_id = self.chunk_id + + document_id = self.document_id + + score = self.score + + text = self.text + + raw = self.raw + + start_char_idx: Union[None, Unset, int] + if isinstance(self.start_char_idx, Unset): + start_char_idx = UNSET + else: + start_char_idx = self.start_char_idx + + end_char_idx: Union[None, Unset, int] + if isinstance(self.end_char_idx, Unset): + end_char_idx = UNSET + else: + end_char_idx = self.end_char_idx + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "chunk_id": chunk_id, + "document_id": document_id, + "score": score, + "text": text, + "raw": raw, + } + ) + if start_char_idx is not UNSET: + field_dict["start_char_idx"] = start_char_idx + if end_char_idx is not UNSET: + field_dict["end_char_idx"] = end_char_idx + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + chunk_id = d.pop("chunk_id") + + document_id = d.pop("document_id") + + score = d.pop("score") + + text = d.pop("text") + + raw = d.pop("raw") + + def _parse_start_char_idx(data: object) -> Union[None, Unset, int]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, int], data) + + start_char_idx = _parse_start_char_idx(d.pop("start_char_idx", UNSET)) + + def _parse_end_char_idx(data: object) -> Union[None, Unset, int]: + if data is None: + return data + if isinstance(data, Unset): + return data + return cast(Union[None, Unset, int], data) + + end_char_idx = _parse_end_char_idx(d.pop("end_char_idx", UNSET)) + + text_result = cls( + chunk_id=chunk_id, + document_id=document_id, + score=score, + text=text, + raw=raw, + start_char_idx=start_char_idx, + end_char_idx=end_char_idx, + ) + + text_result.additional_properties = d + return text_result + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/models/validation_error.py b/dewy-client/dewy_client/models/validation_error.py new file mode 100644 index 0000000..5c6b181 --- /dev/null +++ b/dewy-client/dewy_client/models/validation_error.py @@ -0,0 +1,87 @@ +from typing import Any, Dict, List, Type, TypeVar, Union, cast + +from attrs import define as _attrs_define +from attrs import field as _attrs_field + +T = TypeVar("T", bound="ValidationError") + + +@_attrs_define +class ValidationError: + """ + Attributes: + loc (List[Union[int, str]]): + msg (str): + type (str): + """ + + loc: List[Union[int, str]] + msg: str + type: str + additional_properties: Dict[str, Any] = _attrs_field(init=False, factory=dict) + + def to_dict(self) -> Dict[str, Any]: + loc = [] + for loc_item_data in self.loc: + loc_item: Union[int, str] + loc_item = loc_item_data + loc.append(loc_item) + + msg = self.msg + + type = self.type + + field_dict: Dict[str, Any] = {} + field_dict.update(self.additional_properties) + field_dict.update( + { + "loc": loc, + "msg": msg, + "type": type, + } + ) + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + loc = [] + _loc = d.pop("loc") + for loc_item_data in _loc: + + def _parse_loc_item(data: object) -> Union[int, str]: + return cast(Union[int, str], data) + + loc_item = _parse_loc_item(loc_item_data) + + loc.append(loc_item) + + msg = d.pop("msg") + + type = d.pop("type") + + validation_error = cls( + loc=loc, + msg=msg, + type=type, + ) + + validation_error.additional_properties = d + return validation_error + + @property + def additional_keys(self) -> List[str]: + return list(self.additional_properties.keys()) + + def __getitem__(self, key: str) -> Any: + return self.additional_properties[key] + + def __setitem__(self, key: str, value: Any) -> None: + self.additional_properties[key] = value + + def __delitem__(self, key: str) -> None: + del self.additional_properties[key] + + def __contains__(self, key: str) -> bool: + return key in self.additional_properties diff --git a/dewy-client/dewy_client/py.typed b/dewy-client/dewy_client/py.typed new file mode 100644 index 0000000..1aad327 --- /dev/null +++ b/dewy-client/dewy_client/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561 \ No newline at end of file diff --git a/dewy-client/dewy_client/types.py b/dewy-client/dewy_client/types.py new file mode 100644 index 0000000..dbdcc5d --- /dev/null +++ b/dewy-client/dewy_client/types.py @@ -0,0 +1,44 @@ +""" Contains some shared types for properties """ +from http import HTTPStatus +from typing import BinaryIO, Generic, Literal, MutableMapping, Optional, Tuple, TypeVar + +from attrs import define + + +class Unset: + def __bool__(self) -> Literal[False]: + return False + + +UNSET: Unset = Unset() + +FileJsonType = Tuple[Optional[str], BinaryIO, Optional[str]] + + +@define +class File: + """Contains information for file uploads""" + + payload: BinaryIO + file_name: Optional[str] = None + mime_type: Optional[str] = None + + def to_tuple(self) -> FileJsonType: + """Return a tuple representation that httpx will accept for multipart/form-data""" + return self.file_name, self.payload, self.mime_type + + +T = TypeVar("T") + + +@define +class Response(Generic[T]): + """A response from an endpoint""" + + status_code: HTTPStatus + content: bytes + headers: MutableMapping[str, str] + parsed: Optional[T] + + +__all__ = ["File", "Response", "FileJsonType", "Unset", "UNSET"] diff --git a/dewy-client/pyproject.toml b/dewy-client/pyproject.toml new file mode 100644 index 0000000..0642ac6 --- /dev/null +++ b/dewy-client/pyproject.toml @@ -0,0 +1,25 @@ +[tool.poetry] +name = "dewy-client" +version = "0.1.0" +description = "A client library for accessing Dewy Knowledge Base API" +authors = [] +readme = "README.md" +packages = [ + {include = "dewy_client"}, +] +include = ["CHANGELOG.md", "dewy_client/py.typed"] + + +[tool.poetry.dependencies] +python = "^3.8" +httpx = ">=0.20.0,<0.27.0" +attrs = ">=21.3.0" +python-dateutil = "^2.8.0" + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.ruff] +select = ["F", "I", "UP"] +line-length = 120 diff --git a/dewy/config.py b/dewy/config.py index c56b022..d88415a 100644 --- a/dewy/config.py +++ b/dewy/config.py @@ -62,11 +62,11 @@ def custom_generate_unique_id_function(route: APIRoute) -> str: API_DESCRIPTION: str = """This API allows ingesting and retrieving knowledge. Knowledge comes in a variety of forms -- text, image, tables, etc. and -from a variety of sources -- documents, web pages, audio, etc. -""" +from a variety of sources -- documents, web pages, audio, etc.""" app_configs: dict[str, Any] = { "title": "Dewy Knowledge Base API", + "version": "0.1.0", "summary": "Knowledge curation for Retrieval Augmented Generation", "description": API_DESCRIPTION, "servers": [ diff --git a/openapi.yaml b/openapi.yaml new file mode 100644 index 0000000..c735646 --- /dev/null +++ b/openapi.yaml @@ -0,0 +1,641 @@ +openapi: 3.1.0 +info: + title: Dewy Knowledge Base API + summary: Knowledge curation for Retrieval Augmented Generation + description: 'This API allows ingesting and retrieving knowledge. + + + Knowledge comes in a variety of forms -- text, image, tables, etc. and + + from a variety of sources -- documents, web pages, audio, etc.' + version: 0.1.0 +servers: +- url: http://localhost:8000 + description: Local server +paths: + /api/collections/: + get: + summary: List Collections + description: List collections. + operationId: listCollections + responses: + '200': + description: Successful Response + content: + application/json: + schema: + items: + $ref: '#/components/schemas/Collection' + type: array + title: Response Listcollections + put: + summary: Add Collection + description: Create a collection. + operationId: addCollection + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CollectionCreate' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/Collection' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /api/collections/{id}: + get: + summary: Get Collection + description: Get a specific collection. + operationId: getCollection + parameters: + - name: id + in: path + required: true + schema: + type: integer + description: The collection ID. + title: Id + description: The collection ID. + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/Collection' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /api/documents/: + put: + summary: Add Document + description: Add a document. + operationId: addDocument + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/AddDocumentRequest' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/Document' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + get: + summary: List Documents + description: List documents. + operationId: listDocuments + parameters: + - name: collection_id + in: query + required: false + schema: + anyOf: + - type: integer + - type: 'null' + description: Limit to documents associated with this collection + title: Collection Id + description: Limit to documents associated with this collection + responses: + '200': + description: Successful Response + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Document' + title: Response Listdocuments + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /api/documents/{id}: + get: + summary: Get Document + operationId: getDocument + parameters: + - name: id + in: path + required: true + schema: + type: integer + description: The document ID. + title: Id + description: The document ID. + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/Document' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /api/chunks/: + get: + summary: List Chunks + description: List chunks. + operationId: listChunks + parameters: + - name: collection_id + in: query + required: false + schema: + anyOf: + - type: integer + - type: 'null' + description: Limit to chunks associated with this collection + title: Collection Id + description: Limit to chunks associated with this collection + - name: document_id + in: query + required: false + schema: + anyOf: + - type: integer + - type: 'null' + description: Limit to chunks associated with this document + title: Document Id + description: Limit to chunks associated with this document + - name: page + in: query + required: false + schema: + anyOf: + - type: integer + - type: 'null' + default: 1 + title: Page + - name: perPage + in: query + required: false + schema: + anyOf: + - type: integer + - type: 'null' + default: 10 + title: Perpage + responses: + '200': + description: Successful Response + content: + application/json: + schema: + type: array + items: + oneOf: + - $ref: '#/components/schemas/TextChunk' + - $ref: '#/components/schemas/ImageChunk' + discriminator: + propertyName: kind + mapping: + text: '#/components/schemas/TextChunk' + image: '#/components/schemas/ImageChunk' + title: Response Listchunks + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /api/chunks/{id}: + get: + summary: Get Chunk + operationId: getChunk + parameters: + - name: id + in: path + required: true + schema: + type: integer + description: The chunk ID. + title: Id + description: The chunk ID. + responses: + '200': + description: Successful Response + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/TextChunk' + - $ref: '#/components/schemas/ImageChunk' + discriminator: + propertyName: kind + mapping: + text: '#/components/schemas/TextChunk' + image: '#/components/schemas/ImageChunk' + title: Response Getchunk + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + /api/chunks/retrieve: + post: + summary: Retrieve Chunks + description: Retrieve chunks based on a given query. + operationId: retrieveChunks + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RetrieveRequest' + required: true + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/RetrieveResponse' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' +components: + schemas: + AddDocumentRequest: + properties: + collection_id: + anyOf: + - type: integer + - type: 'null' + title: Collection Id + url: + type: string + title: Url + type: object + required: + - url + title: AddDocumentRequest + Collection: + properties: + id: + type: integer + title: Id + name: + type: string + title: Name + text_embedding_model: + type: string + title: Text Embedding Model + text_distance_metric: + allOf: + - $ref: '#/components/schemas/DistanceMetric' + default: cosine + type: object + required: + - id + - name + - text_embedding_model + title: Collection + CollectionCreate: + properties: + name: + type: string + title: Name + examples: + - my_collection + text_embedding_model: + type: string + title: Text Embedding Model + default: openai:text-embedding-ada-002 + examples: + - openai:text-embedding-ada-002 + - hf:BAAI/bge-small-en + text_distance_metric: + allOf: + - $ref: '#/components/schemas/DistanceMetric' + default: cosine + type: object + required: + - name + title: CollectionCreate + description: The request to create a collection. + DistanceMetric: + type: string + enum: + - cosine + - ip + - l2 + title: DistanceMetric + Document: + properties: + id: + anyOf: + - type: integer + - type: 'null' + title: Id + collection_id: + type: integer + title: Collection Id + extracted_text: + anyOf: + - type: string + - type: 'null' + title: Extracted Text + url: + type: string + title: Url + ingest_state: + anyOf: + - $ref: '#/components/schemas/IngestState' + - type: 'null' + ingest_error: + anyOf: + - type: string + - type: 'null' + title: Ingest Error + type: object + required: + - collection_id + - url + title: Document + description: Schema for documents in the SQL DB. + HTTPValidationError: + properties: + detail: + items: + $ref: '#/components/schemas/ValidationError' + type: array + title: Detail + type: object + title: HTTPValidationError + ImageChunk: + properties: + id: + type: integer + title: Id + document_id: + type: integer + title: Document Id + kind: + const: image + title: Kind + default: image + image: + anyOf: + - type: string + - type: 'null' + title: Image + description: Image of the node. + image_mimetype: + anyOf: + - type: string + - type: 'null' + title: Image Mimetype + description: Mimetype of the image. + image_path: + anyOf: + - type: string + - type: 'null' + title: Image Path + description: Path of the image. + image_url: + anyOf: + - type: string + - type: 'null' + title: Image Url + description: URL of the image. + type: object + required: + - id + - document_id + - image + - image_mimetype + - image_path + - image_url + title: ImageChunk + ImageResult: + properties: + chunk_id: + type: integer + title: Chunk Id + document_id: + type: integer + title: Document Id + score: + type: number + title: Score + image: + anyOf: + - type: string + - type: 'null' + title: Image + description: Image of the node. + image_mimetype: + anyOf: + - type: string + - type: 'null' + title: Image Mimetype + description: Mimetype of the image. + image_path: + anyOf: + - type: string + - type: 'null' + title: Image Path + description: Path of the image. + image_url: + anyOf: + - type: string + - type: 'null' + title: Image Url + description: URL of the image. + type: object + required: + - chunk_id + - document_id + - score + - image + - image_mimetype + - image_path + - image_url + title: ImageResult + IngestState: + type: string + enum: + - pending + - ingested + - failed + title: IngestState + RetrieveRequest: + properties: + collection_id: + type: integer + title: Collection Id + query: + type: string + title: Query + n: + type: integer + title: N + default: 10 + include_text_chunks: + type: boolean + title: Include Text Chunks + default: true + include_image_chunks: + type: boolean + title: Include Image Chunks + default: true + include_summary: + type: boolean + title: Include Summary + default: false + type: object + required: + - collection_id + - query + title: RetrieveRequest + description: A request for retrieving chunks from a collection. + RetrieveResponse: + properties: + summary: + anyOf: + - type: string + - type: 'null' + title: Summary + text_results: + items: + $ref: '#/components/schemas/TextResult' + type: array + title: Text Results + image_results: + items: + $ref: '#/components/schemas/ImageResult' + type: array + title: Image Results + type: object + required: + - summary + - text_results + - image_results + title: RetrieveResponse + description: The response from a retrieval request. + TextChunk: + properties: + id: + type: integer + title: Id + document_id: + type: integer + title: Document Id + kind: + const: text + title: Kind + default: text + text: + type: string + title: Text + raw: + type: boolean + title: Raw + start_char_idx: + anyOf: + - type: integer + - type: 'null' + title: Start Char Idx + description: Start char index of the chunk. + end_char_idx: + anyOf: + - type: integer + - type: 'null' + title: End Char Idx + description: End char index of the chunk. + type: object + required: + - id + - document_id + - text + - raw + title: TextChunk + TextResult: + properties: + chunk_id: + type: integer + title: Chunk Id + document_id: + type: integer + title: Document Id + score: + type: number + title: Score + text: + type: string + title: Text + raw: + type: boolean + title: Raw + start_char_idx: + anyOf: + - type: integer + - type: 'null' + title: Start Char Idx + description: Start char index of the chunk. + end_char_idx: + anyOf: + - type: integer + - type: 'null' + title: End Char Idx + description: End char index of the chunk. + type: object + required: + - chunk_id + - document_id + - score + - text + - raw + title: TextResult + ValidationError: + properties: + loc: + items: + anyOf: + - type: string + - type: integer + type: array + title: Location + msg: + type: string + title: Message + type: + type: string + title: Error Type + type: object + required: + - loc + - msg + - type + title: ValidationError diff --git a/openapi_client_config.yaml b/openapi_client_config.yaml new file mode 100644 index 0000000..daf4794 --- /dev/null +++ b/openapi_client_config.yaml @@ -0,0 +1,2 @@ +project_name_override: dewy-client +package_name_override: dewy_client \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 91574a8..1409350 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1728,6 +1728,30 @@ typing-extensions = ">=4.7,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +[[package]] +name = "openapi-python-client" +version = "0.17.2" +description = "Generate modern Python clients from OpenAPI" +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "openapi_python_client-0.17.2-py3-none-any.whl", hash = "sha256:40577b053fdeb1c47b559bcfacde7645dba3fc4e117ca875b405e94ab841792a"}, + {file = "openapi_python_client-0.17.2.tar.gz", hash = "sha256:4d897a4acb921d22a3d300cb037fb665f4a8be642290a38b57aa977593c008bb"}, +] + +[package.dependencies] +attrs = ">=21.3.0" +colorama = {version = ">=0.4.3", markers = "sys_platform == \"win32\""} +httpx = ">=0.20.0,<0.27.0" +jinja2 = ">=3.0.0,<4.0.0" +pydantic = ">=2.1.1,<3.0.0" +python-dateutil = ">=2.8.1,<3.0.0" +pyyaml = ">=6.0,<7.0" +ruff = ">=0.1.2,<1.0.0" +shellingham = ">=1.3.2,<2.0.0" +typer = ">0.6,<0.10" +typing-extensions = ">=4.8.0,<5.0.0" + [[package]] name = "packaging" version = "23.2" @@ -2835,6 +2859,17 @@ tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] testing = ["h5py (>=3.7.0)", "huggingface_hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools_rust (>=1.5.2)"] torch = ["safetensors[numpy]", "torch (>=1.10)"] +[[package]] +name = "shellingham" +version = "1.5.4" +description = "Tool to Detect Surrounding Shell" +optional = false +python-versions = ">=3.7" +files = [ + {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"}, + {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, +] + [[package]] name = "six" version = "1.16.0" @@ -3410,6 +3445,27 @@ build = ["cmake (>=3.18)", "lit"] tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)"] tutorials = ["matplotlib", "pandas", "tabulate"] +[[package]] +name = "typer" +version = "0.9.0" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +optional = false +python-versions = ">=3.6" +files = [ + {file = "typer-0.9.0-py3-none-any.whl", hash = "sha256:5d96d986a21493606a358cae4461bd8cdf83cbf33a5aa950ae629ca3b51467ee"}, + {file = "typer-0.9.0.tar.gz", hash = "sha256:50922fd79aea2f4751a8e0408ff10d2662bd0c8bbfa84755a699f3bada2978b2"}, +] + +[package.dependencies] +click = ">=7.1.1,<9.0.0" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"] + [[package]] name = "typing-extensions" version = "4.9.0" @@ -3691,4 +3747,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "d093075ba2c0fada13b8045f39f45e504eabbc653b2f337873dfaac787c52e68" +content-hash = "8088f8f17854f8b921f924d82a24a68ddcc4614bd128fa0c167ac63b9bcacf57" diff --git a/pyproject.toml b/pyproject.toml index 5bfa50b..bc3f2d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,14 @@ cmd = "pytest" help = "Update openapi.toml from the swagger docs" cmd = "python scripts/extract_openapi.py dewy.main:app" +[tool.poe.tasks.generate-client] +help = "Generate the openapi client" +cmd = "openapi-python-client generate --path=openapi.yaml --config=openapi_client_config.yaml" + +[tool.poe.tasks.update-client] +help = "Update the openapi client" +cmd = "openapi-python-client update --path=openapi.yaml --config=openapi_client_config.yaml" + [tool.poetry.dependencies] python = "^3.11" pydantic = "^2.5.3" @@ -67,6 +75,9 @@ pytest = "^7.4.4" pytest-asyncio = "^0.21.1" pytest-docker-fixtures = {extras = ["pg"], version = "^1.3.18"} asgi-lifespan = "^2.1.0" +openapi-python-client = "^0.17.2" +poethepoet = "^0.24.4" +dewy-client = { path = "./dewy-client" } [build-system] requires = ["poetry-core"] diff --git a/tests/conftest.py b/tests/conftest.py index 1e7c4ba..2b14b01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import pytest from asgi_lifespan import LifespanManager +from dewy_client import Client from httpx import AsyncClient pytest_plugins = ["pytest_docker_fixtures"] @@ -37,8 +38,10 @@ async def app(pg, event_loop): @pytest.fixture(scope="session") -async def client(app) -> AsyncClient: - async with AsyncClient(app=app, base_url="http://test") as client: +async def client(app) -> Client: + async with AsyncClient(app=app, base_url="http://test") as httpx_client: + client = Client(base_url="http://test") + client.set_async_httpx_client(httpx_client) yield client diff --git a/tests/test_collection.py b/tests/test_collection.py index 7debc09..38fa845 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -1,31 +1,30 @@ import random import string +from dewy_client.api.default import add_collection, get_collection, list_collections +from dewy_client.models import CollectionCreate + async def test_create_collection(client): name = "".join(random.choices(string.ascii_lowercase, k=5)) - create_response = await client.put("/api/collections/", json={"name": name}) - assert create_response.status_code == 200 + collection = await add_collection.asyncio( + client=client, body=CollectionCreate(name=name) + ) - json = create_response.json() - assert json["name"] == name - assert json["text_embedding_model"] == "openai:text-embedding-ada-002" - assert json["text_distance_metric"] == "cosine" + assert collection.name == name + assert collection.text_embedding_model == "openai:text-embedding-ada-002" + assert collection.text_distance_metric == "cosine" - collection_id = json["id"] + collection_id = collection.id - list_response = await client.get("/api/collections/") - assert list_response.status_code == 200 + list_response = await list_collections.asyncio(client=client) # "find" the collection with the new collection ID, since # other tests may have created other collections - json = list_response.json() - collection_row = next(x for x in list_response.json() if x["id"] == collection_id) + collection_row = next(x for x in list_response if x.id == collection_id) assert collection_row is not None - assert collection_row["name"] == name + assert collection_row.name == name - get_response = await client.get(f"/api/collections/{collection_id}") - assert get_response.status_code == 200 + get_response = await get_collection.asyncio(collection_id, client=client) - json = get_response.json() - assert collection_row["name"] == name + assert get_response.name == name diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 1743dde..a0d875d 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -1,108 +1,62 @@ import random import string import time -from typing import List -from pydantic import TypeAdapter - -from dewy.chunk.models import Chunk, RetrieveRequest, RetrieveResponse -from dewy.document.models import AddDocumentRequest, Document +import pytest +from dewy_client.api.default import ( + add_collection, + add_document, + get_document, + list_chunks, + retrieve_chunks, +) +from dewy_client.models import ( + AddDocumentRequest, + CollectionCreate, + IngestState, + RetrieveRequest, +) SKELETON_OF_THOUGHT_PDF = "https://arxiv.org/pdf/2307.15337.pdf" -async def create_collection(client, text_embedding_model: str) -> int: +@pytest.mark.parametrize( + "embedding_model", ["openai:text-embedding-ada-002", "hf:BAAI/bge-small-en"] +) +async def test_index_retrieval(client, embedding_model): name = "".join(random.choices(string.ascii_lowercase, k=5)) - create_response = await client.put("/api/collections/", json={"name": name}) - assert create_response.status_code == 200 - - return create_response.json()["id"] - -async def ingest(client, collection: int, url: str) -> int: - add_request = AddDocumentRequest(collection_id=collection, url=url) - add_response = await client.put( - "/api/documents/", content=add_request.model_dump_json() + collection = await add_collection.asyncio( + client=client, + body=CollectionCreate(name=name, text_embedding_model=embedding_model), ) - assert add_response.status_code == 200 - - document_id = add_response.json()["id"] - - # TODO(https://github.com/DewyKB/dewy/issues/34): Move waiting to the server - # and eliminate need to poll. - status = await client.get(f"/api/documents/{document_id}") - while status.json()["ingest_state"] != "ingested": - time.sleep(1) - status = await client.get(f"/api/documents/{document_id}") - - return document_id - -async def list_chunks(client, collection: int, document: int): - response = await client.get( - "/api/chunks/", params={"collection_id": collection, "document_id": document} + document = await add_document.asyncio( + client=client, + body=AddDocumentRequest( + url=SKELETON_OF_THOUGHT_PDF, collection_id=collection.id + ), ) - assert response.status_code == 200 - ta = TypeAdapter(List[Chunk]) - return ta.validate_json(response.content) - -async def get_document(client, document_id: int) -> Document: - response = await client.get(f"/api/documents/{document_id}") - assert response.status_code == 200 - assert response - return Document.model_validate_json(response.content) - - -async def retrieve(client, collection: int, query: str) -> RetrieveResponse: - request = RetrieveRequest( - collection_id=collection, query=query, include_image_chunks=False - ) - - response = await client.post( - "/api/chunks/retrieve", content=request.model_dump_json() - ) - assert response.status_code == 200 - return RetrieveResponse.model_validate_json(response.content) - - -async def test_e2e_openai_ada002(client): - collection = await create_collection(client, "openai:text-embedding-ada-002") - document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) - - document = await get_document(client, document_id) + while document.ingest_state != IngestState.INGESTED: + time.sleep(0.2) + document = await get_document.asyncio(document.id, client=client) assert document.extracted_text.startswith("Skeleton-of-Thought") - chunks = await list_chunks(client, collection, document_id) - assert len(chunks) > 0 - assert chunks[0].document_id == document_id - - results = await retrieve( - client, collection, "outline the steps to using skeleton-of-thought prompting" + chunks = await list_chunks.asyncio( + client=client, collection_id=collection.id, document_id=document.id ) - assert len(results.text_results) > 0 - print(results.text_results) - - assert results.text_results[0].document_id == document_id - assert "skeleton" in results.text_results[0].text.lower() - - -async def test_e2e_hf_bge_small(client): - collection = await create_collection(client, "hf:BAAI/bge-small-en") - document_id = await ingest(client, collection, SKELETON_OF_THOUGHT_PDF) - - document = await get_document(client, document_id) - assert document.extracted_text.startswith("Skeleton-of-Thought") - - chunks = await list_chunks(client, collection, document_id) assert len(chunks) > 0 - assert chunks[0].document_id == document_id - - results = await retrieve( - client, collection, "outline the steps to using skeleton-of-thought prompting" + assert chunks[0].document_id == document.id + + retrieved = await retrieve_chunks.asyncio( + client=client, + body=RetrieveRequest( + collection_id=collection.id, + query="outline the steps to using skeleton-of-thought prompting", + ), ) - assert len(results.text_results) > 0 - print(results.text_results) + assert len(retrieved.text_results) > 0 - assert results.text_results[0].document_id == document_id - assert "skeleton" in results.text_results[0].text.lower() + assert retrieved.text_results[0].document_id == document.id + assert "skeleton" in retrieved.text_results[0].text.lower()