diff --git a/api/src/adapters/search/opensearch_client.py b/api/src/adapters/search/opensearch_client.py index cb97a9c8c..24d3bdb1d 100644 --- a/api/src/adapters/search/opensearch_client.py +++ b/api/src/adapters/search/opensearch_client.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Sequence +from typing import Any, Generator, Iterable import opensearchpy @@ -75,7 +75,7 @@ def delete_index(self, index_name: str) -> None: def bulk_upsert( self, index_name: str, - records: Sequence[dict[str, Any]], + records: Iterable[dict[str, Any]], primary_key_field: str, *, refresh: bool = True @@ -103,10 +103,51 @@ def bulk_upsert( logger.info( "Upserting records to %s", index_name, - extra={"index_name": index_name, "record_count": int(len(bulk_operations) / 2)}, + extra={ + "index_name": index_name, + "record_count": int(len(bulk_operations) / 2), + "operation": "update", + }, ) self._client.bulk(index=index_name, body=bulk_operations, refresh=refresh) + def bulk_delete(self, index_name: str, ids: Iterable[Any], *, refresh: bool = True) -> None: + """ + Bulk delete records from an index + + See: https://opensearch.org/docs/latest/api-reference/document-apis/bulk/ for details. + In this method, we delete records based on the IDs passed in. + """ + bulk_operations = [] + + for _id in ids: + # { "delete": { "_id": "tt2229499" } } + bulk_operations.append({"delete": {"_id": _id}}) + + logger.info( + "Deleting records from %s", + index_name, + extra={ + "index_name": index_name, + "record_count": len(bulk_operations), + "operation": "delete", + }, + ) + self._client.bulk(index=index_name, body=bulk_operations, refresh=refresh) + + def index_exists(self, index_name: str) -> bool: + """ + Check if an index OR alias exists by a given name + """ + return self._client.indices.exists(index_name) + + def alias_exists(self, alias_name: str) -> bool: + """ + Check if an alias exists + """ + existing_index_mapping = self._client.cat.aliases(alias_name, format="json") + return len(existing_index_mapping) > 0 + def swap_alias_index( self, index_name: str, alias_name: str, *, delete_prior_indexes: bool = False ) -> None: @@ -144,11 +185,71 @@ def search_raw(self, index_name: str, search_query: dict) -> dict: return self._client.search(index=index_name, body=search_query) def search( - self, index_name: str, search_query: dict, include_scores: bool = True + self, + index_name: str, + search_query: dict, + include_scores: bool = True, + params: dict | None = None, ) -> SearchResponse: - response = self._client.search(index=index_name, body=search_query) + if params is None: + params = {} + + response = self._client.search(index=index_name, body=search_query, params=params) return SearchResponse.from_opensearch_response(response, include_scores) + def scroll( + self, + index_name: str, + search_query: dict, + include_scores: bool = True, + duration: str = "10m", + ) -> Generator[SearchResponse, None, None]: + """ + Scroll (iterate) over a large result set a given search query. + + This query uses additional resources to keep the response open, but + keeps a consistent set of results and is useful for backend processes + that need to fetch a large amount of search data. After processing the results, + the scroll lock is closed for you. + + This method is setup as a generator method and the results can be iterated over:: + + for response in search_client.scroll("my_index", {"size": 10000}): + for record in response.records: + process_record(record) + + + See: https://opensearch.org/docs/latest/api-reference/scroll/ + """ + + # start scroll + response = self.search( + index_name=index_name, + search_query=search_query, + include_scores=include_scores, + params={"scroll": duration}, + ) + scroll_id = response.scroll_id + + yield response + + # iterate + while True: + raw_response = self._client.scroll({"scroll_id": scroll_id, "scroll": duration}) + response = SearchResponse.from_opensearch_response(raw_response, include_scores) + + # The scroll ID can change between queries according to the docs, so we + # keep updating the value while iterating in case they change. + scroll_id = response.scroll_id + + if len(response.records) == 0: + break + + yield response + + # close scroll + self._client.clear_scroll(scroll_id=scroll_id) + def _get_connection_parameters(opensearch_config: OpensearchConfig) -> dict[str, Any]: # TODO - we'll want to add the AWS connection params here when we set that up diff --git a/api/src/adapters/search/opensearch_response.py b/api/src/adapters/search/opensearch_response.py index c8bb16cb6..a54c6ecc7 100644 --- a/api/src/adapters/search/opensearch_response.py +++ b/api/src/adapters/search/opensearch_response.py @@ -10,6 +10,8 @@ class SearchResponse: aggregations: dict[str, dict[str, int]] + scroll_id: str | None + @classmethod def from_opensearch_response( cls, raw_json: dict[str, typing.Any], include_scores: bool = True @@ -40,6 +42,8 @@ def from_opensearch_response( ] } """ + scroll_id = raw_json.get("_scroll_id", None) + hits = raw_json.get("hits", {}) hits_total = hits.get("total", {}) total_records = hits_total.get("value", 0) @@ -59,7 +63,7 @@ def from_opensearch_response( raw_aggs: dict[str, dict[str, typing.Any]] = raw_json.get("aggregations", {}) aggregations = _parse_aggregations(raw_aggs) - return cls(total_records, records, aggregations) + return cls(total_records, records, aggregations, scroll_id) def _parse_aggregations( diff --git a/api/src/search/backend/load_opportunities_to_index.py b/api/src/search/backend/load_opportunities_to_index.py index 630ecf616..dcf778037 100644 --- a/api/src/search/backend/load_opportunities_to_index.py +++ b/api/src/search/backend/load_opportunities_to_index.py @@ -38,21 +38,52 @@ def __init__( self, db_session: db.Session, search_client: search.SearchClient, + is_full_refresh: bool = True, config: LoadOpportunitiesToIndexConfig | None = None, ) -> None: super().__init__(db_session) self.search_client = search_client + self.is_full_refresh = is_full_refresh if config is None: config = LoadOpportunitiesToIndexConfig() self.config = config - current_timestamp = get_now_us_eastern_datetime().strftime("%Y-%m-%d_%H-%M-%S") - self.index_name = f"{self.config.index_prefix}-{current_timestamp}" + if is_full_refresh: + current_timestamp = get_now_us_eastern_datetime().strftime("%Y-%m-%d_%H-%M-%S") + self.index_name = f"{self.config.index_prefix}-{current_timestamp}" + else: + self.index_name = self.config.alias_name self.set_metrics({"index_name": self.index_name}) def run_task(self) -> None: + if self.is_full_refresh: + logger.info("Running full refresh") + self.full_refresh() + else: + logger.info("Running incremental load") + self.incremental_updates_and_deletes() + + def incremental_updates_and_deletes(self) -> None: + existing_opportunity_ids = self.fetch_existing_opportunity_ids_in_index() + + # load the records incrementally + # TODO - The point of this incremental load is to support upcoming work + # to load only opportunities that have changes as we'll eventually be indexing + # files which will take longer. However - the structure of the data isn't yet + # known so I want to hold on actually setting up any change-detection logic + loaded_opportunity_ids = set() + for opp_batch in self.fetch_opportunities(): + loaded_opportunity_ids.update(self.load_records(opp_batch)) + + # Delete + opportunity_ids_to_delete = existing_opportunity_ids - loaded_opportunity_ids + + if len(opportunity_ids_to_delete) > 0: + self.search_client.bulk_delete(self.index_name, opportunity_ids_to_delete) + + def full_refresh(self) -> None: # create the index self.search_client.create_index( self.index_name, @@ -93,11 +124,32 @@ def fetch_opportunities(self) -> Iterator[Sequence[Opportunity]]: .partitions() ) - def load_records(self, records: Sequence[Opportunity]) -> None: + def fetch_existing_opportunity_ids_in_index(self) -> set[int]: + if not self.search_client.alias_exists(self.index_name): + raise RuntimeError( + "Alias %s does not exist, please run the full refresh job before the incremental job" + % self.index_name + ) + + opportunity_ids: set[int] = set() + + for response in self.search_client.scroll( + self.config.alias_name, + {"size": 10000, "_source": ["opportunity_id"]}, + include_scores=False, + ): + for record in response.records: + opportunity_ids.add(record["opportunity_id"]) + + return opportunity_ids + + def load_records(self, records: Sequence[Opportunity]) -> set[int]: logger.info("Loading batch of opportunities...") schema = OpportunityV1Schema() json_records = [] + loaded_opportunity_ids = set() + for record in records: logger.info( "Preparing opportunity for upload to search index", @@ -109,4 +161,8 @@ def load_records(self, records: Sequence[Opportunity]) -> None: json_records.append(schema.dump(record)) self.increment(self.Metrics.RECORDS_LOADED) + loaded_opportunity_ids.add(record.opportunity_id) + self.search_client.bulk_upsert(self.index_name, json_records, "opportunity_id") + + return loaded_opportunity_ids diff --git a/api/src/search/backend/load_search_data.py b/api/src/search/backend/load_search_data.py index cf6f0445f..5b82e5a6d 100644 --- a/api/src/search/backend/load_search_data.py +++ b/api/src/search/backend/load_search_data.py @@ -1,3 +1,5 @@ +import click + import src.adapters.db as db import src.adapters.search as search from src.adapters.db import flask_db @@ -8,8 +10,13 @@ @load_search_data_blueprint.cli.command( "load-opportunity-data", help="Load opportunity data from our database to the search index" ) +@click.option( + "--full-refresh/--incremental", + default=True, + help="Whether to run a full refresh, or only incrementally update oppportunities", +) @flask_db.with_db_session() -def load_opportunity_data(db_session: db.Session) -> None: +def load_opportunity_data(db_session: db.Session, full_refresh: bool) -> None: search_client = search.SearchClient() - LoadOpportunitiesToIndex(db_session, search_client).run() + LoadOpportunitiesToIndex(db_session, search_client, full_refresh).run() diff --git a/api/tests/src/adapters/search/test_opensearch_client.py b/api/tests/src/adapters/search/test_opensearch_client.py index 916c6effd..8de2c2cc9 100644 --- a/api/tests/src/adapters/search/test_opensearch_client.py +++ b/api/tests/src/adapters/search/test_opensearch_client.py @@ -65,6 +65,25 @@ def test_bulk_upsert(search_client, generic_index): assert search_client._client.get(generic_index, record["id"])["_source"] == record +def test_bulk_delete(search_client, generic_index): + records = [ + {"id": 1, "title": "Green Eggs & Ham", "notes": "why are the eggs green?"}, + {"id": 2, "title": "The Cat in the Hat", "notes": "silly cat wears a hat"}, + {"id": 3, "title": "One Fish, Two Fish, Red Fish, Blue Fish", "notes": "fish"}, + ] + + search_client.bulk_upsert(generic_index, records, primary_key_field="id") + + search_client.bulk_delete(generic_index, [1]) + + resp = search_client.search(generic_index, {}, include_scores=False) + assert resp.records == records[1:] + + search_client.bulk_delete(generic_index, [2, 3]) + resp = search_client.search(generic_index, {}, include_scores=False) + assert resp.records == [] + + def test_swap_alias_index(search_client, generic_index): alias_name = f"tmp-alias-{uuid.uuid4().int}" @@ -101,3 +120,76 @@ def test_swap_alias_index(search_client, generic_index): # Verify the tmp one was deleted assert search_client._client.indices.exists(tmp_index) is False + + +def test_index_or_alias_exists(search_client, generic_index): + # Create a few aliased indexes + index_a = f"test-index-a-{uuid.uuid4().int}" + index_b = f"test-index-b-{uuid.uuid4().int}" + index_c = f"test-index-c-{uuid.uuid4().int}" + + search_client.create_index(index_a) + search_client.create_index(index_b) + search_client.create_index(index_c) + + alias_index_a = f"test-alias-a-{uuid.uuid4().int}" + alias_index_b = f"test-alias-b-{uuid.uuid4().int}" + alias_index_c = f"test-alias-c-{uuid.uuid4().int}" + + search_client.swap_alias_index(index_a, alias_index_a) + search_client.swap_alias_index(index_b, alias_index_b) + search_client.swap_alias_index(index_c, alias_index_c) + + # Checking the indexes directly - we expect the index method to return true + # and the alias method to not + assert search_client.index_exists(index_a) is True + assert search_client.index_exists(index_b) is True + assert search_client.index_exists(index_c) is True + + assert search_client.alias_exists(index_a) is False + assert search_client.alias_exists(index_b) is False + assert search_client.alias_exists(index_c) is False + + # We just created these aliases, they should exist + assert search_client.index_exists(alias_index_a) is True + assert search_client.index_exists(alias_index_b) is True + assert search_client.index_exists(alias_index_c) is True + + assert search_client.alias_exists(alias_index_a) is True + assert search_client.alias_exists(alias_index_b) is True + assert search_client.alias_exists(alias_index_c) is True + + # Other random things won't be found for either case + assert search_client.index_exists("test-index-a") is False + assert search_client.index_exists("asdasdasd") is False + assert search_client.index_exists(alias_index_a + "-other") is False + + assert search_client.alias_exists("test-index-a") is False + assert search_client.alias_exists("asdasdasd") is False + assert search_client.alias_exists(alias_index_a + "-other") is False + + +def test_scroll(search_client, generic_index): + records = [ + {"id": 1, "title": "Green Eggs & Ham", "notes": "why are the eggs green?"}, + {"id": 2, "title": "The Cat in the Hat", "notes": "silly cat wears a hat"}, + {"id": 3, "title": "One Fish, Two Fish, Red Fish, Blue Fish", "notes": "fish"}, + {"id": 4, "title": "Fox in Socks", "notes": "why he wearing socks?"}, + {"id": 5, "title": "The Lorax", "notes": "trees"}, + {"id": 6, "title": "Oh, the Places You'll Go", "notes": "graduation gift"}, + {"id": 7, "title": "Hop on Pop", "notes": "Let him sleep"}, + {"id": 8, "title": "How the Grinch Stole Christmas", "notes": "who"}, + ] + + search_client.bulk_upsert(generic_index, records, primary_key_field="id") + + results = [] + + for response in search_client.scroll(generic_index, {"size": 3}): + assert response.total_records == 8 + results.append(response) + + assert len(results) == 3 + assert len(results[0].records) == 3 + assert len(results[1].records) == 3 + assert len(results[2].records) == 2 diff --git a/api/tests/src/search/backend/test_load_opportunities_to_index.py b/api/tests/src/search/backend/test_load_opportunities_to_index.py index e939f569f..9a3961f2b 100644 --- a/api/tests/src/search/backend/test_load_opportunities_to_index.py +++ b/api/tests/src/search/backend/test_load_opportunities_to_index.py @@ -4,17 +4,18 @@ LoadOpportunitiesToIndex, LoadOpportunitiesToIndexConfig, ) +from src.util.datetime_util import get_now_us_eastern_datetime from tests.conftest import BaseTestClass from tests.src.db.models.factories import OpportunityFactory -class TestLoadOpportunitiesToIndex(BaseTestClass): +class TestLoadOpportunitiesToIndexFullRefresh(BaseTestClass): @pytest.fixture(scope="class") def load_opportunities_to_index(self, db_session, search_client, opportunity_index_alias): config = LoadOpportunitiesToIndexConfig( alias_name=opportunity_index_alias, index_prefix="test-load-opps" ) - return LoadOpportunitiesToIndex(db_session, search_client, config) + return LoadOpportunitiesToIndex(db_session, search_client, True, config) def test_load_opportunities_to_index( self, @@ -83,3 +84,70 @@ def test_load_opportunities_to_index( assert set([opp.opportunity_id for opp in opportunities]) == set( [record["opportunity_id"] for record in resp.records] ) + + +class TestLoadOpportunitiesToIndexPartialRefresh(BaseTestClass): + @pytest.fixture(scope="class") + def load_opportunities_to_index(self, db_session, search_client, opportunity_index_alias): + config = LoadOpportunitiesToIndexConfig( + alias_name=opportunity_index_alias, index_prefix="test-load-opps" + ) + return LoadOpportunitiesToIndex(db_session, search_client, False, config) + + def test_load_opportunities_to_index( + self, + truncate_opportunities, + enable_factory_create, + db_session, + search_client, + opportunity_index_alias, + load_opportunities_to_index, + ): + index_name = "partial-refresh-index-" + get_now_us_eastern_datetime().strftime( + "%Y-%m-%d_%H-%M-%S" + ) + search_client.create_index(index_name) + search_client.swap_alias_index( + index_name, load_opportunities_to_index.config.alias_name, delete_prior_indexes=True + ) + + # Load a bunch of records into the DB + opportunities = [] + opportunities.extend(OpportunityFactory.create_batch(size=6, is_posted_summary=True)) + opportunities.extend(OpportunityFactory.create_batch(size=3, is_forecasted_summary=True)) + opportunities.extend(OpportunityFactory.create_batch(size=2, is_closed_summary=True)) + opportunities.extend( + OpportunityFactory.create_batch(size=8, is_archived_non_forecast_summary=True) + ) + opportunities.extend( + OpportunityFactory.create_batch(size=6, is_archived_forecast_summary=True) + ) + + load_opportunities_to_index.run() + + resp = search_client.search(opportunity_index_alias, {"size": 100}) + assert resp.total_records == len(opportunities) + + # Add a few more opportunities that will be created + opportunities.extend(OpportunityFactory.create_batch(size=3, is_posted_summary=True)) + + # Delete some opportunities + opportunities_to_delete = [opportunities.pop(), opportunities.pop(), opportunities.pop()] + for opportunity in opportunities_to_delete: + db_session.delete(opportunity) + + load_opportunities_to_index.run() + + resp = search_client.search(opportunity_index_alias, {"size": 100}) + assert resp.total_records == len(opportunities) + + def test_load_opportunities_to_index_index_does_not_exist(self, db_session, search_client): + config = LoadOpportunitiesToIndexConfig( + alias_name="fake-index-that-will-not-exist", index_prefix="test-load-opps" + ) + load_opportunities_to_index = LoadOpportunitiesToIndex( + db_session, search_client, False, config + ) + + with pytest.raises(RuntimeError, match="please run the full refresh job"): + load_opportunities_to_index.run()