From 18e30325f19241feb4128a6b280a2395a298d824 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Thu, 6 Jul 2023 17:35:13 +0200 Subject: [PATCH 1/9] feat: add filtering using several response status values --- .../server/apis/v1/handlers/datasets.py | 12 ++++++-- src/argilla/server/contexts/datasets.py | 29 +++++++++++++----- tests/server/api/v1/test_datasets.py | 30 +++++++++++-------- 3 files changed, 49 insertions(+), 22 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 3ec36ff6f1..c5f417c27f 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -119,7 +119,7 @@ async def list_current_user_dataset_records( db: AsyncSession = Depends(get_async_db), dataset_id: UUID, include: List[RecordInclude] = Query([], description="Relationships to include in the response"), - response_status: Optional[ResponseStatusFilter] = Query(None), + response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), current_user: User = Security(auth.get_current_user), @@ -129,7 +129,13 @@ async def list_current_user_dataset_records( await authorize(current_user, DatasetPolicyV1.get(dataset)) records = await datasets.list_records_by_dataset_id_and_user_id( - db, dataset_id, current_user.id, include=include, response_status=response_status, offset=offset, limit=limit + db, + dataset_id, + current_user.id, + include=include, + response_statuses=response_statuses, + offset=offset, + limit=limit, ) return Records(items=records) @@ -314,7 +320,7 @@ async def search_dataset_records( dataset_id: UUID, query: SearchRecordsQuery, include: List[RecordInclude] = Query([]), - response_status: Optional[ResponseStatusFilter] = Query(None), + response_status: List[ResponseStatusFilter] = Query([]), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), current_user: User = Security(auth.get_current_user), diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index 183ee02590..dd847b0a0a 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -16,7 +16,7 @@ from uuid import UUID from fastapi.encoders import jsonable_encoder -from sqlalchemy import and_, func, select +from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import contains_eager, joinedload, selectinload from argilla.server.contexts import accounts @@ -267,20 +267,35 @@ async def list_records_by_dataset_id_and_user_id( dataset_id: UUID, user_id: UUID, include: List[RecordInclude] = [], - response_status: Optional[ResponseStatusFilter] = None, + response_statuses: List[ResponseStatusFilter] = [], offset: int = 0, limit: int = LIST_RECORDS_LIMIT, ) -> List[Record]: + response_statuses_ = [ + ResponseStatus(response_status) + for response_status in response_statuses + if response_status != ResponseStatusFilter.missing + ] + + response_status_filter_expressions = [] + + if response_statuses_: + response_status_filter_expressions.append(Response.status.in_(response_statuses_)) + + if ResponseStatusFilter.missing in response_statuses: + response_status_filter_expressions.append(Response.status.is_(None)) + query = ( select(Record) .filter(Record.dataset_id == dataset_id) - .outerjoin(Response, and_(Response.record_id == Record.id, Response.user_id == user_id)) + .outerjoin( + Response, + and_(Response.record_id == Record.id, Response.user_id == user_id), + ) ) - if response_status == ResponseStatusFilter.missing: - query = query.filter(Response.status == None) # noqa: E711 - elif response_status is not None: - query = query.filter(Response.status == ResponseStatus(response_status)) + if response_status_filter_expressions: + query = query.filter(or_(*response_status_filter_expressions)) if RecordInclude.responses in include: query = query.options(contains_eager(Record.responses)) diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 8ba271204c..231c71132e 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -20,6 +20,7 @@ import pytest from argilla._constants import API_KEY_HEADER_NAME from argilla.server.apis.v1.handlers.datasets import LIST_DATASET_RECORDS_LIMIT_DEFAULT +from argilla.server.enums import ResponseStatusFilter from argilla.server.models import ( Dataset, DatasetStatus, @@ -934,10 +935,21 @@ async def create_records_with_response( await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status) -@pytest.mark.parametrize("response_status_filter", ["missing", "discarded", "submitted", "draft"]) +@pytest.mark.parametrize( + "response_status_filters", + [ + [ResponseStatusFilter.missing], + [ResponseStatusFilter.draft], + [ResponseStatusFilter.submitted], + [ResponseStatusFilter.discarded], + [ResponseStatusFilter.missing, ResponseStatusFilter.draft], + [ResponseStatusFilter.submitted, ResponseStatusFilter.discarded], + [ResponseStatusFilter.missing, ResponseStatus.draft, ResponseStatus.discarded], + ], +) @pytest.mark.asyncio async def test_list_current_user_dataset_records_with_response_status_filter( - client: TestClient, owner: "User", owner_auth_header: dict, response_status_filter: str + client: TestClient, owner: "User", owner_auth_header: dict, response_status_filters: List[ResponseStatusFilter] ): num_responses_per_status = 10 response_values = {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}} @@ -957,20 +969,14 @@ async def test_list_current_user_dataset_records_with_response_status_filter( other_dataset = await DatasetFactory.create() await RecordFactory.create_batch(size=2, dataset=other_dataset) - response = client.get( - f"/api/v1/me/datasets/{dataset.id}/records?response_status={response_status_filter}&include=responses", - headers=owner_auth_header, - ) + params = [("include", RecordInclude.responses.value)] + params.extend(("response_status", status_filter.value) for status_filter in response_status_filters) + response = client.get(f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params=params) assert response.status_code == 200 response_json = response.json() - assert len(response_json["items"]) == num_responses_per_status - - if response_status_filter == "missing": - assert all([len(record["responses"]) == 0 for record in response_json["items"]]) - else: - assert all([record["responses"][0]["status"] == response_status_filter for record in response_json["items"]]) + assert len(response_json["items"]) == num_responses_per_status * len(response_status_filters) @pytest.mark.asyncio From a64bdcb20c681a8e97b9d9992d57d5c5ae86c810 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 7 Jul 2023 13:03:32 +0200 Subject: [PATCH 2/9] feat: several response status filters in search endpoint --- .../server/apis/v1/handlers/datasets.py | 8 ++--- src/argilla/server/search_engine.py | 33 +++++++++++-------- 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index c5f417c27f..fb284b3552 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Security, status @@ -320,7 +320,7 @@ async def search_dataset_records( dataset_id: UUID, query: SearchRecordsQuery, include: List[RecordInclude] = Query([]), - response_status: List[ResponseStatusFilter] = Query([]), + response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), current_user: User = Security(auth.get_current_user), @@ -338,8 +338,8 @@ async def search_dataset_records( ) user_response_status_filter = None - if response_status: - user_response_status_filter = UserResponseStatusFilter(user=current_user, status=response_status) + if response_statuses: + user_response_status_filter = UserResponseStatusFilter(user=current_user, statuses=response_statuses) search_responses = await search_engine.search( dataset=dataset, diff --git a/src/argilla/server/search_engine.py b/src/argilla/server/search_engine.py index 6fb96a5dee..fa2e3a66da 100644 --- a/src/argilla/server/search_engine.py +++ b/src/argilla/server/search_engine.py @@ -78,7 +78,7 @@ class Query: @dataclasses.dataclass class UserResponseStatusFilter: user: User - status: ResponseStatusFilter + statuses: List[ResponseStatusFilter] @dataclasses.dataclass @@ -185,15 +185,12 @@ async def search( text_query = self._text_query_builder(dataset, text=query.text) - bool_query: dict = {"must": [text_query]} + bool_query = {"must": [text_query]} if user_response_status_filter: - bool_query["filter"] = self._response_status_filter_builder(user_response_status_filter) + bool_query["should"] = self._response_status_filter_builder(user_response_status_filter) + bool_query["minimum_should_match"] = 1 - body = { - "_source": False, - "query": {"bool": bool_query}, - # "sort": [{"_score": "desc"}, {"id": "asc"}], - } + body = {"_source": False, "query": {"bool": bool_query}} response = await self.client.search( index=self._index_name_for_dataset(dataset), @@ -230,7 +227,7 @@ def _mapping_for_fields(self, fields: List[Field]): def _dynamic_templates_for_question_responses(self, questions: List[Question]) -> List[dict]: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/dynamic-templates.html return [ - {"status_responses": {"path_match": f"responses.*.status", "mapping": {"type": "keyword"}}}, + {"status_responses": {"path_match": "responses.*.status", "mapping": {"type": "keyword"}}}, *[ { f"{question.name}_responses": { @@ -280,15 +277,23 @@ async def _get_index_or_raise(self, dataset: Dataset): def _index_name_for_dataset(dataset: Dataset): return f"rg.{dataset.id}" - def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter): + def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter) -> Dict[str, Any]: user_response_field = f"responses.{status_filter.user.username}" - if status_filter.status == ResponseStatusFilter.missing: + statuses = [ + ResponseStatus(status).value for status in status_filter.statuses if status != ResponseStatusFilter.missing + ] + + filters = [] + if ResponseStatusFilter.missing in status_filter.statuses: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-exists-query.html - return [{"bool": {"must_not": {"exists": {"field": user_response_field}}}}] - else: + filters.append({"bool": {"must_not": {"exists": {"field": user_response_field}}}}) + + if statuses: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html - return [{"term": {f"{user_response_field}.status": status_filter.status}}] + filters.append({"terms": {f"{user_response_field}.status": statuses}}) + + return filters async def get_search_engine() -> AsyncGenerator[SearchEngine, None]: From c4e1f602b68f815b8fecbdaa1a54cc629350499e Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 7 Jul 2023 13:51:41 +0200 Subject: [PATCH 3/9] feat: update `SearchEngine.search` with status unit test --- tests/server/api/v1/test_datasets.py | 4 +- tests/server/test_search_engine.py | 80 ++++++++++++++++++---------- 2 files changed, 55 insertions(+), 29 deletions(-) diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 231c71132e..16b1f016bf 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -944,7 +944,7 @@ async def create_records_with_response( [ResponseStatusFilter.discarded], [ResponseStatusFilter.missing, ResponseStatusFilter.draft], [ResponseStatusFilter.submitted, ResponseStatusFilter.discarded], - [ResponseStatusFilter.missing, ResponseStatus.draft, ResponseStatus.discarded], + [ResponseStatusFilter.missing, ResponseStatusFilter.draft, ResponseStatusFilter.discarded], ], ) @pytest.mark.asyncio @@ -3011,7 +3011,7 @@ async def test_search_dataset_records_with_response_status_filter( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=Query(text=TextQuery(q="Hello", field="input")), - user_response_status_filter=UserResponseStatusFilter(user=owner, status=ResponseStatus.submitted), + user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=ResponseStatus.submitted), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ) diff --git a/tests/server/test_search_engine.py b/tests/server/test_search_engine.py index 22183e8457..d896bd837e 100644 --- a/tests/server/test_search_engine.py +++ b/tests/server/test_search_engine.py @@ -13,7 +13,7 @@ # limitations under the License. import random -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List import pytest import pytest_asyncio @@ -109,15 +109,18 @@ async def test_banking_sentiment_dataset(elastic_search_engine: SearchEngine) -> ), await RecordFactory.create( dataset=dataset, - fields={"textId": "00002", "text": "Why was I charged for getting cash?", "label": "neutral"}, + fields={ + "textId": "00002", + "text": "I tried to make a payment with my card and it was declined.", + "label": "negative", + }, responses=[], ), await RecordFactory.create( dataset=dataset, fields={ "textId": "00003", - "text": "I deposited cash into my account a week ago and it is still not available," - " please tell me why? I need the cash back now.", + "text": "My credit card was declined when I tried to make a payment.", "label": "negative", }, responses=[], @@ -126,8 +129,8 @@ async def test_banking_sentiment_dataset(elastic_search_engine: SearchEngine) -> dataset=dataset, fields={ "textId": "00004", - "text": "Why was I charged for getting cash?", - "label": "neutral", + "text": "I made a successful payment towards my mortgage loan earlier today.", + "label": "positive", }, responses=[], ), @@ -135,8 +138,8 @@ async def test_banking_sentiment_dataset(elastic_search_engine: SearchEngine) -> dataset=dataset, fields={ "textId": "00005", - "text": "I tried to make a payment with my card and it was declined.", - "label": "negative", + "text": "Please confirm the receipt of my payment for the credit card bill due on the 15th.", + "label": "neutral", }, responses=[], ), @@ -144,7 +147,22 @@ async def test_banking_sentiment_dataset(elastic_search_engine: SearchEngine) -> dataset=dataset, fields={ "textId": "00006", - "text": "My credit card was declined when I tried to make a payment.", + "text": "Why was I charged for getting cash?", + "label": "neutral", + }, + responses=[], + ), + await RecordFactory.create( + dataset=dataset, + fields={"textId": "00007", "text": "Why was I charged for getting cash?", "label": "neutral"}, + responses=[], + ), + await RecordFactory.create( + dataset=dataset, + fields={ + "textId": "00008", + "text": "I deposited cash into my account a week ago and it is still not available," + " please tell me why? I need the cash back now.", "label": "negative", }, responses=[], @@ -309,23 +327,23 @@ async def test_create_index_with_existing_index( @pytest.mark.parametrize( ("query", "expected_items"), [ - ("card", 4), + ("card", 5), ("account", 1), - ("payment", 4), + ("payment", 6), ("cash", 3), ("negative", 4), ("00000", 1), - ("card payment", 4), + ("card payment", 5), ("nothing", 0), - (SearchQuery(text=TextQuery(q="card")), 4), + (SearchQuery(text=TextQuery(q="card")), 5), (SearchQuery(text=TextQuery(q="account")), 1), - (SearchQuery(text=TextQuery(q="payment")), 4), + (SearchQuery(text=TextQuery(q="payment")), 6), (SearchQuery(text=TextQuery(q="cash")), 3), - (SearchQuery(text=TextQuery(q="card payment")), 4), + (SearchQuery(text=TextQuery(q="card payment")), 5), (SearchQuery(text=TextQuery(q="nothing")), 0), (SearchQuery(text=TextQuery(q="negative", field="label")), 4), (SearchQuery(text=TextQuery(q="00000", field="textId")), 1), - (SearchQuery(text=TextQuery(q="card payment", field="text")), 4), + (SearchQuery(text=TextQuery(q="card payment", field="text")), 5), ], ) async def test_search_with_query_string( @@ -352,12 +370,15 @@ async def test_search_with_query_string( assert scores == sorted_scores @pytest.mark.parametrize( - "status", + "statuses, expected_items", [ - ResponseStatusFilter.discarded, - ResponseStatusFilter.submitted, - ResponseStatusFilter.draft, - ResponseStatusFilter.missing, + ([ResponseStatusFilter.missing], 6), + ([ResponseStatusFilter.draft], 2), + ([ResponseStatusFilter.submitted], 2), + ([ResponseStatusFilter.discarded], 2), + ([ResponseStatusFilter.missing, ResponseStatusFilter.draft], 6), + ([ResponseStatusFilter.submitted, ResponseStatusFilter.discarded], 4), + ([ResponseStatusFilter.missing, ResponseStatusFilter.draft, ResponseStatusFilter.discarded], 6), ], ) async def test_search_with_response_status_filter( @@ -365,14 +386,19 @@ async def test_search_with_response_status_filter( elastic_search_engine: SearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset, - status: ResponseStatusFilter, + statuses: List[ResponseStatusFilter], + expected_items: int, ): index_name = f"rg.{test_banking_sentiment_dataset.id}" user = await UserFactory.create() another_user = await UserFactory.create() - if status != ResponseStatusFilter.missing: - for record in test_banking_sentiment_dataset.records: + # Create two responses with the same status (one in each record) + for i, status in enumerate(statuses): + if status == ResponseStatusFilter.missing: + continue + offset = i * 2 + for record in test_banking_sentiment_dataset.records[offset : offset + 2]: users_responses = { f"{user.username}.status": status.value, f"{another_user.username}.status": status.value, @@ -383,10 +409,10 @@ async def test_search_with_response_status_filter( result = await elastic_search_engine.search( test_banking_sentiment_dataset, query=SearchQuery(text=TextQuery(q="payment")), - user_response_status_filter=UserResponseStatusFilter(user=user, status=status), + user_response_status_filter=UserResponseStatusFilter(user=user, statuses=statuses), ) - assert len(result.items) == 4 - assert result.total == 4 + assert len(result.items) == expected_items + assert result.total == expected_items @pytest.mark.parametrize(("offset", "limit"), [(0, 50), (10, 5), (0, 0), (90, 100)]) async def test_search_with_pagination( From 5eb15e20d30317be7fbd032c459491f984455982 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 7 Jul 2023 14:17:51 +0200 Subject: [PATCH 4/9] fix: wrong return type hint --- src/argilla/server/search_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/argilla/server/search_engine.py b/src/argilla/server/search_engine.py index fa2e3a66da..9434c5120a 100644 --- a/src/argilla/server/search_engine.py +++ b/src/argilla/server/search_engine.py @@ -277,7 +277,7 @@ async def _get_index_or_raise(self, dataset: Dataset): def _index_name_for_dataset(dataset: Dataset): return f"rg.{dataset.id}" - def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter) -> Dict[str, Any]: + def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter) -> List[Dict[str, Any]]: user_response_field = f"responses.{status_filter.user.username}" statuses = [ From ae4e42a59dcce87272548202cb755afb7492662d Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 7 Jul 2023 14:23:09 +0200 Subject: [PATCH 5/9] docs: update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b44cec7b52..bf987534d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ These are the section headers that we use: - `User.workspaces` is no longer an attribute but a property, and is calling `list_user_workspaces` to list all the workspace names for a given user ID ([#3334](https://github.com/argilla-io/argilla/pull/3334)) - Renamed `FeedbackDatasetConfig` to `DatasetConfig` and export/import from YAML as default instead of JSON (just used internally on `push_to_huggingface` and `from_huggingface` methods of `FeedbackDataset`) ([#3326](https://github.com/argilla-io/argilla/pull/3326)). - The protected metadata fields support other than textual info - existing datasets must be reindex. See [docs](https://docs.argilla.io/en/latest/getting_started/installation/configurations/database_migrations.html#elasticsearch) for more detail (Closes [#3332](https://github.com/argilla-io/argilla/issues/3332)). +- Updated `GET /api/v1/me/datasets/{dataset_id}/records` endpoint to allow getting records matching one of the response statuses provided via query param.([#3359](https://github.com/argilla-io/argilla/pull/3359)). +- Updated `POST /api/v1/me/datasets/{dataset_id}/records` endpoint to allow searching records matching one of the response statuses provided via query param.([#3359](https://github.com/argilla-io/argilla/pull/3359)). +- Updated `SearchEngine.search` method to allow searching records matching one of the response statuses provided ([#3359](https://github.com/argilla-io/argilla/pull/3359)). ### Fixed From ecca4ea98696054cf07ea5c5983cc6a7575f3d9c Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 10 Jul 2023 11:20:30 +0200 Subject: [PATCH 6/9] fix: wrong value `user_response_status_filter` in assertion call --- tests/server/api/v1/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 16b1f016bf..84dd4ddf93 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -3011,7 +3011,7 @@ async def test_search_dataset_records_with_response_status_filter( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=Query(text=TextQuery(q="Hello", field="input")), - user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=ResponseStatus.submitted), + user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatus.submitted]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ) From 9ffd7528e17e66e2fe6848fa81c89049caab09d0 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 10 Jul 2023 11:42:56 +0200 Subject: [PATCH 7/9] fix: `UserFactory.create` not awaited --- tests/client/sdk/users/test_api.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/client/sdk/users/test_api.py b/tests/client/sdk/users/test_api.py index ac972ba1b9..8b5e724c62 100644 --- a/tests/client/sdk/users/test_api.py +++ b/tests/client/sdk/users/test_api.py @@ -52,9 +52,10 @@ def test_whoami_errors() -> None: whoami(AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey")) -def test_list_users(owner: "ServerUser") -> None: - UserFactory.create(username="user_1") - UserFactory.create(username="user_2") +@pytest.mark.asyncio +async def test_list_users(owner: "ServerUser") -> None: + await UserFactory.create(username="user_1") + await UserFactory.create(username="user_2") httpx_client = ArgillaSingleton.init(api_key=owner.api_key).http_client.httpx response = list_users(client=httpx_client) From 0036ee0199492e57393cdea475f3c9417ccb8b8d Mon Sep 17 00:00:00 2001 From: Paco Aranda Date: Wed, 2 Aug 2023 13:03:49 +0200 Subject: [PATCH 8/9] fix: exclude response filters from search score --- src/argilla/server/search_engine.py | 11 ++++-- tests/server/test_search_engine.py | 58 +++++++++++++++++++++-------- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/src/argilla/server/search_engine.py b/src/argilla/server/search_engine.py index 9434c5120a..44e573b97c 100644 --- a/src/argilla/server/search_engine.py +++ b/src/argilla/server/search_engine.py @@ -187,8 +187,7 @@ async def search( bool_query = {"must": [text_query]} if user_response_status_filter: - bool_query["should"] = self._response_status_filter_builder(user_response_status_filter) - bool_query["minimum_should_match"] = 1 + bool_query["filter"] = self._response_status_filter_builder(user_response_status_filter) body = {"_source": False, "query": {"bool": bool_query}} @@ -277,7 +276,11 @@ async def _get_index_or_raise(self, dataset: Dataset): def _index_name_for_dataset(dataset: Dataset): return f"rg.{dataset.id}" - def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter) -> List[Dict[str, Any]]: + def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter) -> Optional[Dict[str, Any]]: + + if not status_filter.statuses: + return None + user_response_field = f"responses.{status_filter.user.username}" statuses = [ @@ -293,7 +296,7 @@ def _response_status_filter_builder(self, status_filter: UserResponseStatusFilte # See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html filters.append({"terms": {f"{user_response_field}.status": statuses}}) - return filters + return {"bool": {"should": filters, "minimum_should_match": 1}} async def get_search_engine() -> AsyncGenerator[SearchEngine, None]: diff --git a/tests/server/test_search_engine.py b/tests/server/test_search_engine.py index d896bd837e..b9c76d0a1d 100644 --- a/tests/server/test_search_engine.py +++ b/tests/server/test_search_engine.py @@ -18,7 +18,7 @@ import pytest import pytest_asyncio from argilla.server.enums import ResponseStatusFilter -from argilla.server.models import Dataset +from argilla.server.models import Dataset, User from argilla.server.search_engine import Query as SearchQuery from argilla.server.search_engine import ( SearchEngine, @@ -372,6 +372,7 @@ async def test_search_with_query_string( @pytest.mark.parametrize( "statuses, expected_items", [ + ([], 6), ([ResponseStatusFilter.missing], 6), ([ResponseStatusFilter.draft], 2), ([ResponseStatusFilter.submitted], 2), @@ -389,23 +390,10 @@ async def test_search_with_response_status_filter( statuses: List[ResponseStatusFilter], expected_items: int, ): - index_name = f"rg.{test_banking_sentiment_dataset.id}" user = await UserFactory.create() - another_user = await UserFactory.create() - # Create two responses with the same status (one in each record) - for i, status in enumerate(statuses): - if status == ResponseStatusFilter.missing: - continue - offset = i * 2 - for record in test_banking_sentiment_dataset.records[offset : offset + 2]: - users_responses = { - f"{user.username}.status": status.value, - f"{another_user.username}.status": status.value, - } - opensearch.update(index_name, id=record.id, body={"doc": {"responses": users_responses}}) + await self._configure_record_responses(opensearch, test_banking_sentiment_dataset, statuses, user) - opensearch.indices.refresh(index=index_name) result = await elastic_search_engine.search( test_banking_sentiment_dataset, query=SearchQuery(text=TextQuery(q="payment")), @@ -414,6 +402,27 @@ async def test_search_with_response_status_filter( assert len(result.items) == expected_items assert result.total == expected_items + async def test_search_with_response_status_filter_does_not_affect_the_result_scores( + self, elastic_search_engine: SearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset + ): + user = await UserFactory.create() + + all_statuses = [ResponseStatusFilter.missing, ResponseStatusFilter.draft, ResponseStatusFilter.discarded] + await self._configure_record_responses(opensearch, test_banking_sentiment_dataset, all_statuses, user) + + no_filter_results = await elastic_search_engine.search( + test_banking_sentiment_dataset, + query=SearchQuery(text=TextQuery(q="payment")), + ) + results = await elastic_search_engine.search( + test_banking_sentiment_dataset, + query=SearchQuery(text=TextQuery(q="payment")), + user_response_status_filter=UserResponseStatusFilter(user=user, statuses=all_statuses), + ) + assert len(no_filter_results.items) == len(results.items) + assert no_filter_results.total == results.total + assert [item.score for item in no_filter_results.items] == [item.score for item in results.items] + @pytest.mark.parametrize(("offset", "limit"), [(0, 50), (10, 5), (0, 0), (90, 100)]) async def test_search_with_pagination( self, @@ -524,3 +533,22 @@ async def test_delete_record_response( results = opensearch.get(index=index_name, id=record.id) assert results["_source"]["responses"] == {} + + async def _configure_record_responses( + self, opensearch: OpenSearch, dataset: Dataset, response_status: List[ResponseStatusFilter], user: User + ): + index_name = f"rg.{dataset.id}" + another_user = await UserFactory.create() + + # Create two responses with the same status (one in each record) + for i, status in enumerate(response_status): + if status == ResponseStatusFilter.missing: + continue + offset = i * 2 + for record in dataset.records[offset : offset + 2]: + users_responses = { + f"{user.username}.status": status.value, + f"{another_user.username}.status": status.value, + } + opensearch.update(index_name, id=record.id, body={"doc": {"responses": users_responses}}) + opensearch.indices.refresh(index=index_name) From 3102ff75af7169bb7ed15bffec28cf1c451fd505 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Wed, 2 Aug 2023 13:22:20 +0200 Subject: [PATCH 9/9] fix: wrong parameter name --- tests/unit/server/api/v1/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/server/api/v1/test_datasets.py b/tests/unit/server/api/v1/test_datasets.py index 8839d618b3..935e62ae87 100644 --- a/tests/unit/server/api/v1/test_datasets.py +++ b/tests/unit/server/api/v1/test_datasets.py @@ -3009,7 +3009,7 @@ async def test_search_dataset_records_with_response_status_filter( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=Query(text=TextQuery(q="Hello", field="input")), - user_response_status_filter=UserResponseStatusFilter(user=owner, status=ResponseStatusFilter.submitted), + user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatusFilter.submitted]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, )