From 8bb875672b88eb74e2800503f1e30943f0014cbf Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Thu, 1 Jun 2023 16:14:01 +0200 Subject: [PATCH 01/18] feat: add `search_records` policy method --- src/argilla/server/policies.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index e1078bdb66..45cd66ac8d 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -167,6 +167,10 @@ def create_question(cls, actor: User) -> bool: def create_records(cls, actor: User) -> bool: return actor.is_admin + @classmethod + def search_records(cls, actor: User) -> bool: + return True + @classmethod def publish(cls, actor: User) -> bool: return actor.is_admin From c5e37314cbedbe3f59396ed72bf9a087a2ef0001 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Thu, 1 Jun 2023 16:14:28 +0200 Subject: [PATCH 02/18] feat: add `get_search_engine` return type hint --- src/argilla/server/search_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/argilla/server/search_engine.py b/src/argilla/server/search_engine.py index 0d54451f5e..42bd5bed7c 100644 --- a/src/argilla/server/search_engine.py +++ b/src/argilla/server/search_engine.py @@ -13,7 +13,7 @@ # limitations under the License. import dataclasses -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Union from uuid import UUID from opensearchpy import AsyncOpenSearch, helpers @@ -249,7 +249,7 @@ def _index_name_for_dataset(dataset: Dataset): return f"rg.{dataset.id}" -async def get_search_engine(): +async def get_search_engine() -> Generator[SearchEngine, None, None]: config = dict( hosts=settings.elasticsearch, verify_certs=settings.elasticsearch_ssl_verify, From 673b41accffebcbc67c8358a48948e5335cf61c8 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Thu, 1 Jun 2023 16:14:48 +0200 Subject: [PATCH 03/18] feat: add search records endpoint --- .../server/apis/v1/handlers/datasets.py | 74 ++++++++++++++++++- src/argilla/server/contexts/datasets.py | 13 +++- src/argilla/server/schemas/v1/datasets.py | 9 +++ 3 files changed, 92 insertions(+), 4 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 05e41a836d..c26d1e96e8 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Security, status @@ -38,10 +38,20 @@ Records, RecordsCreate, ResponseStatusFilter, + SearchRecord, + SearchRecordsResult, +) +from argilla.server.search_engine import Query as SearchEngineQuery +from argilla.server.search_engine import ( + SearchEngine, + UserResponseStatusFilter, + get_search_engine, ) -from argilla.server.search_engine import SearchEngine, get_search_engine from argilla.server.security import auth +if TYPE_CHECKING: + from argilla.server.search_engine import SearchResponses + LIST_DATASET_RECORDS_LIMIT_DEFAULT = 50 LIST_DATASET_RECORDS_LIMIT_LTE = 1000 @@ -59,6 +69,16 @@ def _get_dataset(db: Session, dataset_id: UUID): return dataset +def _merge_search_records(search_responses: "SearchResponses", records: List[Records]) -> List[SearchRecord]: + search_records = [] + for response in search_responses.items: + record = next((r for r in records if r.id == UUID(response.record_id)), None) + if record: + search_records.append(SearchRecord(record=record, query_score=response.score)) + records.remove(record) + return search_records + + @router.get("/me/datasets", response_model=Datasets) def list_current_user_datasets( *, @@ -106,7 +126,7 @@ def list_current_user_dataset_records( *, db: Session = Depends(get_db), dataset_id: UUID, - include: Optional[List[RecordInclude]] = Query([]), + include: List[RecordInclude] = Query([]), response_status: Optional[ResponseStatusFilter] = Query(None), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), @@ -278,6 +298,54 @@ async def create_dataset_records( raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err)) +@router.post( + "/datasets/{dataset_id}/records/search", + status_code=status.HTTP_200_OK, + response_model=SearchRecordsResult, + response_model_exclude_unset=True, +) +async def search_dataset_records( + *, + db: Session = Depends(get_db), + search_engine: SearchEngine = Depends(get_search_engine), + telemetry_client: TelemetryClient = Depends(get_telemetry_client), + dataset_id: UUID, + query: SearchEngineQuery, + include: List[RecordInclude] = Query([]), + response_status: Optional[ResponseStatusFilter] = Query(None), + limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), + current_user: User = Security(auth.get_current_user), +): + authorize(current_user, DatasetPolicyV1.search_records) + + if query.text.field and not datasets.get_field_by_name_and_dataset_id(db, query.text.field, dataset_id): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Field `{query.text.field}` not found in dataset `{dataset_id}`.", + ) + + dataset = _get_dataset(db, dataset_id) + + user_response_status_filter = None + if response_status: + user_response_status_filter = UserResponseStatusFilter( + user=current_user, + statuses=[response_status.value], + ) + + search_responses = await search_engine.search( + dataset=dataset, + query=query, + user_response_status_filter=user_response_status_filter, + limit=limit, + ) + + record_ids = [UUID(response.record_id) for response in search_responses.items] + records = datasets.get_records_by_ids(db, dataset_id, record_ids, include) + + return SearchRecordsResult(items=_merge_search_records(search_responses, records)) + + @router.put("/datasets/{dataset_id}/publish", response_model=Dataset) async def publish_dataset( *, diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index 41918a95a9..5d58179a80 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -18,7 +18,7 @@ from fastapi.encoders import jsonable_encoder from sqlalchemy import and_, func -from sqlalchemy.orm import Session, contains_eager, joinedload +from sqlalchemy.orm import Session, contains_eager, joinedload, noload from argilla.server.contexts import accounts from argilla.server.models import ( @@ -189,6 +189,17 @@ def get_record_by_id(db: Session, record_id: UUID): return db.get(Record, record_id) +def get_records_by_ids( + db: Session, dataset_id: UUID, record_ids: List[UUID], include: List[RecordInclude] = [] +) -> List[Record]: + query = db.query(Record).filter(Record.dataset_id == dataset_id, Record.id.in_(record_ids)) + if RecordInclude.responses in include: + query = query.options(contains_eager(Record.responses)) + else: + query = query.options(noload(Record.responses)) + return query.all() + + def list_records_by_dataset_id( db: Session, dataset_id: UUID, diff --git a/src/argilla/server/schemas/v1/datasets.py b/src/argilla/server/schemas/v1/datasets.py index bf27a251dd..07ad8f5c48 100644 --- a/src/argilla/server/schemas/v1/datasets.py +++ b/src/argilla/server/schemas/v1/datasets.py @@ -323,3 +323,12 @@ def check_user_id_is_unique(cls, values): class RecordsCreate(BaseModel): items: conlist(item_type=RecordCreate, min_items=RECORDS_CREATE_MIN_ITEMS, max_items=RECORDS_CREATE_MAX_ITEMS) + + +class SearchRecord(BaseModel): + record: Record + query_score: float + + +class SearchRecordsResult(BaseModel): + items: List[SearchRecord] From 4450a3bf7f15608fdce4310c89fa6726795faa31 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Thu, 1 Jun 2023 16:33:50 +0200 Subject: [PATCH 04/18] feat: update `search_records` policy --- src/argilla/server/apis/v1/handlers/datasets.py | 5 ++--- src/argilla/server/policies.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index c26d1e96e8..b3d3dbfe01 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -316,7 +316,8 @@ async def search_dataset_records( limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), current_user: User = Security(auth.get_current_user), ): - authorize(current_user, DatasetPolicyV1.search_records) + dataset = _get_dataset(db, dataset_id) + authorize(current_user, DatasetPolicyV1.search_records(dataset)) if query.text.field and not datasets.get_field_by_name_and_dataset_id(db, query.text.field, dataset_id): raise HTTPException( @@ -324,8 +325,6 @@ async def search_dataset_records( detail=f"Field `{query.text.field}` not found in dataset `{dataset_id}`.", ) - dataset = _get_dataset(db, dataset_id) - user_response_status_filter = None if response_status: user_response_status_filter = UserResponseStatusFilter( diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index 45cd66ac8d..eba4b1b9eb 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -168,8 +168,17 @@ def create_records(cls, actor: User) -> bool: return actor.is_admin @classmethod - def search_records(cls, actor: User) -> bool: - return True + def search_records(cls, dataset: Dataset) -> bool: + return lambda actor: ( + actor.is_admin + or bool( + accounts.get_workspace_user_by_workspace_id_and_user_id( + Session.object_session(actor), + dataset.workspace_id, + actor.id, + ) + ) + ) @classmethod def publish(cls, actor: User) -> bool: From 4da4b7ae0671da7b062a3664f6de9c0f929ce15a Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 2 Jun 2023 15:45:52 +0200 Subject: [PATCH 05/18] fix: `record_id` was `str` instead of `UUID` --- src/argilla/server/apis/v1/handlers/datasets.py | 4 ++-- src/argilla/server/search_engine.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index e8a7f2dfb2..4efa443f61 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -72,7 +72,7 @@ def _get_dataset(db: Session, dataset_id: UUID): def _merge_search_records(search_responses: "SearchResponses", records: List[Records]) -> List[SearchRecord]: search_records = [] for response in search_responses.items: - record = next((r for r in records if r.id == UUID(response.record_id)), None) + record = next((r for r in records if r.id == response.record_id), None) if record: search_records.append(SearchRecord(record=record, query_score=response.score)) records.remove(record) @@ -346,7 +346,7 @@ async def search_dataset_records( limit=limit, ) - record_ids = [UUID(response.record_id) for response in search_responses.items] + record_ids = [response.record_id for response in search_responses.items] records = datasets.get_records_by_ids(db, dataset_id, record_ids, include) return SearchRecordsResult(items=_merge_search_records(search_responses, records)) diff --git a/src/argilla/server/search_engine.py b/src/argilla/server/search_engine.py index 42bd5bed7c..19c3241ba6 100644 --- a/src/argilla/server/search_engine.py +++ b/src/argilla/server/search_engine.py @@ -184,7 +184,7 @@ async def search( items = [] next_page_token = None for hit in response["hits"]["hits"]: - items.append(SearchResponseItem(record_id=hit["_id"], score=hit["_score"])) + items.append(SearchResponseItem(record_id=UUID(hit["_id"]), score=hit["_score"])) # See https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html next_page_token = hit.get("_sort") From 25c9cfda0b55e2f7fa9c10966d6bcc6d18c4856f Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 2 Jun 2023 15:46:21 +0200 Subject: [PATCH 06/18] fix: include response returning `[]` as query result --- src/argilla/server/contexts/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index fa24ad433f..d9442d8b74 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -194,7 +194,7 @@ def get_records_by_ids( ) -> List[Record]: query = db.query(Record).filter(Record.dataset_id == dataset_id, Record.id.in_(record_ids)) if RecordInclude.responses in include: - query = query.options(contains_eager(Record.responses)) + query = query.options(joinedload(Record.responses)) else: query = query.options(noload(Record.responses)) return query.all() From a78dc1348950e204fde07ad469caef99d80e524b Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 2 Jun 2023 16:22:53 +0200 Subject: [PATCH 07/18] feat: add unit tests for search endpoint --- tests/server/api/v1/test_datasets.py | 326 ++++++++++++++++++++++++++- 1 file changed, 324 insertions(+), 2 deletions(-) diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 76408798c7..43aa6e6de2 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -13,12 +13,13 @@ # limitations under the License. from datetime import datetime -from typing import Optional, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type from unittest.mock import MagicMock from uuid import UUID, uuid4 import pytest from argilla._constants import API_KEY_HEADER_NAME +from argilla.server.apis.v1.handlers.datasets import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla.server.models import ( Dataset, DatasetStatus, @@ -46,7 +47,14 @@ RECORDS_CREATE_MIN_ITEMS, RecordInclude, ) -from argilla.server.search_engine import SearchEngine +from argilla.server.search_engine import ( + Query, + SearchEngine, + SearchResponseItem, + SearchResponses, + TextQuery, + UserResponseStatusFilter, +) from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -65,6 +73,9 @@ WorkspaceFactory, ) +if TYPE_CHECKING: + from pytest_mock import MockerFixture + def test_list_current_user_datasets(client: TestClient, admin_auth_header: dict): dataset_a = DatasetFactory.create(name="dataset-a") @@ -2404,6 +2415,317 @@ def test_create_dataset_records_with_nonexistent_dataset_id(client: TestClient, assert db.query(Response).count() == 0 +def create_dataset_for_search() -> Tuple[Dataset, List[Record]]: + dataset = DatasetFactory.create(status=DatasetStatus.ready) + TextFieldFactory.create(name="input", dataset=dataset) + TextFieldFactory.create(name="output", dataset=dataset) + TextQuestionFactory.create(name="input_ok", dataset=dataset) + TextQuestionFactory.create(name="output_ok", dataset=dataset) + records = [ + RecordFactory.create(dataset=dataset, fields={"input": "Say Hello", "output": "Hello"}), + RecordFactory.create(dataset=dataset, fields={"input": "Hello", "output": "Hi"}), + RecordFactory.create(dataset=dataset, fields={"input": "Say Goodbye", "output": "Goodbye"}), + RecordFactory.create(dataset=dataset, fields={"input": "Say bye", "output": "Bye"}), + ] + responses = [ + ResponseFactory.create( + record=records[0], + values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, + status=ResponseStatus.submitted, + ), + ResponseFactory.create( + record=records[1], + values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, + status=ResponseStatus.submitted, + ), + ResponseFactory.create( + record=records[2], + values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, + status=ResponseStatus.submitted, + ), + ResponseFactory.create( + record=records[3], + values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, + status=ResponseStatus.submitted, + ), + ] + return dataset, records, responses + + +def test_search_dataset_records( + mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine, admin: User, admin_auth_header: dict +): + dataset, records, _ = create_dataset_for_search() + + search_mock = mocker.patch.object( + search_engine, + "search", + return_value=SearchResponses( + items=[ + SearchResponseItem(record_id=records[0].id, score=14.2), + SearchResponseItem(record_id=records[1].id, score=12.2), + ] + ), + ) + + query_json = {"text": {"q": "Hello", "field": "input"}} + response = client.post(f"/api/v1/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json) + + search_mock.assert_called_once_with( + dataset=dataset, + query=Query( + text=TextQuery( + q="Hello", + field="input", + ) + ), + user_response_status_filter=None, + limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, + ) + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "record": { + "id": str(records[0].id), + "fields": { + "input": "Say Hello", + "output": "Hello", + }, + "external_id": records[0].external_id, + "responses": [], + "inserted_at": records[0].inserted_at.isoformat(), + "updated_at": records[0].updated_at.isoformat(), + }, + "query_score": 14.2, + }, + { + "record": { + "id": str(records[1].id), + "fields": { + "input": "Hello", + "output": "Hi", + }, + "external_id": records[1].external_id, + "responses": [], + "inserted_at": records[1].inserted_at.isoformat(), + "updated_at": records[1].updated_at.isoformat(), + }, + "query_score": 12.2, + }, + ] + } + + +def test_search_dataset_records_including_responses( + mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine, admin_auth_header: dict +): + dataset, records, responses = create_dataset_for_search() + + search_mock = mocker.patch.object( + search_engine, + "search", + return_value=SearchResponses( + items=[ + SearchResponseItem(record_id=records[0].id, score=14.2), + SearchResponseItem(record_id=records[1].id, score=12.2), + ] + ), + ) + + query_json = {"text": {"q": "Hello", "field": "input"}} + response = client.post( + f"/api/v1/datasets/{dataset.id}/records/search", + headers=admin_auth_header, + json=query_json, + params={"include": RecordInclude.responses.value}, + ) + + search_mock.assert_called_once_with( + dataset=dataset, + query=Query( + text=TextQuery( + q="Hello", + field="input", + ) + ), + user_response_status_filter=None, + limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, + ) + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "record": { + "id": str(records[0].id), + "fields": { + "input": "Say Hello", + "output": "Hello", + }, + "external_id": records[0].external_id, + "responses": [ + { + "id": str(responses[0].id), + "values": { + "input_ok": {"value": "yes"}, + "output_ok": {"value": "yes"}, + }, + "status": "submitted", + "user_id": str(responses[0].user_id), + "inserted_at": responses[0].inserted_at.isoformat(), + "updated_at": responses[0].updated_at.isoformat(), + } + ], + "inserted_at": records[0].inserted_at.isoformat(), + "updated_at": records[0].updated_at.isoformat(), + }, + "query_score": 14.2, + }, + { + "record": { + "id": str(records[1].id), + "fields": { + "input": "Hello", + "output": "Hi", + }, + "external_id": records[1].external_id, + "responses": [ + { + "id": str(responses[1].id), + "values": { + "input_ok": {"value": "yes"}, + "output_ok": {"value": "yes"}, + }, + "status": "submitted", + "user_id": str(responses[1].user_id), + "inserted_at": responses[1].inserted_at.isoformat(), + "updated_at": responses[1].updated_at.isoformat(), + } + ], + "inserted_at": records[1].inserted_at.isoformat(), + "updated_at": records[1].updated_at.isoformat(), + }, + "query_score": 12.2, + }, + ] + } + + +def test_search_dataset_records_with_response_status_filter( + mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine, admin: User, admin_auth_header: dict +): + dataset, _, _ = create_dataset_for_search() + search_mock = mocker.patch.object(search_engine, "search", return_value=SearchResponses(items=[])) + + query_json = {"text": {"q": "Hello", "field": "input"}} + response = client.post( + f"/api/v1/datasets/{dataset.id}/records/search", + headers=admin_auth_header, + json=query_json, + params={"response_status": ResponseStatus.submitted.value}, + ) + + search_mock.assert_called_once_with( + dataset=dataset, + query=Query(text=TextQuery(q="Hello", field="input")), + user_response_status_filter=UserResponseStatusFilter(user=admin, statuses=[ResponseStatus.submitted]), + limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, + ) + assert response.status_code == 200 + + +def test_search_dataset_records_with_limit( + mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine, admin_auth_header: dict +): + dataset, _, _ = create_dataset_for_search() + search_mock = mocker.patch.object(search_engine, "search", return_value=SearchResponses(items=[])) + + query_json = {"text": {"q": "Hello", "field": "input"}} + response = client.post( + f"/api/v1/datasets/{dataset.id}/records/search", + headers=admin_auth_header, + json=query_json, + params={"limit": 10}, + ) + + search_mock.assert_called_once_with( + dataset=dataset, + query=Query(text=TextQuery(q="Hello", field="input")), + user_response_status_filter=None, + limit=10, + ) + assert response.status_code == 200 + + +def test_search_dataset_records_as_annotator(mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine): + dataset, records, _ = create_dataset_for_search() + annotator = AnnotatorFactory.create(workspaces=[dataset.workspace]) + + search_mock = mocker.patch.object( + search_engine, + "search", + return_value=SearchResponses( + items=[ + SearchResponseItem(record_id=records[0].id, score=14.2), + SearchResponseItem(record_id=records[1].id, score=12.2), + ] + ), + ) + + query_json = {"text": {"q": "unit test", "field": "input"}} + response = client.post( + f"/api/v1/datasets/{dataset.id}/records/search", + headers={API_KEY_HEADER_NAME: annotator.api_key}, + json=query_json, + ) + + search_mock.assert_called_once_with( + dataset=dataset, + query=Query( + text=TextQuery( + q="unit test", + field="input", + ) + ), + user_response_status_filter=None, + limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, + ) + assert response.status_code == 200 + + +def test_search_dataset_records_as_annotator_from_different_workspace( + mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine +): + dataset, _, _ = create_dataset_for_search() + annotator = AnnotatorFactory.create(workspaces=[WorkspaceFactory.create()]) + + query_json = {"text": {"q": "unit test", "field": "input"}} + response = client.post( + f"/api/v1/datasets/{dataset.id}/records/search", + headers={API_KEY_HEADER_NAME: annotator.api_key}, + json=query_json, + ) + + assert response.status_code == 403 + + +def test_search_dataset_records_with_non_existent_field(client: TestClient, admin_auth_header: dict): + dataset, _, _ = create_dataset_for_search() + + query_json = {"text": {"q": "unit test", "field": "i do not exist"}} + response = client.post(f"/api/v1/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json) + + assert response.status_code == 422 + + +def test_search_dataset_with_non_existent_dataset(client: TestClient, admin_auth_header: dict): + query_json = {"text": {"q": "unit test", "field": "input"}} + response = client.post(f"/api/v1/datasets/{uuid4()}/records/search", headers=admin_auth_header, json=query_json) + + assert response.status_code == 404 + + +@pytest.mark.asyncio def test_publish_dataset( mocker, client: TestClient, From 32f9fc6aea2413c51067d3cc4b0996311b00434e Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 2 Jun 2023 17:11:29 +0200 Subject: [PATCH 08/18] fix: add missing `/me` --- .../server/apis/v1/handlers/datasets.py | 2 +- tests/server/api/v1/test_datasets.py | 20 +++++++++++-------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 4efa443f61..e9c03355ab 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -306,7 +306,7 @@ async def create_dataset_records( @router.post( - "/datasets/{dataset_id}/records/search", + "/me/datasets/{dataset_id}/records/search", status_code=status.HTTP_200_OK, response_model=SearchRecordsResult, response_model_exclude_unset=True, diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 43aa6e6de2..2ca3a3714b 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -2469,7 +2469,9 @@ def test_search_dataset_records( ) query_json = {"text": {"q": "Hello", "field": "input"}} - response = client.post(f"/api/v1/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json) + response = client.post( + f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json + ) search_mock.assert_called_once_with( dataset=dataset, @@ -2535,7 +2537,7 @@ def test_search_dataset_records_including_responses( query_json = {"text": {"q": "Hello", "field": "input"}} response = client.post( - f"/api/v1/datasets/{dataset.id}/records/search", + f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json, params={"include": RecordInclude.responses.value}, @@ -2619,7 +2621,7 @@ def test_search_dataset_records_with_response_status_filter( query_json = {"text": {"q": "Hello", "field": "input"}} response = client.post( - f"/api/v1/datasets/{dataset.id}/records/search", + f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json, params={"response_status": ResponseStatus.submitted.value}, @@ -2642,7 +2644,7 @@ def test_search_dataset_records_with_limit( query_json = {"text": {"q": "Hello", "field": "input"}} response = client.post( - f"/api/v1/datasets/{dataset.id}/records/search", + f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json, params={"limit": 10}, @@ -2674,7 +2676,7 @@ def test_search_dataset_records_as_annotator(mocker: "MockerFixture", client: Te query_json = {"text": {"q": "unit test", "field": "input"}} response = client.post( - f"/api/v1/datasets/{dataset.id}/records/search", + f"/api/v1/me/datasets/{dataset.id}/records/search", headers={API_KEY_HEADER_NAME: annotator.api_key}, json=query_json, ) @@ -2701,7 +2703,7 @@ def test_search_dataset_records_as_annotator_from_different_workspace( query_json = {"text": {"q": "unit test", "field": "input"}} response = client.post( - f"/api/v1/datasets/{dataset.id}/records/search", + f"/api/v1/me/datasets/{dataset.id}/records/search", headers={API_KEY_HEADER_NAME: annotator.api_key}, json=query_json, ) @@ -2713,14 +2715,16 @@ def test_search_dataset_records_with_non_existent_field(client: TestClient, admi dataset, _, _ = create_dataset_for_search() query_json = {"text": {"q": "unit test", "field": "i do not exist"}} - response = client.post(f"/api/v1/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json) + response = client.post( + f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json + ) assert response.status_code == 422 def test_search_dataset_with_non_existent_dataset(client: TestClient, admin_auth_header: dict): query_json = {"text": {"q": "unit test", "field": "input"}} - response = client.post(f"/api/v1/datasets/{uuid4()}/records/search", headers=admin_auth_header, json=query_json) + response = client.post(f"/api/v1/me/datasets/{uuid4()}/records/search", headers=admin_auth_header, json=query_json) assert response.status_code == 404 From b63c14e0236f1a70e27c57415c78eb008ffaab80 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 2 Jun 2023 17:37:46 +0200 Subject: [PATCH 09/18] feat: remove `_merge_search_records` function --- .../server/apis/v1/handlers/datasets.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index e9c03355ab..6be9673ab1 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Security, status @@ -69,16 +69,6 @@ def _get_dataset(db: Session, dataset_id: UUID): return dataset -def _merge_search_records(search_responses: "SearchResponses", records: List[Records]) -> List[SearchRecord]: - search_records = [] - for response in search_responses.items: - record = next((r for r in records if r.id == response.record_id), None) - if record: - search_records.append(SearchRecord(record=record, query_score=response.score)) - records.remove(record) - return search_records - - @router.get("/me/datasets", response_model=Datasets) def list_current_user_datasets( *, @@ -346,10 +336,14 @@ async def search_dataset_records( limit=limit, ) - record_ids = [response.record_id for response in search_responses.items] - records = datasets.get_records_by_ids(db, dataset_id, record_ids, include) + record_id_score_map = {response.record_id: response.score for response in search_responses.items} + records = datasets.get_records_by_ids( + db=db, dataset_id=dataset_id, record_ids=list(record_id_score_map.keys()), include=include + ) - return SearchRecordsResult(items=_merge_search_records(search_responses, records)) + return SearchRecordsResult( + items=[SearchRecord(record=record, query_score=record_id_score_map[record.id]) for record in records] + ) @router.put("/datasets/{dataset_id}/publish", response_model=Dataset) From b90db2be551f1ce645b83a64806babf1ee85ed76 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 2 Jun 2023 17:41:12 +0200 Subject: [PATCH 10/18] feat: return results sorted based on score --- src/argilla/server/apis/v1/handlers/datasets.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 6be9673ab1..4ac38fc355 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -342,7 +342,11 @@ async def search_dataset_records( ) return SearchRecordsResult( - items=[SearchRecord(record=record, query_score=record_id_score_map[record.id]) for record in records] + items=sorted( + [SearchRecord(record=record, query_score=record_id_score_map[record.id]) for record in records], + key=lambda x: x.query_score, + reverse=True, + ) ) From 43b5be8fbe32adb565852359510e4a4f31f92ab2 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Fri, 2 Jun 2023 18:03:41 +0200 Subject: [PATCH 11/18] docs: add search endpoint --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c69b31bfcb..141833445f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ These are the section headers that we use: - Added new status `draft` for the `Response` model. - Added `LabelSelectionQuestionSettings` class allowing to create label selection (single-choice) questions in the API ([#3005]) - Added `MultiLabelSelectionQuestionSettings` class allowing to create multi-label selection (multi-choice) questions in the API ([#3010]). +- Added `POST /api/v1/me/datasets/{dataset_id}/records/search` endpoint ([#3068]). ### Changed From 7d1605fdc5b1f0d54cca9b2f0b554b9fdbf68fb6 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 5 Jun 2023 10:54:13 +0200 Subject: [PATCH 12/18] feat: remove `Dict` import --- src/argilla/server/apis/v1/handlers/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 4ac38fc355..5712ccca34 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, List, Optional from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Security, status From c291df0b99b987740e9f0ce930381425cbd01278 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 5 Jun 2023 11:13:32 +0200 Subject: [PATCH 13/18] feat: update search unit tests to use `mock_search_engine` --- tests/server/api/v1/test_datasets.py | 70 +++++++++++----------------- 1 file changed, 28 insertions(+), 42 deletions(-) diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 8820d09af6..01c56ca32c 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -2446,19 +2446,15 @@ def create_dataset_for_search() -> Tuple[Dataset, List[Record]]: def test_search_dataset_records( - mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine, admin: User, admin_auth_header: dict + client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict ): dataset, records, _ = create_dataset_for_search() - search_mock = mocker.patch.object( - search_engine, - "search", - return_value=SearchResponses( - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ] - ), + mock_search_engine.search.return_value = SearchResponses( + items=[ + SearchResponseItem(record_id=records[0].id, score=14.2), + SearchResponseItem(record_id=records[1].id, score=12.2), + ] ) query_json = {"text": {"q": "Hello", "field": "input"}} @@ -2466,7 +2462,7 @@ def test_search_dataset_records( f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json ) - search_mock.assert_called_once_with( + mock_search_engine.search.assert_called_once_with( dataset=dataset, query=Query( text=TextQuery( @@ -2513,19 +2509,15 @@ def test_search_dataset_records( def test_search_dataset_records_including_responses( - mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine, admin_auth_header: dict + client: TestClient, mock_search_engine: SearchEngine, admin_auth_header: dict ): dataset, records, responses = create_dataset_for_search() - search_mock = mocker.patch.object( - search_engine, - "search", - return_value=SearchResponses( - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ] - ), + mock_search_engine.search.return_value = SearchResponses( + items=[ + SearchResponseItem(record_id=records[0].id, score=14.2), + SearchResponseItem(record_id=records[1].id, score=12.2), + ] ) query_json = {"text": {"q": "Hello", "field": "input"}} @@ -2536,7 +2528,7 @@ def test_search_dataset_records_including_responses( params={"include": RecordInclude.responses.value}, ) - search_mock.assert_called_once_with( + mock_search_engine.search.assert_called_once_with( dataset=dataset, query=Query( text=TextQuery( @@ -2607,10 +2599,10 @@ def test_search_dataset_records_including_responses( def test_search_dataset_records_with_response_status_filter( - mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine, admin: User, admin_auth_header: dict + client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict ): dataset, _, _ = create_dataset_for_search() - search_mock = mocker.patch.object(search_engine, "search", return_value=SearchResponses(items=[])) + mock_search_engine.search.return_value = SearchResponses(items=[]) query_json = {"text": {"q": "Hello", "field": "input"}} response = client.post( @@ -2620,7 +2612,7 @@ def test_search_dataset_records_with_response_status_filter( params={"response_status": ResponseStatus.submitted.value}, ) - search_mock.assert_called_once_with( + mock_search_engine.search.assert_called_once_with( dataset=dataset, query=Query(text=TextQuery(q="Hello", field="input")), user_response_status_filter=UserResponseStatusFilter(user=admin, statuses=[ResponseStatus.submitted]), @@ -2630,10 +2622,10 @@ def test_search_dataset_records_with_response_status_filter( def test_search_dataset_records_with_limit( - mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine, admin_auth_header: dict + client: TestClient, mock_search_engine: SearchEngine, admin_auth_header: dict ): dataset, _, _ = create_dataset_for_search() - search_mock = mocker.patch.object(search_engine, "search", return_value=SearchResponses(items=[])) + mock_search_engine.search.return_value = SearchResponses(items=[]) query_json = {"text": {"q": "Hello", "field": "input"}} response = client.post( @@ -2643,7 +2635,7 @@ def test_search_dataset_records_with_limit( params={"limit": 10}, ) - search_mock.assert_called_once_with( + mock_search_engine.search.assert_called_once_with( dataset=dataset, query=Query(text=TextQuery(q="Hello", field="input")), user_response_status_filter=None, @@ -2652,19 +2644,15 @@ def test_search_dataset_records_with_limit( assert response.status_code == 200 -def test_search_dataset_records_as_annotator(mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine): +def test_search_dataset_records_as_annotator(client: TestClient, mock_search_engine: SearchEngine): dataset, records, _ = create_dataset_for_search() annotator = AnnotatorFactory.create(workspaces=[dataset.workspace]) - search_mock = mocker.patch.object( - search_engine, - "search", - return_value=SearchResponses( - items=[ - SearchResponseItem(record_id=records[0].id, score=14.2), - SearchResponseItem(record_id=records[1].id, score=12.2), - ] - ), + mock_search_engine.search.return_value = SearchResponses( + items=[ + SearchResponseItem(record_id=records[0].id, score=14.2), + SearchResponseItem(record_id=records[1].id, score=12.2), + ] ) query_json = {"text": {"q": "unit test", "field": "input"}} @@ -2674,7 +2662,7 @@ def test_search_dataset_records_as_annotator(mocker: "MockerFixture", client: Te json=query_json, ) - search_mock.assert_called_once_with( + mock_search_engine.search.assert_called_once_with( dataset=dataset, query=Query( text=TextQuery( @@ -2688,9 +2676,7 @@ def test_search_dataset_records_as_annotator(mocker: "MockerFixture", client: Te assert response.status_code == 200 -def test_search_dataset_records_as_annotator_from_different_workspace( - mocker: "MockerFixture", client: TestClient, search_engine: SearchEngine -): +def test_search_dataset_records_as_annotator_from_different_workspace(client: TestClient): dataset, _, _ = create_dataset_for_search() annotator = AnnotatorFactory.create(workspaces=[WorkspaceFactory.create()]) From 50ae3504f8331cda2ba7e45ef79389b346c086d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 5 Jun 2023 11:34:35 +0200 Subject: [PATCH 14/18] feat: allow `query_score` to be optional Co-authored-by: Francisco Aranda --- src/argilla/server/schemas/v1/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/argilla/server/schemas/v1/datasets.py b/src/argilla/server/schemas/v1/datasets.py index 349fd396df..819bc41466 100644 --- a/src/argilla/server/schemas/v1/datasets.py +++ b/src/argilla/server/schemas/v1/datasets.py @@ -351,7 +351,7 @@ class RecordsCreate(BaseModel): class SearchRecord(BaseModel): record: Record - query_score: float + query_score: Optional[float] class SearchRecordsResult(BaseModel): From 86b27dfbc98d5b36cc906435d5afec60f760ecd8 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 5 Jun 2023 12:27:27 +0200 Subject: [PATCH 15/18] feat: add `user_id` to query --- src/argilla/server/apis/v1/handlers/datasets.py | 6 +++++- src/argilla/server/contexts/datasets.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 5712ccca34..48350f354d 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -338,7 +338,11 @@ async def search_dataset_records( record_id_score_map = {response.record_id: response.score for response in search_responses.items} records = datasets.get_records_by_ids( - db=db, dataset_id=dataset_id, record_ids=list(record_id_score_map.keys()), include=include + db=db, + dataset_id=dataset_id, + record_ids=list(record_id_score_map.keys()), + include=include, + user_id=current_user.id, ) return SearchRecordsResult( diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index ca89694bad..846ba1c431 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -190,11 +190,20 @@ def get_record_by_id(db: Session, record_id: UUID): def get_records_by_ids( - db: Session, dataset_id: UUID, record_ids: List[UUID], include: List[RecordInclude] = [] + db: Session, + dataset_id: UUID, + record_ids: List[UUID], + include: List[RecordInclude] = [], + user_id: Optional[UUID] = None, ) -> List[Record]: query = db.query(Record).filter(Record.dataset_id == dataset_id, Record.id.in_(record_ids)) if RecordInclude.responses in include: - query = query.options(joinedload(Record.responses)) + if user_id: + query = query.outerjoin( + Response, and_(Response.record_id == Record.id, Response.user_id == user_id) + ).options(contains_eager(Record.responses)) + else: + query = query.options(joinedload(Record.responses)) else: query = query.options(noload(Record.responses)) return query.all() From d3748bc5f27814d7bec4820e50f4ce6bc09558d1 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 5 Jun 2023 12:43:04 +0200 Subject: [PATCH 16/18] refactor: records sorting according to `SearchEngine` results --- src/argilla/server/apis/v1/handlers/datasets.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 48350f354d..58567116db 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -336,7 +336,10 @@ async def search_dataset_records( limit=limit, ) - record_id_score_map = {response.record_id: response.score for response in search_responses.items} + record_id_score_map = { + response.record_id: {"query_score": response.score, "search_record": None} + for response in search_responses.items + } records = datasets.get_records_by_ids( db=db, dataset_id=dataset_id, @@ -345,13 +348,12 @@ async def search_dataset_records( user_id=current_user.id, ) - return SearchRecordsResult( - items=sorted( - [SearchRecord(record=record, query_score=record_id_score_map[record.id]) for record in records], - key=lambda x: x.query_score, - reverse=True, + for record in records: + record_id_score_map[record.id]["search_record"] = SearchRecord( + record=record, query_score=record_id_score_map[record.id]["query_score"] ) - ) + + return SearchRecordsResult(items=[record["search_record"] for record in record_id_score_map.values()]) @router.put("/datasets/{dataset_id}/publish", response_model=Dataset) From 75717b3d17d86a76282c402aa15ef3e8789ebba0 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 5 Jun 2023 13:52:20 +0200 Subject: [PATCH 17/18] feat: add other users `Responses` --- tests/server/api/v1/test_datasets.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 01c56ca32c..7ef97781c7 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -2408,7 +2408,7 @@ def test_create_dataset_records_with_nonexistent_dataset_id(client: TestClient, assert db.query(Response).count() == 0 -def create_dataset_for_search() -> Tuple[Dataset, List[Record]]: +def create_dataset_for_search(user: Optional[User] = None) -> Tuple[Dataset, List[Record]]: dataset = DatasetFactory.create(status=DatasetStatus.ready) TextFieldFactory.create(name="input", dataset=dataset) TextFieldFactory.create(name="output", dataset=dataset) @@ -2425,30 +2425,36 @@ def create_dataset_for_search() -> Tuple[Dataset, List[Record]]: record=records[0], values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, status=ResponseStatus.submitted, + user=user, ), ResponseFactory.create( record=records[1], values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, status=ResponseStatus.submitted, + user=user, ), ResponseFactory.create( record=records[2], values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, status=ResponseStatus.submitted, + user=user, ), ResponseFactory.create( record=records[3], values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, status=ResponseStatus.submitted, + user=user, ), ] + # Add some responses from other users + ResponseFactory.create_batch(10, record=records[0], status=ResponseStatus.submitted) return dataset, records, responses def test_search_dataset_records( client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict ): - dataset, records, _ = create_dataset_for_search() + dataset, records, _ = create_dataset_for_search(user=admin) mock_search_engine.search.return_value = SearchResponses( items=[ @@ -2509,9 +2515,9 @@ def test_search_dataset_records( def test_search_dataset_records_including_responses( - client: TestClient, mock_search_engine: SearchEngine, admin_auth_header: dict + client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict ): - dataset, records, responses = create_dataset_for_search() + dataset, records, responses = create_dataset_for_search(user=admin) mock_search_engine.search.return_value = SearchResponses( items=[ @@ -2601,7 +2607,7 @@ def test_search_dataset_records_including_responses( def test_search_dataset_records_with_response_status_filter( client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict ): - dataset, _, _ = create_dataset_for_search() + dataset, _, _ = create_dataset_for_search(user=admin) mock_search_engine.search.return_value = SearchResponses(items=[]) query_json = {"text": {"q": "Hello", "field": "input"}} @@ -2622,9 +2628,9 @@ def test_search_dataset_records_with_response_status_filter( def test_search_dataset_records_with_limit( - client: TestClient, mock_search_engine: SearchEngine, admin_auth_header: dict + client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict ): - dataset, _, _ = create_dataset_for_search() + dataset, _, _ = create_dataset_for_search(user=admin) mock_search_engine.search.return_value = SearchResponses(items=[]) query_json = {"text": {"q": "Hello", "field": "input"}} @@ -2644,8 +2650,8 @@ def test_search_dataset_records_with_limit( assert response.status_code == 200 -def test_search_dataset_records_as_annotator(client: TestClient, mock_search_engine: SearchEngine): - dataset, records, _ = create_dataset_for_search() +def test_search_dataset_records_as_annotator(client: TestClient, admin: User, mock_search_engine: SearchEngine): + dataset, records, _ = create_dataset_for_search(user=admin) annotator = AnnotatorFactory.create(workspaces=[dataset.workspace]) mock_search_engine.search.return_value = SearchResponses( From fc5a2b7fca3f84324673473648e62cc209221652 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Mon, 5 Jun 2023 15:53:36 +0200 Subject: [PATCH 18/18] feat: add `SearchRecordsQuery` schema --- .../server/apis/v1/handlers/datasets.py | 20 ++++++++--------- src/argilla/server/contexts/datasets.py | 2 +- src/argilla/server/enums.py | 22 +++++++++++++++++++ src/argilla/server/schemas/v1/datasets.py | 13 +++++------ src/argilla/server/search_engine.py | 2 +- tests/server/api/v1/test_datasets.py | 16 +++++++------- 6 files changed, 48 insertions(+), 27 deletions(-) create mode 100644 src/argilla/server/enums.py diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 58567116db..635e48a48e 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, List, Optional +from typing import List, Optional from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Security, status @@ -21,6 +21,7 @@ from argilla.server.commons.telemetry import TelemetryClient, get_telemetry_client from argilla.server.contexts import accounts, datasets from argilla.server.database import get_db +from argilla.server.enums import ResponseStatusFilter from argilla.server.models import User from argilla.server.policies import DatasetPolicyV1, authorize from argilla.server.schemas.v1.datasets import ( @@ -37,11 +38,10 @@ RecordInclude, Records, RecordsCreate, - ResponseStatusFilter, SearchRecord, + SearchRecordsQuery, SearchRecordsResult, ) -from argilla.server.search_engine import Query as SearchEngineQuery from argilla.server.search_engine import ( SearchEngine, UserResponseStatusFilter, @@ -49,9 +49,6 @@ ) from argilla.server.security import auth -if TYPE_CHECKING: - from argilla.server.search_engine import SearchResponses - LIST_DATASET_RECORDS_LIMIT_DEFAULT = 50 LIST_DATASET_RECORDS_LIMIT_LTE = 1000 @@ -307,7 +304,7 @@ async def search_dataset_records( search_engine: SearchEngine = Depends(get_search_engine), telemetry_client: TelemetryClient = Depends(get_telemetry_client), dataset_id: UUID, - query: SearchEngineQuery, + query: SearchRecordsQuery, include: List[RecordInclude] = Query([]), response_status: Optional[ResponseStatusFilter] = Query(None), limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), @@ -316,10 +313,13 @@ async def search_dataset_records( dataset = _get_dataset(db, dataset_id) authorize(current_user, DatasetPolicyV1.search_records(dataset)) - if query.text.field and not datasets.get_field_by_name_and_dataset_id(db, query.text.field, dataset_id): + search_engine_query = query.query + if search_engine_query.text.field and not datasets.get_field_by_name_and_dataset_id( + db, search_engine_query.text.field, dataset_id + ): raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail=f"Field `{query.text.field}` not found in dataset `{dataset_id}`.", + detail=f"Field `{search_engine_query.text.field}` not found in dataset `{dataset_id}`.", ) user_response_status_filter = None @@ -331,7 +331,7 @@ async def search_dataset_records( search_responses = await search_engine.search( dataset=dataset, - query=query, + query=search_engine_query, user_response_status_filter=user_response_status_filter, limit=limit, ) diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index 846ba1c431..4a0d566a35 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -21,6 +21,7 @@ from sqlalchemy.orm import Session, contains_eager, joinedload, noload from argilla.server.contexts import accounts +from argilla.server.enums import ResponseStatusFilter from argilla.server.models import ( Dataset, DatasetStatus, @@ -37,7 +38,6 @@ QuestionCreate, RecordInclude, RecordsCreate, - ResponseStatusFilter, ) from argilla.server.schemas.v1.records import ResponseCreate from argilla.server.schemas.v1.responses import ResponseUpdate diff --git a/src/argilla/server/enums.py b/src/argilla/server/enums.py new file mode 100644 index 0000000000..b73f660168 --- /dev/null +++ b/src/argilla/server/enums.py @@ -0,0 +1,22 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class ResponseStatusFilter(str, Enum): + draft = "draft" + missing = "missing" + submitted = "submitted" + discarded = "discarded" diff --git a/src/argilla/server/schemas/v1/datasets.py b/src/argilla/server/schemas/v1/datasets.py index 819bc41466..ab931dc3f3 100644 --- a/src/argilla/server/schemas/v1/datasets.py +++ b/src/argilla/server/schemas/v1/datasets.py @@ -20,6 +20,8 @@ from pydantic import BaseModel, PositiveInt, conlist, constr, root_validator, validator from pydantic import Field as PydanticField +from argilla.server.search_engine import Query + try: from typing import Annotated, Literal except ImportError: @@ -285,13 +287,6 @@ class RecordInclude(str, Enum): responses = "responses" -class ResponseStatusFilter(str, Enum): - draft = "draft" - missing = "missing" - submitted = "submitted" - discarded = "discarded" - - class Record(BaseModel): id: UUID fields: Dict[str, Any] @@ -349,6 +344,10 @@ class RecordsCreate(BaseModel): items: conlist(item_type=RecordCreate, min_items=RECORDS_CREATE_MIN_ITEMS, max_items=RECORDS_CREATE_MAX_ITEMS) +class SearchRecordsQuery(BaseModel): + query: Query + + class SearchRecord(BaseModel): record: Record query_score: Optional[float] diff --git a/src/argilla/server/search_engine.py b/src/argilla/server/search_engine.py index 825b2133f9..06adb8d03e 100644 --- a/src/argilla/server/search_engine.py +++ b/src/argilla/server/search_engine.py @@ -20,6 +20,7 @@ from pydantic import BaseModel from pydantic.utils import GetterDict +from argilla.server.enums import ResponseStatusFilter from argilla.server.models import ( Dataset, Field, @@ -31,7 +32,6 @@ ResponseStatus, User, ) -from argilla.server.schemas.v1.datasets import ResponseStatusFilter from argilla.server.settings import settings diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 7ef97781c7..e48052b760 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -2463,7 +2463,7 @@ def test_search_dataset_records( ] ) - query_json = {"text": {"q": "Hello", "field": "input"}} + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} response = client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json ) @@ -2526,7 +2526,7 @@ def test_search_dataset_records_including_responses( ] ) - query_json = {"text": {"q": "Hello", "field": "input"}} + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} response = client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, @@ -2610,7 +2610,7 @@ def test_search_dataset_records_with_response_status_filter( dataset, _, _ = create_dataset_for_search(user=admin) mock_search_engine.search.return_value = SearchResponses(items=[]) - query_json = {"text": {"q": "Hello", "field": "input"}} + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} response = client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, @@ -2633,7 +2633,7 @@ def test_search_dataset_records_with_limit( dataset, _, _ = create_dataset_for_search(user=admin) mock_search_engine.search.return_value = SearchResponses(items=[]) - query_json = {"text": {"q": "Hello", "field": "input"}} + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} response = client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, @@ -2661,7 +2661,7 @@ def test_search_dataset_records_as_annotator(client: TestClient, admin: User, mo ] ) - query_json = {"text": {"q": "unit test", "field": "input"}} + query_json = {"query": {"text": {"q": "unit test", "field": "input"}}} response = client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", headers={API_KEY_HEADER_NAME: annotator.api_key}, @@ -2686,7 +2686,7 @@ def test_search_dataset_records_as_annotator_from_different_workspace(client: Te dataset, _, _ = create_dataset_for_search() annotator = AnnotatorFactory.create(workspaces=[WorkspaceFactory.create()]) - query_json = {"text": {"q": "unit test", "field": "input"}} + query_json = {"query": {"text": {"q": "unit test", "field": "input"}}} response = client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", headers={API_KEY_HEADER_NAME: annotator.api_key}, @@ -2699,7 +2699,7 @@ def test_search_dataset_records_as_annotator_from_different_workspace(client: Te def test_search_dataset_records_with_non_existent_field(client: TestClient, admin_auth_header: dict): dataset, _, _ = create_dataset_for_search() - query_json = {"text": {"q": "unit test", "field": "i do not exist"}} + query_json = {"query": {"text": {"q": "unit test", "field": "i do not exist"}}} response = client.post( f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json ) @@ -2708,7 +2708,7 @@ def test_search_dataset_records_with_non_existent_field(client: TestClient, admi def test_search_dataset_with_non_existent_dataset(client: TestClient, admin_auth_header: dict): - query_json = {"text": {"q": "unit test", "field": "input"}} + query_json = {"query": {"text": {"q": "unit test", "field": "input"}}} response = client.post(f"/api/v1/me/datasets/{uuid4()}/records/search", headers=admin_auth_header, json=query_json) assert response.status_code == 404