diff --git a/astropylibrarian/algolia/client.py b/astropylibrarian/algolia/client.py index d5b0791..477b656 100644 --- a/astropylibrarian/algolia/client.py +++ b/astropylibrarian/algolia/client.py @@ -6,22 +6,16 @@ import logging import uuid from copy import deepcopy -from typing import ( - TYPE_CHECKING, - Any, - AsyncIterator, - Iterator, - Type, -) +from types import TracebackType +from typing import Any, AsyncIterator, Iterator, Type, Union -from algoliasearch.search_client import SearchClient +from algoliasearch.search.client import SearchClient +from algoliasearch.search.models.batch_response import BatchResponse +from algoliasearch.search.models.browse_params_object import BrowseParamsObject +from algoliasearch.search.models.browse_response import BrowseResponse +from algoliasearch.search.models.deleted_at_response import DeletedAtResponse -if TYPE_CHECKING: - from types import TracebackType - - from algoliasearch.search_index_async import SearchIndexAsync - -AlgoliaIndexType = SearchIndexAsync | "MockAlgoliaIndex" +AlgoliaIndexType = Union["AlgoliaIndex", "MockAlgoliaIndex"] """Type annotation alias supporting the return types of the `AlgoliaIndex` and `MockAlgoliaIndex` context managers. """ @@ -73,12 +67,10 @@ class AlgoliaIndex(BaseAlgoliaIndex): Name of the Algolia index. """ - async def __aenter__(self) -> SearchIndexAsync: + async def __aenter__(self) -> SearchClient: self._logger.debug("Opening algolia client") - self.algolia_client = SearchClient.create(self.app_id, self._key) - self._logger.debug("Initializing algolia index") - self.index = self.algolia_client.init_index(self.name) - return self.index + self.algolia_client = SearchClient(self.app_id, self._key) + return self.algolia_client async def __aexit__( self, @@ -87,9 +79,24 @@ async def __aexit__( tb: TracebackType | None, ) -> None: self._logger.debug("Closing algolia client") - await self.algolia_client.close_async() + await self.algolia_client.close() self._logger.debug("Finished closing algolia client") + async def browse_objects_async( + self, browse_params: BrowseParamsObject + ) -> BrowseResponse: + return await self.algolia_client.browse_objects( + index_name=self.name, aggregator=None, browse_params=browse_params + ) + + async def save_objects_async( + self, objects: list[dict[str, Any]] + ) -> list[BatchResponse]: + return self.algolia_client.save_objects(self.name, objects) + + async def delete_objects_async(self, objectids: list[str]) -> list[BatchResponse]: + return self.algolia_client.delete_objects(self.name, objectids) + class MockAlgoliaIndex(BaseAlgoliaIndex): """A mock Algolia index client. @@ -141,8 +148,10 @@ async def browse_objects_async( for _ in range(5): yield {} - async def delete_objects_async(self, objectids: list[str]) -> list[str]: - return objectids + async def delete_objects_async( + self, objectids: list[str] + ) -> list[DeletedAtResponse]: + return [DeletedAtResponse(task_id=0, deleted_at="") for _ in objectids] class MockMultiResponse: diff --git a/astropylibrarian/workflows/deleterooturl.py b/astropylibrarian/workflows/deleterooturl.py index 7db3f67..3df7e1d 100644 --- a/astropylibrarian/workflows/deleterooturl.py +++ b/astropylibrarian/workflows/deleterooturl.py @@ -1,28 +1,23 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst """Workflow for deleting all Algolia records associated with a root URL.""" -from __future__ import annotations - import logging -from typing import TYPE_CHECKING - -from astropylibrarian.algolia.client import escape_facet_value +from typing import Any, AsyncIterator -if TYPE_CHECKING: - from typing import Any, AsyncIterator, Dict, List +from algoliasearch.search.models.browse_params_object import BrowseParamsObject - from astropylibrarian.algolia.client import AlgoliaIndexType +from astropylibrarian.algolia.client import AlgoliaIndexType, escape_facet_value logger = logging.getLogger(__name__) async def delete_root_url( *, root_url: str, algolia_index: AlgoliaIndexType -) -> List[str]: +) -> list[str]: """Delete all Algolia records associated with a ``root_url``.""" - object_ids: List[str] = [] + object_ids: list[str] = [] async for record in search_for_records( - index=algolia_index, root_url=root_url + algolia_index=algolia_index, root_url=root_url ): if record["root_url"] != root_url: logger.warning( @@ -35,8 +30,8 @@ async def delete_root_url( logger.debug("Found %d objects for deletion", len(object_ids)) - response = await algolia_index.delete_objects_async(object_ids) - logger.debug("Algolia response:\n%s", response.raw_responses) + responses = await algolia_index.delete_objects_async(object_ids) + logger.debug("Algolia response:\n%s", responses) logger.info("Deleted %d objects", len(object_ids)) @@ -44,16 +39,13 @@ async def delete_root_url( async def search_for_records( - *, index: AlgoliaIndexType, root_url: str -) -> AsyncIterator[Dict[str, Any]]: + *, algolia_index: AlgoliaIndexType, root_url: str +) -> AsyncIterator[dict[str, Any]]: filters = f"root_url:{escape_facet_value(root_url)}" logger.debug("Filter:\n%s", filters) - async for result in index.browse_objects_async( - { - "filters": filters, - "attributesToRetrieve": ["root_url"], - "attributesToHighlight": [], - } - ): + obj = BrowseParamsObject( + filters=filters, attributes_to_retrieve=["root_url"], attributes_to_highlight=[] + ) + async for result in algolia_index.browse_objects_async(obj): yield result diff --git a/astropylibrarian/workflows/expirerecords.py b/astropylibrarian/workflows/expirerecords.py index 19877c9..a264a7e 100644 --- a/astropylibrarian/workflows/expirerecords.py +++ b/astropylibrarian/workflows/expirerecords.py @@ -6,6 +6,8 @@ import logging from typing import TYPE_CHECKING +from algoliasearch.search.models.browse_params_object import BrowseParamsObject + from astropylibrarian.algolia.client import escape_facet_value if TYPE_CHECKING: @@ -27,21 +29,20 @@ async def expire_old_records( " AND NOT " f"root_url:{escape_facet_value(root_url)}" ) - search_settings = { - "filters": filters, - "attributesToRetrieve": ["root_url", "index_epoch"], - "attributesToHighlight": [], - } + + obj = BrowseParamsObject( + filters=filters, + attributes_to_retrieve=["root_url", "index_epoch"], + attributes_to_highlight=[], + ) old_object_ids: List[str] = [] - async for r in algolia_index.browse_objects_async(search_settings): + async for r in algolia_index.browse_objects_async(obj): # Double check that we're deleting the right things. if r["root_url"] != root_url: logger.warning("root_url does not match: %s", r["baseUrl"]) continue if r["surrogateKey"] == index_epoch: - logger.warning( - "index_epoch matches current epoch: %s", r["index_epoch"] - ) + logger.warning("index_epoch matches current epoch: %s", r["index_epoch"]) continue old_object_ids.append(r["objectID"]) diff --git a/astropylibrarian/workflows/indextutorial.py b/astropylibrarian/workflows/indextutorial.py index c8a066b..977d11f 100644 --- a/astropylibrarian/workflows/indextutorial.py +++ b/astropylibrarian/workflows/indextutorial.py @@ -13,7 +13,7 @@ from pathlib import Path from typing import TYPE_CHECKING, List -import algoliasearch.exceptions +from algoliasearch.http.exceptions import RequestException from astropylibrarian.algolia.client import generate_index_epoch from astropylibrarian.reducers.tutorial import get_tutorial_reducer @@ -171,7 +171,7 @@ async def index_tutorial( saved_object_ids: List[str] = [] try: response = await algolia_index.save_objects_async(records) - except algoliasearch.exceptions.RequestException as e: + except RequestException as e: logger.error( "Error saving objects for tutorial %s:\n%s", tutorial_html.url, diff --git a/pyproject.toml b/pyproject.toml index e9c173a..8efd5af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ classifiers = [ dependencies = [ "lxml", "cssselect", - "algoliasearch", + "algoliasearch>=4,<5", "aiohttp", "async_timeout", "PyYAML",