From ccca1fb58dd828af6ddce051f53df7330bb10ae6 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Thu, 23 May 2024 14:13:37 +0100 Subject: [PATCH] Added point in time support and the Search.iterate() method --- elasticsearch_dsl/_async/search.py | 55 ++++++++++++++++++++ elasticsearch_dsl/_sync/search.py | 53 +++++++++++++++++++ tests/test_integration/_async/test_search.py | 31 +++++++++++ tests/test_integration/_sync/test_search.py | 29 +++++++++++ utils/run-unasync.py | 1 + 5 files changed, 169 insertions(+) diff --git a/elasticsearch_dsl/_async/search.py b/elasticsearch_dsl/_async/search.py index faffb17f..1dc51477 100644 --- a/elasticsearch_dsl/_async/search.py +++ b/elasticsearch_dsl/_async/search.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import contextlib + from elasticsearch.exceptions import ApiError from elasticsearch.helpers import async_scan @@ -92,6 +94,8 @@ async def scan(self): pass to the underlying ``scan`` helper from ``elasticsearch-py`` - https://elasticsearch-py.readthedocs.io/en/master/helpers.html#elasticsearch.helpers.scan + The ``iterate()`` method should be preferred, as it provides similar + functionality using a point in time. """ es = get_connection(self._using) @@ -113,6 +117,57 @@ async def delete(self): ) ) + @contextlib.asynccontextmanager + async def point_in_time(self, keep_alive="1m"): + """ + Open a point in time (pit) that can be used across several searches. + + This method implements a context manager that returns a search object + configured to operate within the created pit. + + :arg keep_alive: the time to live for the point in time, renewed with each search request + + The following example shows how to paginate through all the documents of an index:: + + page_size = 10 + with Search(index="my-index")[:page_size].point_in_time() as s: + while True: + r = s.execute() # get a page of results + // ... do something with r.hits + + if len(r.hits) < page_size: + break # we reached the end + s = r.search_after() + """ + es = get_connection(self._using) + + pit = await es.open_point_in_time( + index=self._index or "*", keep_alive=keep_alive + ) + search = self.index().extra(pit={"id": pit["id"], "keep_alive": keep_alive}) + if not search._sort: + search = search.sort("_shard_doc") + yield search + await es.close_point_in_time(id=pit["id"]) + + async def iterate(self, keep_alive="1m"): + """ + Return a generator that iterates over all the documents matching the query. + + This method uses a point in time to provide consistent results even when + the index is changing. It should be preferred over ``scan()``. + + :arg keep_alive: the time to live for the point in time, renewed with each new search request + """ + async with self.point_in_time(keep_alive=keep_alive) as s: + while True: + r = await s.execute() + for hit in r: + yield hit + if len(r.hits) == 0: + break + s = r.search_after() + class AsyncMultiSearch(MultiSearchBase): """ diff --git a/elasticsearch_dsl/_sync/search.py b/elasticsearch_dsl/_sync/search.py index ae379237..38e1ae59 100644 --- a/elasticsearch_dsl/_sync/search.py +++ b/elasticsearch_dsl/_sync/search.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import contextlib + from elasticsearch.exceptions import ApiError from elasticsearch.helpers import scan @@ -88,6 +90,8 @@ def scan(self): pass to the underlying ``scan`` helper from ``elasticsearch-py`` - https://elasticsearch-py.readthedocs.io/en/master/helpers.html#elasticsearch.helpers.scan + The ``iterate()`` method should be preferred, as it provides similar + functionality using a point in time. """ es = get_connection(self._using) @@ -105,6 +109,55 @@ def delete(self): es.delete_by_query(index=self._index, body=self.to_dict(), **self._params) ) + @contextlib.contextmanager + def point_in_time(self, keep_alive="1m"): + """ + Open a point in time (pit) that can be used across several searches. + + This method implements a context manager that returns a search object + configured to operate within the created pit. + + :arg keep_alive: the time to live for the point in time, renewed with each search request + + The following example shows how to paginate through all the documents of an index:: + + page_size = 10 + with Search(index="my-index")[:page_size].point_in_time() as s: + while True: + r = s.execute() # get a page of results + // ... do something with r.hits + + if len(r.hits) < page_size: + break # we reached the end + s = r.search_after() + """ + es = get_connection(self._using) + + pit = es.open_point_in_time(index=self._index or "*", keep_alive=keep_alive) + search = self.index().extra(pit={"id": pit["id"], "keep_alive": keep_alive}) + if not search._sort: + search = search.sort("_shard_doc") + yield search + es.close_point_in_time(id=pit["id"]) + + def iterate(self, keep_alive="1m"): + """ + Return a generator that iterates over all the documents matching the query. + + This method uses a point in time to provide consistent results even when + the index is changing. It should be preferred over ``scan()``. + + :arg keep_alive: the time to live for the point in time, renewed with each new search request + """ + with self.point_in_time(keep_alive=keep_alive) as s: + while True: + r = s.execute() + for hit in r: + yield hit + if len(r.hits) == 0: + break + s = r.search_after() + class MultiSearch(MultiSearchBase): """ diff --git a/tests/test_integration/_async/test_search.py b/tests/test_integration/_async/test_search.py index 6d6a5ab9..2c329ee8 100644 --- a/tests/test_integration/_async/test_search.py +++ b/tests/test_integration/_async/test_search.py @@ -179,6 +179,37 @@ async def test_search_after_no_results(async_data_client): await r.search_after() +@pytest.mark.asyncio +async def test_point_in_time(async_data_client): + page_size = 7 + commits = [] + async with AsyncSearch(index="flat-git")[:page_size].point_in_time( + keep_alive="30s" + ) as s: + pit_id = s._extra["pit"]["id"] + while True: + r = await s.execute() + commits += r.hits + if len(r.hits) < page_size: + break + s = r.search_after() + assert pit_id == s._extra["pit"]["id"] + assert "30s" == s._extra["pit"]["keep_alive"] + + assert 52 == len(commits) + assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} + + +@pytest.mark.asyncio +async def test_iterate(async_data_client): + s = AsyncSearch(index="flat-git") + + commits = [commit async for commit in s.iterate()] + + assert 52 == len(commits) + assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} + + @pytest.mark.asyncio async def test_response_is_cached(async_data_client): s = Repository.search() diff --git a/tests/test_integration/_sync/test_search.py b/tests/test_integration/_sync/test_search.py index 09c31836..db1f23bf 100644 --- a/tests/test_integration/_sync/test_search.py +++ b/tests/test_integration/_sync/test_search.py @@ -171,6 +171,35 @@ def test_search_after_no_results(data_client): r.search_after() +@pytest.mark.sync +def test_point_in_time(data_client): + page_size = 7 + commits = [] + with Search(index="flat-git")[:page_size].point_in_time(keep_alive="30s") as s: + pit_id = s._extra["pit"]["id"] + while True: + r = s.execute() + commits += r.hits + if len(r.hits) < page_size: + break + s = r.search_after() + assert pit_id == s._extra["pit"]["id"] + assert "30s" == s._extra["pit"]["keep_alive"] + + assert 52 == len(commits) + assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} + + +@pytest.mark.sync +def test_iterate(data_client): + s = Search(index="flat-git") + + commits = [commit for commit in s.iterate()] + + assert 52 == len(commits) + assert {d["_id"] for d in FLAT_DATA} == {c.meta.id for c in commits} + + @pytest.mark.sync def test_response_is_cached(data_client): s = Repository.search() diff --git a/utils/run-unasync.py b/utils/run-unasync.py index aaaba069..797994fc 100644 --- a/utils/run-unasync.py +++ b/utils/run-unasync.py @@ -72,6 +72,7 @@ def main(check=False): "async_sleep": "sleep", "assert_awaited_once_with": "assert_called_once_with", "pytest_asyncio": "pytest", + "asynccontextmanager": "contextmanager", } rules = [ unasync.Rule(