Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: preventing KeyError("_highlightResult") #550

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions algoliasearch/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __iter__(self):


class PaginatorIterator(Iterator):
nbHits = 0

def __init__(self, transporter, index_name, request_options=None):
# type: (Transporter, str, Optional[Union[dict, RequestOptions]]) -> None # noqa: E501

Expand All @@ -48,16 +50,16 @@ def __init__(self, transporter, index_name, request_options=None):

def __next__(self):
# type: () -> dict

if self._raw_response:

if len(self._raw_response["hits"]):
hit = self._raw_response["hits"].pop(0)

hit.pop("_highlightResult")

return hit

if self._raw_response["nbHits"] < self._data["hitsPerPage"]:
if self.nbHits < self._data["hitsPerPage"]:
self._raw_response = {}
self._data = {
"hitsPerPage": 1000,
Expand All @@ -68,6 +70,7 @@ def __next__(self):
self._raw_response = self._transporter.read(
Verb.POST, self.get_endpoint(), self._data, self._request_options
)
self.nbHits = len(self._raw_response["hits"])

self._data["page"] += 1

Expand Down
5 changes: 4 additions & 1 deletion algoliasearch/iterators_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@


class PaginatorIteratorAsync(Iterator):
nbHits = 0

def __init__(self, transporter, index_name, request_options=None):
# type: (Transporter, str, Optional[Union[dict, RequestOptions]]) -> None # noqa: E501

Expand Down Expand Up @@ -38,7 +40,7 @@ def __anext__(self): # type: ignore

return hit

if self._raw_response["nbHits"] < self._data["hitsPerPage"]:
if self.nbHits < self._data["hitsPerPage"]:
self._raw_response = {}
self._data = {
"hitsPerPage": 1000,
Expand All @@ -49,6 +51,7 @@ def __anext__(self): # type: ignore
self._raw_response = yield from self._transporter.read(
Verb.POST, self.get_endpoint(), self._data, self._request_options
)
self.nbHits = len(self._raw_response["hits"])

self._data["page"] += 1

Expand Down
41 changes: 41 additions & 0 deletions tests/features/test_search_index.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# -*- coding: utf-8 -*-
import sys
import unittest
import json
import requests

from requests.models import Response

from algoliasearch.exceptions import RequestException, ObjectNotFoundException
from algoliasearch.responses import MultipleResponse
from algoliasearch.search_client import SearchClient
from algoliasearch.search_index import SearchIndex
from tests.helpers.factory import Factory as F
from tests.helpers.misc import Unicode, rule_without_metadata
from unittest.mock import MagicMock


class TestSearchIndex(unittest.TestCase):
Expand Down Expand Up @@ -450,6 +456,41 @@ def test_synonyms(self):
# and check that the number of returned synonyms is equal to 0
self.assertEqual(self.index.search_synonyms("")["nbHits"], 0)

def test_browse_rules(self):
def side_effect(req, **kwargs):
hits = [{"objectID": i, "_highlightResult": None} for i in range(0, 1000)]
page = json.loads(req.body)["page"]

if page == 3:
hits = hits[0:800]

response = Response()
response.status_code = 200
response._content = str.encode(
json.dumps(
{
"hits": hits,
"nbHits": 3800,
"page": page,
"nbPages": 3,
}
)
)

return response

client = SearchClient.create("foo", "bar")
client._transporter._requester._session = requests.Session()
client._transporter._requester._session.send = MagicMock(name="send")
client._transporter._requester._session.send.side_effect = side_effect
index = F.index(client, "test")

rules = index.browse_rules()

len_rules = len(list(rules))

self.assertEqual(len_rules, 3800)

def test_rules(self):
responses = MultipleResponse()

Expand Down