diff --git a/CHANGELOG.md b/CHANGELOG.md index 49ea45e14f..56a53ce393 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ These are the section headers that we use: - Limit rating questions values to the positive range [1, 10] (Closes [#3451](https://github.com/argilla-io/argilla/issues/3451)). - Updated `POST /api/users` endpoint to be able to provide a list of workspace names to which the user should be linked to ([#3462](https://github.com/argilla-io/argilla/pull/3462)). - Updated Python client `User.create` method to be able to provide a list of workspace names to which the user should be linked to ([#3462](https://github.com/argilla-io/argilla/pull/3462)). +- Updated `GET /api/v1/me/datasets/{dataset_id}/records` endpoint to allow getting records matching one of the response statuses provided via query param.([#3359](https://github.com/argilla-io/argilla/pull/3359)). +- Updated `POST /api/v1/me/datasets/{dataset_id}/records` endpoint to allow searching records matching one of the response statuses provided via query param.([#3359](https://github.com/argilla-io/argilla/pull/3359)). +- Updated `SearchEngine.search` method to allow searching records matching one of the response statuses provided ([#3359](https://github.com/argilla-io/argilla/pull/3359)). ## [1.13.3](https://github.com/argilla-io/argilla/compare/v1.13.2...v1.13.3) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 6a548ea887..105271d29b 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import List from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Security, status @@ -108,7 +108,7 @@ async def list_current_user_dataset_records( db: AsyncSession = Depends(get_async_db), dataset_id: UUID, include: List[RecordInclude] = Query([], description="Relationships to include in the response"), - response_status: Optional[ResponseStatusFilter] = Query(None), + response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), current_user: User = Security(auth.get_current_user), @@ -118,7 +118,13 @@ async def list_current_user_dataset_records( await authorize(current_user, DatasetPolicyV1.get(dataset)) records = await datasets.list_records_by_dataset_id_and_user_id( - db, dataset_id, current_user.id, include=include, response_status=response_status, offset=offset, limit=limit + db, + dataset_id, + current_user.id, + include=include, + response_statuses=response_statuses, + offset=offset, + limit=limit, ) return Records(items=records) @@ -297,7 +303,7 @@ async def search_dataset_records( dataset_id: UUID, query: SearchRecordsQuery, include: List[RecordInclude] = Query([]), - response_status: Optional[ResponseStatusFilter] = Query(None), + response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"), offset: int = Query(0, ge=0), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), current_user: User = Security(auth.get_current_user), @@ -315,8 +321,8 @@ async def search_dataset_records( ) user_response_status_filter = None - if response_status: - user_response_status_filter = UserResponseStatusFilter(user=current_user, status=response_status) + if response_statuses: + user_response_status_filter = UserResponseStatusFilter(user=current_user, statuses=response_statuses) search_responses = await search_engine.search( dataset=dataset, diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index cc3c673a18..bb4e0045c4 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -16,7 +16,7 @@ from uuid import UUID from fastapi.encoders import jsonable_encoder -from sqlalchemy import and_, delete, func, select +from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import contains_eager, joinedload, selectinload from argilla.server.contexts import accounts @@ -267,20 +267,35 @@ async def list_records_by_dataset_id_and_user_id( dataset_id: UUID, user_id: UUID, include: List[RecordInclude] = [], - response_status: Optional[ResponseStatusFilter] = None, + response_statuses: List[ResponseStatusFilter] = [], offset: int = 0, limit: int = LIST_RECORDS_LIMIT, ) -> List[Record]: + response_statuses_ = [ + ResponseStatus(response_status) + for response_status in response_statuses + if response_status != ResponseStatusFilter.missing + ] + + response_status_filter_expressions = [] + + if response_statuses_: + response_status_filter_expressions.append(Response.status.in_(response_statuses_)) + + if ResponseStatusFilter.missing in response_statuses: + response_status_filter_expressions.append(Response.status.is_(None)) + query = ( select(Record) .filter(Record.dataset_id == dataset_id) - .outerjoin(Response, and_(Response.record_id == Record.id, Response.user_id == user_id)) + .outerjoin( + Response, + and_(Response.record_id == Record.id, Response.user_id == user_id), + ) ) - if response_status == ResponseStatusFilter.missing: - query = query.filter(Response.status == None) # noqa: E711 - elif response_status is not None: - query = query.filter(Response.status == ResponseStatus(response_status)) + if response_status_filter_expressions: + query = query.filter(or_(*response_status_filter_expressions)) if RecordInclude.responses in include: query = query.options(contains_eager(Record.responses)) diff --git a/src/argilla/server/search_engine.py b/src/argilla/server/search_engine.py index 00440985fc..9245722e58 100644 --- a/src/argilla/server/search_engine.py +++ b/src/argilla/server/search_engine.py @@ -78,7 +78,7 @@ class Query: @dataclasses.dataclass class UserResponseStatusFilter: user: User - status: ResponseStatusFilter + statuses: List[ResponseStatusFilter] @dataclasses.dataclass @@ -190,15 +190,11 @@ async def search( text_query = self._text_query_builder(dataset, text=query.text) - bool_query: dict = {"must": [text_query]} + bool_query = {"must": [text_query]} if user_response_status_filter: bool_query["filter"] = self._response_status_filter_builder(user_response_status_filter) - body = { - "_source": False, - "query": {"bool": bool_query}, - # "sort": [{"_score": "desc"}, {"id": "asc"}], - } + body = {"_source": False, "query": {"bool": bool_query}} response = await self.client.search( index=self._index_name_for_dataset(dataset), @@ -235,7 +231,7 @@ def _mapping_for_fields(self, fields: List[Field]): def _dynamic_templates_for_question_responses(self, questions: List[Question]) -> List[dict]: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/dynamic-templates.html return [ - {"status_responses": {"path_match": f"responses.*.status", "mapping": {"type": "keyword"}}}, + {"status_responses": {"path_match": "responses.*.status", "mapping": {"type": "keyword"}}}, *[ { f"{question.name}_responses": { @@ -285,15 +281,26 @@ async def _get_index_or_raise(self, dataset: Dataset): def _index_name_for_dataset(dataset: Dataset): return f"rg.{dataset.id}" - def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter): + def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter) -> Optional[Dict[str, Any]]: + if not status_filter.statuses: + return None + user_response_field = f"responses.{status_filter.user.username}" - if status_filter.status == ResponseStatusFilter.missing: + statuses = [ + ResponseStatus(status).value for status in status_filter.statuses if status != ResponseStatusFilter.missing + ] + + filters = [] + if ResponseStatusFilter.missing in status_filter.statuses: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-exists-query.html - return [{"bool": {"must_not": {"exists": {"field": user_response_field}}}}] - else: + filters.append({"bool": {"must_not": {"exists": {"field": user_response_field}}}}) + + if statuses: # See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html - return [{"term": {f"{user_response_field}.status": status_filter.status}}] + filters.append({"terms": {f"{user_response_field}.status": statuses}}) + + return {"bool": {"should": filters, "minimum_should_match": 1}} async def _bulk_op(self, actions: List[Dict[str, Any]]): _, errors = await helpers.async_bulk(client=self.client, actions=actions, raise_on_error=False) diff --git a/tests/integration/client/sdk/api/test_users.py b/tests/integration/client/sdk/api/test_users.py index f7700d6c4d..24ff1e6567 100644 --- a/tests/integration/client/sdk/api/test_users.py +++ b/tests/integration/client/sdk/api/test_users.py @@ -50,9 +50,10 @@ def test_whoami_errors() -> None: whoami(AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey")) -def test_list_users(owner: "ServerUser") -> None: - UserFactory.create(username="user_1") - UserFactory.create(username="user_2") +@pytest.mark.asyncio +async def test_list_users(owner: "ServerUser") -> None: + await UserFactory.create(username="user_1") + await UserFactory.create(username="user_2") httpx_client = ArgillaSingleton.init(api_key=owner.api_key).http_client.httpx response = list_users(client=httpx_client) diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 38a7a2c6be..a81a2b30c7 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -20,6 +20,7 @@ import pytest from argilla._constants import API_KEY_HEADER_NAME from argilla.server.apis.v1.handlers.datasets import LIST_DATASET_RECORDS_LIMIT_DEFAULT +from argilla.server.enums import ResponseStatusFilter from argilla.server.models import ( Dataset, DatasetStatus, @@ -937,10 +938,21 @@ async def create_records_with_response( await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status) -@pytest.mark.parametrize("response_status_filter", ["missing", "discarded", "submitted", "draft"]) +@pytest.mark.parametrize( + "response_status_filters", + [ + [ResponseStatusFilter.missing], + [ResponseStatusFilter.draft], + [ResponseStatusFilter.submitted], + [ResponseStatusFilter.discarded], + [ResponseStatusFilter.missing, ResponseStatusFilter.draft], + [ResponseStatusFilter.submitted, ResponseStatusFilter.discarded], + [ResponseStatusFilter.missing, ResponseStatusFilter.draft, ResponseStatusFilter.discarded], + ], +) @pytest.mark.asyncio async def test_list_current_user_dataset_records_with_response_status_filter( - client: TestClient, owner: "User", owner_auth_header: dict, response_status_filter: str + client: TestClient, owner: "User", owner_auth_header: dict, response_status_filters: List[ResponseStatusFilter] ): num_responses_per_status = 10 response_values = {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}} @@ -960,20 +972,14 @@ async def test_list_current_user_dataset_records_with_response_status_filter( other_dataset = await DatasetFactory.create() await RecordFactory.create_batch(size=2, dataset=other_dataset) - response = client.get( - f"/api/v1/me/datasets/{dataset.id}/records?response_status={response_status_filter}&include=responses", - headers=owner_auth_header, - ) + params = [("include", RecordInclude.responses.value)] + params.extend(("response_status", status_filter.value) for status_filter in response_status_filters) + response = client.get(f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params=params) assert response.status_code == 200 response_json = response.json() - assert len(response_json["items"]) == num_responses_per_status - - if response_status_filter == "missing": - assert all([len(record["responses"]) == 0 for record in response_json["items"]]) - else: - assert all([record["responses"][0]["status"] == response_status_filter for record in response_json["items"]]) + assert len(response_json["items"]) == num_responses_per_status * len(response_status_filters) @pytest.mark.asyncio @@ -3019,7 +3025,7 @@ async def test_search_dataset_records_with_response_status_filter( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=Query(text=TextQuery(q="Hello", field="input")), - user_response_status_filter=UserResponseStatusFilter(user=owner, status=ResponseStatus.submitted), + user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatus.submitted]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ) diff --git a/tests/unit/server/api/v1/test_datasets.py b/tests/unit/server/api/v1/test_datasets.py index 8839d618b3..935e62ae87 100644 --- a/tests/unit/server/api/v1/test_datasets.py +++ b/tests/unit/server/api/v1/test_datasets.py @@ -3009,7 +3009,7 @@ async def test_search_dataset_records_with_response_status_filter( mock_search_engine.search.assert_called_once_with( dataset=dataset, query=Query(text=TextQuery(q="Hello", field="input")), - user_response_status_filter=UserResponseStatusFilter(user=owner, status=ResponseStatusFilter.submitted), + user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatusFilter.submitted]), offset=0, limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, ) diff --git a/tests/unit/server/test_search_engine.py b/tests/unit/server/test_search_engine.py index cc30c3f241..04e0a8eb52 100644 --- a/tests/unit/server/test_search_engine.py +++ b/tests/unit/server/test_search_engine.py @@ -13,12 +13,12 @@ # limitations under the License. import random -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List import pytest import pytest_asyncio from argilla.server.enums import ResponseStatusFilter -from argilla.server.models import Dataset +from argilla.server.models import Dataset, User from argilla.server.search_engine import Query as SearchQuery from argilla.server.search_engine import ( SearchEngine, @@ -109,15 +109,18 @@ async def test_banking_sentiment_dataset(elastic_search_engine: SearchEngine) -> ), await RecordFactory.create( dataset=dataset, - fields={"textId": "00002", "text": "Why was I charged for getting cash?", "label": "neutral"}, + fields={ + "textId": "00002", + "text": "I tried to make a payment with my card and it was declined.", + "label": "negative", + }, responses=[], ), await RecordFactory.create( dataset=dataset, fields={ "textId": "00003", - "text": "I deposited cash into my account a week ago and it is still not available," - " please tell me why? I need the cash back now.", + "text": "My credit card was declined when I tried to make a payment.", "label": "negative", }, responses=[], @@ -126,8 +129,8 @@ async def test_banking_sentiment_dataset(elastic_search_engine: SearchEngine) -> dataset=dataset, fields={ "textId": "00004", - "text": "Why was I charged for getting cash?", - "label": "neutral", + "text": "I made a successful payment towards my mortgage loan earlier today.", + "label": "positive", }, responses=[], ), @@ -135,8 +138,8 @@ async def test_banking_sentiment_dataset(elastic_search_engine: SearchEngine) -> dataset=dataset, fields={ "textId": "00005", - "text": "I tried to make a payment with my card and it was declined.", - "label": "negative", + "text": "Please confirm the receipt of my payment for the credit card bill due on the 15th.", + "label": "neutral", }, responses=[], ), @@ -144,7 +147,22 @@ async def test_banking_sentiment_dataset(elastic_search_engine: SearchEngine) -> dataset=dataset, fields={ "textId": "00006", - "text": "My credit card was declined when I tried to make a payment.", + "text": "Why was I charged for getting cash?", + "label": "neutral", + }, + responses=[], + ), + await RecordFactory.create( + dataset=dataset, + fields={"textId": "00007", "text": "Why was I charged for getting cash?", "label": "neutral"}, + responses=[], + ), + await RecordFactory.create( + dataset=dataset, + fields={ + "textId": "00008", + "text": "I deposited cash into my account a week ago and it is still not available," + " please tell me why? I need the cash back now.", "label": "negative", }, responses=[], @@ -309,23 +327,23 @@ async def test_create_index_with_existing_index( @pytest.mark.parametrize( ("query", "expected_items"), [ - ("card", 4), + ("card", 5), ("account", 1), - ("payment", 4), + ("payment", 6), ("cash", 3), ("negative", 4), ("00000", 1), - ("card payment", 4), + ("card payment", 5), ("nothing", 0), - (SearchQuery(text=TextQuery(q="card")), 4), + (SearchQuery(text=TextQuery(q="card")), 5), (SearchQuery(text=TextQuery(q="account")), 1), - (SearchQuery(text=TextQuery(q="payment")), 4), + (SearchQuery(text=TextQuery(q="payment")), 6), (SearchQuery(text=TextQuery(q="cash")), 3), - (SearchQuery(text=TextQuery(q="card payment")), 4), + (SearchQuery(text=TextQuery(q="card payment")), 5), (SearchQuery(text=TextQuery(q="nothing")), 0), (SearchQuery(text=TextQuery(q="negative", field="label")), 4), (SearchQuery(text=TextQuery(q="00000", field="textId")), 1), - (SearchQuery(text=TextQuery(q="card payment", field="text")), 4), + (SearchQuery(text=TextQuery(q="card payment", field="text")), 5), ], ) async def test_search_with_query_string( @@ -352,12 +370,16 @@ async def test_search_with_query_string( assert scores == sorted_scores @pytest.mark.parametrize( - "status", + "statuses, expected_items", [ - ResponseStatusFilter.discarded, - ResponseStatusFilter.submitted, - ResponseStatusFilter.draft, - ResponseStatusFilter.missing, + ([], 6), + ([ResponseStatusFilter.missing], 6), + ([ResponseStatusFilter.draft], 2), + ([ResponseStatusFilter.submitted], 2), + ([ResponseStatusFilter.discarded], 2), + ([ResponseStatusFilter.missing, ResponseStatusFilter.draft], 6), + ([ResponseStatusFilter.submitted, ResponseStatusFilter.discarded], 4), + ([ResponseStatusFilter.missing, ResponseStatusFilter.draft, ResponseStatusFilter.discarded], 6), ], ) async def test_search_with_response_status_filter( @@ -365,28 +387,41 @@ async def test_search_with_response_status_filter( elastic_search_engine: SearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset, - status: ResponseStatusFilter, + statuses: List[ResponseStatusFilter], + expected_items: int, ): - index_name = f"rg.{test_banking_sentiment_dataset.id}" user = await UserFactory.create() - another_user = await UserFactory.create() - if status != ResponseStatusFilter.missing: - for record in test_banking_sentiment_dataset.records: - users_responses = { - f"{user.username}.status": status.value, - f"{another_user.username}.status": status.value, - } - opensearch.update(index_name, id=record.id, body={"doc": {"responses": users_responses}}) + await self._configure_record_responses(opensearch, test_banking_sentiment_dataset, statuses, user) - opensearch.indices.refresh(index=index_name) result = await elastic_search_engine.search( test_banking_sentiment_dataset, query=SearchQuery(text=TextQuery(q="payment")), - user_response_status_filter=UserResponseStatusFilter(user=user, status=status), + user_response_status_filter=UserResponseStatusFilter(user=user, statuses=statuses), + ) + assert len(result.items) == expected_items + assert result.total == expected_items + + async def test_search_with_response_status_filter_does_not_affect_the_result_scores( + self, elastic_search_engine: SearchEngine, opensearch: OpenSearch, test_banking_sentiment_dataset: Dataset + ): + user = await UserFactory.create() + + all_statuses = [ResponseStatusFilter.missing, ResponseStatusFilter.draft, ResponseStatusFilter.discarded] + await self._configure_record_responses(opensearch, test_banking_sentiment_dataset, all_statuses, user) + + no_filter_results = await elastic_search_engine.search( + test_banking_sentiment_dataset, + query=SearchQuery(text=TextQuery(q="payment")), ) - assert len(result.items) == 4 - assert result.total == 4 + results = await elastic_search_engine.search( + test_banking_sentiment_dataset, + query=SearchQuery(text=TextQuery(q="payment")), + user_response_status_filter=UserResponseStatusFilter(user=user, statuses=all_statuses), + ) + assert len(no_filter_results.items) == len(results.items) + assert no_filter_results.total == results.total + assert [item.score for item in no_filter_results.items] == [item.score for item in results.items] @pytest.mark.parametrize(("offset", "limit"), [(0, 50), (10, 5), (0, 0), (90, 100)]) async def test_search_with_pagination( @@ -533,3 +568,22 @@ async def test_delete_record_response( results = opensearch.get(index=index_name, id=record.id) assert results["_source"]["responses"] == {} + + async def _configure_record_responses( + self, opensearch: OpenSearch, dataset: Dataset, response_status: List[ResponseStatusFilter], user: User + ): + index_name = f"rg.{dataset.id}" + another_user = await UserFactory.create() + + # Create two responses with the same status (one in each record) + for i, status in enumerate(response_status): + if status == ResponseStatusFilter.missing: + continue + offset = i * 2 + for record in dataset.records[offset : offset + 2]: + users_responses = { + f"{user.username}.status": status.value, + f"{another_user.username}.status": status.value, + } + opensearch.update(index_name, id=record.id, body={"doc": {"responses": users_responses}}) + opensearch.indices.refresh(index=index_name)