diff --git a/src/argilla/server/search_engine.py b/src/argilla/server/search_engine.py index ff1f795303..aef96570c5 100644 --- a/src/argilla/server/search_engine.py +++ b/src/argilla/server/search_engine.py @@ -13,7 +13,8 @@ # limitations under the License. import dataclasses -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, List, Optional, Union +from uuid import UUID from opensearchpy import AsyncOpenSearch, helpers from pydantic import BaseModel @@ -27,7 +28,9 @@ QuestionType, Record, ResponseStatus, + User, ) +from argilla.server.schemas.v1.datasets import ResponseStatusFilter from argilla.server.settings import settings @@ -50,6 +53,7 @@ class UserResponse(BaseModel): class SearchDocument(BaseModel): + id: UUID fields: Dict[str, Any] responses: Optional[Dict[str, UserResponse]] @@ -59,6 +63,34 @@ class Config: getter_dict = SearchDocumentGetter +@dataclasses.dataclass +class TextQuery: + q: str + field: Optional[str] = None + + +@dataclasses.dataclass +class Query: + text: TextQuery + + +@dataclasses.dataclass +class UserResponseStatusFilter: + user: User + statuses: List[ResponseStatusFilter] + + +@dataclasses.dataclass +class SearchResponseItem: + record_id: UUID + score: Optional[float] + + +@dataclasses.dataclass +class SearchResponses: + items: List[SearchResponseItem] + + @dataclasses.dataclass class SearchEngine: config: Dict[str, Any] @@ -68,6 +100,7 @@ def __post_init__(self): async def create_index(self, dataset: Dataset): fields = { + "id": {"type": "keyword"}, "responses": {"dynamic": True, "type": "object"}, } @@ -75,7 +108,7 @@ async def create_index(self, dataset: Dataset): fields[f"fields.{field.name}"] = self._es_mapping_for_field(field) # See https://www.elastic.co/guide/en/elasticsearch/reference/current/dynamic-templates.html - dynamic_templates = [ + dynamic_templates: List[dict] = [ { f"{question.name}_responses": { "path_match": f"responses.*.values.{question.name}", @@ -96,27 +129,6 @@ async def create_index(self, dataset: Dataset): index_name = self._index_name_for_dataset(dataset) await self.client.indices.create(index=index_name, body=dict(mappings=mappings)) - def _field_mapping_for_question(self, question: Question): - settings = question.parsed_settings - - if settings.type == QuestionType.rating: - # See https://www.elastic.co/guide/en/elasticsearch/reference/current/number.html - return {"type": "integer"} - elif settings.type in [QuestionType.text, QuestionType.label_selection, QuestionType.multi_label_selection]: - # TODO: Review mapping for label selection. Could make sense to use `keyword` mapping instead. See https://www.elastic.co/guide/en/elasticsearch/reference/current/keyword.html - # See https://www.elastic.co/guide/en/elasticsearch/reference/current/text.html - return {"type": "text", "index": False} - else: - raise ValueError(f"ElasticSearch mappings for Question of type {settings.type} cannot be generated") - - def _es_mapping_for_field(self, field: Field): - field_type = field.settings["type"] - - if field_type == FieldType.text: - return {"type": "text"} - else: - raise ValueError(f"ElasticSearch mappings for Field of type {field_type} cannot be generated") - async def add_records(self, dataset: Dataset, records: Iterable[Record]): index_name = self._index_name_for_dataset(dataset) @@ -139,6 +151,82 @@ async def add_records(self, dataset: Dataset, records: Iterable[Record]): if errors: raise RuntimeError(errors) + async def search( + self, + dataset: Dataset, + query: Union[Query, str], + user_response_status_filter: Optional[UserResponseStatusFilter] = None, + limit: int = 100, + ) -> SearchResponses: + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html + + if isinstance(query, str): + query = Query(text=TextQuery(q=query)) + + text_query = self._text_query_builder(dataset, text=query.text) + + bool_query = {"must": [text_query]} + if user_response_status_filter: + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html + user_response_status_field = f"responses.{user_response_status_filter.user.username}.status" + bool_query["filter"] = [{"terms": {user_response_status_field: user_response_status_filter.statuses}}] + + body = { + "_source": False, + "query": {"bool": bool_query}, + "sort": ["_score", {"id": "asc"}], + } + # TODO: Work on search pagination after endpoint integration + next_page_token = None + if next_page_token: + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html + body["search_after"] = next_page_token + + response = await self.client.search(index=self._index_name_for_dataset(dataset), size=limit, body=body) + + items = [] + next_page_token = None + for hit in response["hits"]["hits"]: + items.append(SearchResponseItem(record_id=hit["_id"], score=hit["_score"])) + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html + next_page_token = hit.get("_sort") + + return SearchResponses(items=items) + + @staticmethod + def _text_query_builder(dataset: Dataset, text: TextQuery) -> dict: + if not text.field: + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-multi-match-query.html + field_names = [ + f"fields.{field.name}" for field in dataset.fields if field.settings.get("type") == FieldType.text + ] + return {"multi_match": {"query": text.q, "fields": field_names, "operator": "and"}} + else: + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html + return {"match": {f"fields.{text.field}": {"query": text.q, "operator": "and"}}} + + def _field_mapping_for_question(self, question: Question): + settings = question.parsed_settings + + if settings.type == QuestionType.rating: + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/number.html + return {"type": "integer"} + elif settings.type in [QuestionType.text, QuestionType.label_selection, QuestionType.multi_label_selection]: + # TODO: Review mapping for label selection. Could make sense to use `keyword` mapping instead. + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/keyword.html + # See https://www.elastic.co/guide/en/elasticsearch/reference/current/text.html + return {"type": "text", "index": False} + else: + raise ValueError(f"ElasticSearch mappings for Question of type {settings.type} cannot be generated") + + def _es_mapping_for_field(self, field: Field): + field_type = field.settings["type"] + + if field_type == FieldType.text: + return {"type": "text"} + else: + raise ValueError(f"ElasticSearch mappings for Field of type {field_type} cannot be generated") + @staticmethod def _index_name_for_dataset(dataset: Dataset): return f"rg.{dataset.id}" diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index fc2a8380d4..9a9022e6d5 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -1791,21 +1791,30 @@ async def test_create_dataset_records( index_name = f"rg.{dataset.id}" opensearch.indices.refresh(index=index_name) - assert [hit["_source"] for hit in opensearch.search(index=index_name)["hits"]["hits"]] == [ + es_docs = [hit["_source"] for hit in opensearch.search(index=index_name)["hits"]["hits"]] + assert es_docs == [ { + "id": str(db.get(Record, UUID(es_docs[0]["id"])).id), "fields": {"input": "Say Hello", "output": "Hello"}, "responses": {"admin": {"values": {"input_ok": "yes", "output_ok": "yes"}, "status": "submitted"}}, }, - {"fields": {"input": "Say Hello", "output": "Hi"}, "responses": {}}, { + "id": str(db.get(Record, UUID(es_docs[1]["id"])).id), + "fields": {"input": "Say Hello", "output": "Hi"}, + "responses": {}, + }, + { + "id": str(db.get(Record, UUID(es_docs[2]["id"])).id), "fields": {"input": "Say Pello", "output": "Hello World"}, "responses": {"admin": {"values": {"input_ok": "no", "output_ok": "no"}, "status": "submitted"}}, }, { + "id": str(db.get(Record, UUID(es_docs[3]["id"])).id), "fields": {"input": "Say Hello", "output": "Good Morning"}, "responses": {"admin": {"values": {"input_ok": "yes", "output_ok": "no"}, "status": "discarded"}}, }, { + "id": str(db.get(Record, UUID(es_docs[4]["id"])).id), "fields": {"input": "Say Hello", "output": "Say Hello"}, "responses": {"admin": {"values": None, "status": "discarded"}}, }, diff --git a/tests/server/test_search_engine.py b/tests/server/test_search_engine.py index b3d8d38820..30ecb18a9e 100644 --- a/tests/server/test_search_engine.py +++ b/tests/server/test_search_engine.py @@ -15,18 +15,68 @@ import random import pytest -from argilla.server.search_engine import SearchEngine +import pytest_asyncio +from argilla.server.models import Dataset +from argilla.server.search_engine import Query as SearchQuery +from argilla.server.search_engine import SearchEngine, TextQuery from opensearchpy import OpenSearch, RequestError from sqlalchemy.orm import Session from tests.factories import ( DatasetFactory, RatingQuestionFactory, + RecordFactory, TextFieldFactory, TextQuestionFactory, ) +@pytest_asyncio.fixture() +async def test_banking_sentiment_dataset(search_engine: SearchEngine): + text_question = TextQuestionFactory() + rating_question = RatingQuestionFactory() + + dataset = DatasetFactory.create( + fields=[TextFieldFactory(name="textId"), TextFieldFactory(name="text"), TextFieldFactory(name="label")], + questions=[text_question, rating_question], + ) + + await search_engine.create_index(dataset) + + await search_engine.add_records( + dataset, + records=[ + RecordFactory( + dataset=dataset, + fields={"textId": "00000", "text": "My card payment had the wrong exchange rate", "label": "negative"}, + ), + RecordFactory( + dataset=dataset, + fields={ + "textId": "00001", + "text": "I believe that a card payment I made was cancelled.", + "label": "neutral", + }, + ), + RecordFactory( + dataset=dataset, + fields={"textId": "00002", "text": "Why was I charged for getting cash?", "label": "neutral"}, + ), + RecordFactory( + dataset=dataset, + fields={ + "textId": "00003", + "text": "I deposited cash into my account a week ago and it is still not available," + " please tell me why? I need the cash back now.", + "label": "negative", + }, + ), + ], + ) + + return dataset + + @pytest.mark.asyncio class TestSuiteElasticSearchEngine: async def test_create_index_for_dataset(self, search_engine: SearchEngine, opensearch: OpenSearch): @@ -41,6 +91,7 @@ async def test_create_index_for_dataset(self, search_engine: SearchEngine, opens "dynamic": "strict", "dynamic_templates": [], "properties": { + "id": {"type": "keyword"}, "responses": {"dynamic": "true", "type": "object"}, }, } @@ -64,6 +115,7 @@ async def test_create_index_for_dataset_with_fields( "dynamic": "strict", "dynamic_templates": [], "properties": { + "id": {"type": "keyword"}, "fields": {"properties": {field.name: {"type": "text"} for field in dataset.fields}}, "responses": {"type": "object", "dynamic": "true"}, }, @@ -95,6 +147,7 @@ async def test_create_index_for_dataset_with_questions( assert index["mappings"] == { "dynamic": "strict", "properties": { + "id": {"type": "keyword"}, "responses": {"dynamic": "true", "type": "object"}, }, "dynamic_templates": [ @@ -136,3 +189,47 @@ async def test_create_index_with_existing_index( with pytest.raises(RequestError, match="resource_already_exists_exception"): await search_engine.create_index(dataset) + + @pytest.mark.parametrize( + ("query", "expected_items"), + [ + ("card", 2), + ("account", 1), + ("payment", 2), + ("cash", 2), + ("negative", 2), + ("00000", 1), + ("card payment", 2), + ("nothing", 0), + (SearchQuery(text=TextQuery(q="card")), 2), + (SearchQuery(text=TextQuery(q="account")), 1), + (SearchQuery(text=TextQuery(q="payment")), 2), + (SearchQuery(text=TextQuery(q="cash")), 2), + (SearchQuery(text=TextQuery(q="card payment")), 2), + (SearchQuery(text=TextQuery(q="nothing")), 0), + (SearchQuery(text=TextQuery(q="negative", field="label")), 2), + (SearchQuery(text=TextQuery(q="00000", field="textId")), 1), + (SearchQuery(text=TextQuery(q="card payment", field="text")), 2), + ], + ) + async def test_search_with_query_string( + self, + search_engine: SearchEngine, + opensearch: OpenSearch, + db: Session, + test_banking_sentiment_dataset: Dataset, + query: str, + expected_items: int, + ): + opensearch.indices.refresh(index=f"rg.{test_banking_sentiment_dataset.id}") + + result = await search_engine.search(test_banking_sentiment_dataset, query=query) + assert len(result.items) == expected_items + + scores = [item.score > 0 for item in result.items] + assert all(map(lambda s: s > 0, scores)) + + sorted_scores = scores.copy() + sorted_scores.sort(reverse=True) + + assert scores == sorted_scores