diff --git a/CHANGELOG.md b/CHANGELOG.md index c69b31bfcb..141833445f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ These are the section headers that we use: - Added new status `draft` for the `Response` model. - Added `LabelSelectionQuestionSettings` class allowing to create label selection (single-choice) questions in the API ([#3005]) - Added `MultiLabelSelectionQuestionSettings` class allowing to create multi-label selection (multi-choice) questions in the API ([#3010]). +- Added `POST /api/v1/me/datasets/{dataset_id}/records/search` endpoint ([#3068]). ### Changed diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index 237b0f7dc7..635e48a48e 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -21,6 +21,7 @@ from argilla.server.commons.telemetry import TelemetryClient, get_telemetry_client from argilla.server.contexts import accounts, datasets from argilla.server.database import get_db +from argilla.server.enums import ResponseStatusFilter from argilla.server.models import User from argilla.server.policies import DatasetPolicyV1, authorize from argilla.server.schemas.v1.datasets import ( @@ -37,9 +38,15 @@ RecordInclude, Records, RecordsCreate, - ResponseStatusFilter, + SearchRecord, + SearchRecordsQuery, + SearchRecordsResult, +) +from argilla.server.search_engine import ( + SearchEngine, + UserResponseStatusFilter, + get_search_engine, ) -from argilla.server.search_engine import SearchEngine, get_search_engine from argilla.server.security import auth LIST_DATASET_RECORDS_LIMIT_DEFAULT = 50 @@ -106,7 +113,7 @@ def list_current_user_dataset_records( *, db: Session = Depends(get_db), dataset_id: UUID, - include: Optional[List[RecordInclude]] = Query([]), + include: List[RecordInclude] = Query([]), response_status: Optional[ResponseStatusFilter] = Query(None), offset: int = 0, limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), @@ -285,6 +292,70 @@ async def create_dataset_records( raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err)) +@router.post( + "/me/datasets/{dataset_id}/records/search", + status_code=status.HTTP_200_OK, + response_model=SearchRecordsResult, + response_model_exclude_unset=True, +) +async def search_dataset_records( + *, + db: Session = Depends(get_db), + search_engine: SearchEngine = Depends(get_search_engine), + telemetry_client: TelemetryClient = Depends(get_telemetry_client), + dataset_id: UUID, + query: SearchRecordsQuery, + include: List[RecordInclude] = Query([]), + response_status: Optional[ResponseStatusFilter] = Query(None), + limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE), + current_user: User = Security(auth.get_current_user), +): + dataset = _get_dataset(db, dataset_id) + authorize(current_user, DatasetPolicyV1.search_records(dataset)) + + search_engine_query = query.query + if search_engine_query.text.field and not datasets.get_field_by_name_and_dataset_id( + db, search_engine_query.text.field, dataset_id + ): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"Field `{search_engine_query.text.field}` not found in dataset `{dataset_id}`.", + ) + + user_response_status_filter = None + if response_status: + user_response_status_filter = UserResponseStatusFilter( + user=current_user, + statuses=[response_status.value], + ) + + search_responses = await search_engine.search( + dataset=dataset, + query=search_engine_query, + user_response_status_filter=user_response_status_filter, + limit=limit, + ) + + record_id_score_map = { + response.record_id: {"query_score": response.score, "search_record": None} + for response in search_responses.items + } + records = datasets.get_records_by_ids( + db=db, + dataset_id=dataset_id, + record_ids=list(record_id_score_map.keys()), + include=include, + user_id=current_user.id, + ) + + for record in records: + record_id_score_map[record.id]["search_record"] = SearchRecord( + record=record, query_score=record_id_score_map[record.id]["query_score"] + ) + + return SearchRecordsResult(items=[record["search_record"] for record in record_id_score_map.values()]) + + @router.put("/datasets/{dataset_id}/publish", response_model=Dataset) async def publish_dataset( *, diff --git a/src/argilla/server/contexts/datasets.py b/src/argilla/server/contexts/datasets.py index e2c42bcc8c..4a0d566a35 100644 --- a/src/argilla/server/contexts/datasets.py +++ b/src/argilla/server/contexts/datasets.py @@ -18,9 +18,10 @@ from fastapi.encoders import jsonable_encoder from sqlalchemy import and_, func -from sqlalchemy.orm import Session, contains_eager, joinedload +from sqlalchemy.orm import Session, contains_eager, joinedload, noload from argilla.server.contexts import accounts +from argilla.server.enums import ResponseStatusFilter from argilla.server.models import ( Dataset, DatasetStatus, @@ -37,7 +38,6 @@ QuestionCreate, RecordInclude, RecordsCreate, - ResponseStatusFilter, ) from argilla.server.schemas.v1.records import ResponseCreate from argilla.server.schemas.v1.responses import ResponseUpdate @@ -189,6 +189,26 @@ def get_record_by_id(db: Session, record_id: UUID): return db.get(Record, record_id) +def get_records_by_ids( + db: Session, + dataset_id: UUID, + record_ids: List[UUID], + include: List[RecordInclude] = [], + user_id: Optional[UUID] = None, +) -> List[Record]: + query = db.query(Record).filter(Record.dataset_id == dataset_id, Record.id.in_(record_ids)) + if RecordInclude.responses in include: + if user_id: + query = query.outerjoin( + Response, and_(Response.record_id == Record.id, Response.user_id == user_id) + ).options(contains_eager(Record.responses)) + else: + query = query.options(joinedload(Record.responses)) + else: + query = query.options(noload(Record.responses)) + return query.all() + + def list_records_by_dataset_id( db: Session, dataset_id: UUID, diff --git a/src/argilla/server/enums.py b/src/argilla/server/enums.py new file mode 100644 index 0000000000..b73f660168 --- /dev/null +++ b/src/argilla/server/enums.py @@ -0,0 +1,22 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class ResponseStatusFilter(str, Enum): + draft = "draft" + missing = "missing" + submitted = "submitted" + discarded = "discarded" diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index e1078bdb66..eba4b1b9eb 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -167,6 +167,19 @@ def create_question(cls, actor: User) -> bool: def create_records(cls, actor: User) -> bool: return actor.is_admin + @classmethod + def search_records(cls, dataset: Dataset) -> bool: + return lambda actor: ( + actor.is_admin + or bool( + accounts.get_workspace_user_by_workspace_id_and_user_id( + Session.object_session(actor), + dataset.workspace_id, + actor.id, + ) + ) + ) + @classmethod def publish(cls, actor: User) -> bool: return actor.is_admin diff --git a/src/argilla/server/schemas/v1/datasets.py b/src/argilla/server/schemas/v1/datasets.py index 57a42f738d..ab931dc3f3 100644 --- a/src/argilla/server/schemas/v1/datasets.py +++ b/src/argilla/server/schemas/v1/datasets.py @@ -20,6 +20,8 @@ from pydantic import BaseModel, PositiveInt, conlist, constr, root_validator, validator from pydantic import Field as PydanticField +from argilla.server.search_engine import Query + try: from typing import Annotated, Literal except ImportError: @@ -285,13 +287,6 @@ class RecordInclude(str, Enum): responses = "responses" -class ResponseStatusFilter(str, Enum): - draft = "draft" - missing = "missing" - submitted = "submitted" - discarded = "discarded" - - class Record(BaseModel): id: UUID fields: Dict[str, Any] @@ -347,3 +342,16 @@ def check_user_id_is_unique(cls, values): class RecordsCreate(BaseModel): items: conlist(item_type=RecordCreate, min_items=RECORDS_CREATE_MIN_ITEMS, max_items=RECORDS_CREATE_MAX_ITEMS) + + +class SearchRecordsQuery(BaseModel): + query: Query + + +class SearchRecord(BaseModel): + record: Record + query_score: Optional[float] + + +class SearchRecordsResult(BaseModel): + items: List[SearchRecord] diff --git a/src/argilla/server/search_engine.py b/src/argilla/server/search_engine.py index e58eae89d4..06adb8d03e 100644 --- a/src/argilla/server/search_engine.py +++ b/src/argilla/server/search_engine.py @@ -13,13 +13,14 @@ # limitations under the License. import dataclasses -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Union from uuid import UUID from opensearchpy import AsyncOpenSearch, helpers from pydantic import BaseModel from pydantic.utils import GetterDict +from argilla.server.enums import ResponseStatusFilter from argilla.server.models import ( Dataset, Field, @@ -31,7 +32,6 @@ ResponseStatus, User, ) -from argilla.server.schemas.v1.datasets import ResponseStatusFilter from argilla.server.settings import settings @@ -205,7 +205,7 @@ async def search( items = [] next_page_token = None for hit in response["hits"]["hits"]: - items.append(SearchResponseItem(record_id=hit["_id"], score=hit["_score"])) + items.append(SearchResponseItem(record_id=UUID(hit["_id"]), score=hit["_score"])) # See https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html next_page_token = hit.get("_sort") @@ -278,7 +278,7 @@ def _index_name_for_dataset(dataset: Dataset): return f"rg.{dataset.id}" -async def get_search_engine(): +async def get_search_engine() -> Generator[SearchEngine, None, None]: config = dict( hosts=settings.elasticsearch, verify_certs=settings.elasticsearch_ssl_verify, diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 548e03456c..e48052b760 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -13,12 +13,13 @@ # limitations under the License. from datetime import datetime -from typing import Optional, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type from unittest.mock import MagicMock from uuid import UUID, uuid4 import pytest from argilla._constants import API_KEY_HEADER_NAME +from argilla.server.apis.v1.handlers.datasets import LIST_DATASET_RECORDS_LIMIT_DEFAULT from argilla.server.models import ( Dataset, DatasetStatus, @@ -46,7 +47,14 @@ RECORDS_CREATE_MIN_ITEMS, RecordInclude, ) -from argilla.server.search_engine import SearchEngine +from argilla.server.search_engine import ( + Query, + SearchEngine, + SearchResponseItem, + SearchResponses, + TextQuery, + UserResponseStatusFilter, +) from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -65,6 +73,9 @@ WorkspaceFactory, ) +if TYPE_CHECKING: + from pytest_mock import MockerFixture + def test_list_current_user_datasets(client: TestClient, admin_auth_header: dict): dataset_a = DatasetFactory.create(name="dataset-a") @@ -2397,6 +2408,313 @@ def test_create_dataset_records_with_nonexistent_dataset_id(client: TestClient, assert db.query(Response).count() == 0 +def create_dataset_for_search(user: Optional[User] = None) -> Tuple[Dataset, List[Record]]: + dataset = DatasetFactory.create(status=DatasetStatus.ready) + TextFieldFactory.create(name="input", dataset=dataset) + TextFieldFactory.create(name="output", dataset=dataset) + TextQuestionFactory.create(name="input_ok", dataset=dataset) + TextQuestionFactory.create(name="output_ok", dataset=dataset) + records = [ + RecordFactory.create(dataset=dataset, fields={"input": "Say Hello", "output": "Hello"}), + RecordFactory.create(dataset=dataset, fields={"input": "Hello", "output": "Hi"}), + RecordFactory.create(dataset=dataset, fields={"input": "Say Goodbye", "output": "Goodbye"}), + RecordFactory.create(dataset=dataset, fields={"input": "Say bye", "output": "Bye"}), + ] + responses = [ + ResponseFactory.create( + record=records[0], + values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, + status=ResponseStatus.submitted, + user=user, + ), + ResponseFactory.create( + record=records[1], + values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, + status=ResponseStatus.submitted, + user=user, + ), + ResponseFactory.create( + record=records[2], + values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, + status=ResponseStatus.submitted, + user=user, + ), + ResponseFactory.create( + record=records[3], + values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, + status=ResponseStatus.submitted, + user=user, + ), + ] + # Add some responses from other users + ResponseFactory.create_batch(10, record=records[0], status=ResponseStatus.submitted) + return dataset, records, responses + + +def test_search_dataset_records( + client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict +): + dataset, records, _ = create_dataset_for_search(user=admin) + + mock_search_engine.search.return_value = SearchResponses( + items=[ + SearchResponseItem(record_id=records[0].id, score=14.2), + SearchResponseItem(record_id=records[1].id, score=12.2), + ] + ) + + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + response = client.post( + f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json + ) + + mock_search_engine.search.assert_called_once_with( + dataset=dataset, + query=Query( + text=TextQuery( + q="Hello", + field="input", + ) + ), + user_response_status_filter=None, + limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, + ) + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "record": { + "id": str(records[0].id), + "fields": { + "input": "Say Hello", + "output": "Hello", + }, + "external_id": records[0].external_id, + "responses": [], + "inserted_at": records[0].inserted_at.isoformat(), + "updated_at": records[0].updated_at.isoformat(), + }, + "query_score": 14.2, + }, + { + "record": { + "id": str(records[1].id), + "fields": { + "input": "Hello", + "output": "Hi", + }, + "external_id": records[1].external_id, + "responses": [], + "inserted_at": records[1].inserted_at.isoformat(), + "updated_at": records[1].updated_at.isoformat(), + }, + "query_score": 12.2, + }, + ] + } + + +def test_search_dataset_records_including_responses( + client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict +): + dataset, records, responses = create_dataset_for_search(user=admin) + + mock_search_engine.search.return_value = SearchResponses( + items=[ + SearchResponseItem(record_id=records[0].id, score=14.2), + SearchResponseItem(record_id=records[1].id, score=12.2), + ] + ) + + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + response = client.post( + f"/api/v1/me/datasets/{dataset.id}/records/search", + headers=admin_auth_header, + json=query_json, + params={"include": RecordInclude.responses.value}, + ) + + mock_search_engine.search.assert_called_once_with( + dataset=dataset, + query=Query( + text=TextQuery( + q="Hello", + field="input", + ) + ), + user_response_status_filter=None, + limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, + ) + assert response.status_code == 200 + assert response.json() == { + "items": [ + { + "record": { + "id": str(records[0].id), + "fields": { + "input": "Say Hello", + "output": "Hello", + }, + "external_id": records[0].external_id, + "responses": [ + { + "id": str(responses[0].id), + "values": { + "input_ok": {"value": "yes"}, + "output_ok": {"value": "yes"}, + }, + "status": "submitted", + "user_id": str(responses[0].user_id), + "inserted_at": responses[0].inserted_at.isoformat(), + "updated_at": responses[0].updated_at.isoformat(), + } + ], + "inserted_at": records[0].inserted_at.isoformat(), + "updated_at": records[0].updated_at.isoformat(), + }, + "query_score": 14.2, + }, + { + "record": { + "id": str(records[1].id), + "fields": { + "input": "Hello", + "output": "Hi", + }, + "external_id": records[1].external_id, + "responses": [ + { + "id": str(responses[1].id), + "values": { + "input_ok": {"value": "yes"}, + "output_ok": {"value": "yes"}, + }, + "status": "submitted", + "user_id": str(responses[1].user_id), + "inserted_at": responses[1].inserted_at.isoformat(), + "updated_at": responses[1].updated_at.isoformat(), + } + ], + "inserted_at": records[1].inserted_at.isoformat(), + "updated_at": records[1].updated_at.isoformat(), + }, + "query_score": 12.2, + }, + ] + } + + +def test_search_dataset_records_with_response_status_filter( + client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict +): + dataset, _, _ = create_dataset_for_search(user=admin) + mock_search_engine.search.return_value = SearchResponses(items=[]) + + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + response = client.post( + f"/api/v1/me/datasets/{dataset.id}/records/search", + headers=admin_auth_header, + json=query_json, + params={"response_status": ResponseStatus.submitted.value}, + ) + + mock_search_engine.search.assert_called_once_with( + dataset=dataset, + query=Query(text=TextQuery(q="Hello", field="input")), + user_response_status_filter=UserResponseStatusFilter(user=admin, statuses=[ResponseStatus.submitted]), + limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, + ) + assert response.status_code == 200 + + +def test_search_dataset_records_with_limit( + client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict +): + dataset, _, _ = create_dataset_for_search(user=admin) + mock_search_engine.search.return_value = SearchResponses(items=[]) + + query_json = {"query": {"text": {"q": "Hello", "field": "input"}}} + response = client.post( + f"/api/v1/me/datasets/{dataset.id}/records/search", + headers=admin_auth_header, + json=query_json, + params={"limit": 10}, + ) + + mock_search_engine.search.assert_called_once_with( + dataset=dataset, + query=Query(text=TextQuery(q="Hello", field="input")), + user_response_status_filter=None, + limit=10, + ) + assert response.status_code == 200 + + +def test_search_dataset_records_as_annotator(client: TestClient, admin: User, mock_search_engine: SearchEngine): + dataset, records, _ = create_dataset_for_search(user=admin) + annotator = AnnotatorFactory.create(workspaces=[dataset.workspace]) + + mock_search_engine.search.return_value = SearchResponses( + items=[ + SearchResponseItem(record_id=records[0].id, score=14.2), + SearchResponseItem(record_id=records[1].id, score=12.2), + ] + ) + + query_json = {"query": {"text": {"q": "unit test", "field": "input"}}} + response = client.post( + f"/api/v1/me/datasets/{dataset.id}/records/search", + headers={API_KEY_HEADER_NAME: annotator.api_key}, + json=query_json, + ) + + mock_search_engine.search.assert_called_once_with( + dataset=dataset, + query=Query( + text=TextQuery( + q="unit test", + field="input", + ) + ), + user_response_status_filter=None, + limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT, + ) + assert response.status_code == 200 + + +def test_search_dataset_records_as_annotator_from_different_workspace(client: TestClient): + dataset, _, _ = create_dataset_for_search() + annotator = AnnotatorFactory.create(workspaces=[WorkspaceFactory.create()]) + + query_json = {"query": {"text": {"q": "unit test", "field": "input"}}} + response = client.post( + f"/api/v1/me/datasets/{dataset.id}/records/search", + headers={API_KEY_HEADER_NAME: annotator.api_key}, + json=query_json, + ) + + assert response.status_code == 403 + + +def test_search_dataset_records_with_non_existent_field(client: TestClient, admin_auth_header: dict): + dataset, _, _ = create_dataset_for_search() + + query_json = {"query": {"text": {"q": "unit test", "field": "i do not exist"}}} + response = client.post( + f"/api/v1/me/datasets/{dataset.id}/records/search", headers=admin_auth_header, json=query_json + ) + + assert response.status_code == 422 + + +def test_search_dataset_with_non_existent_dataset(client: TestClient, admin_auth_header: dict): + query_json = {"query": {"text": {"q": "unit test", "field": "input"}}} + response = client.post(f"/api/v1/me/datasets/{uuid4()}/records/search", headers=admin_auth_header, json=query_json) + + assert response.status_code == 404 + + +@pytest.mark.asyncio def test_publish_dataset( client: TestClient, db: Session,