Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable search data from SearchEngine component #3037

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 111 additions & 23 deletions src/argilla/server/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import dataclasses
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional, Union
from uuid import UUID

from opensearchpy import AsyncOpenSearch, helpers
from pydantic import BaseModel
Expand All @@ -27,7 +28,9 @@
QuestionType,
Record,
ResponseStatus,
User,
)
from argilla.server.schemas.v1.datasets import ResponseStatusFilter
from argilla.server.settings import settings


Expand All @@ -50,6 +53,7 @@ class UserResponse(BaseModel):


class SearchDocument(BaseModel):
id: UUID
fields: Dict[str, Any]

responses: Optional[Dict[str, UserResponse]]
Expand All @@ -59,6 +63,34 @@ class Config:
getter_dict = SearchDocumentGetter


@dataclasses.dataclass
class TextQuery:
q: str
field: Optional[str] = None


@dataclasses.dataclass
class Query:
text: TextQuery


@dataclasses.dataclass
class UserResponseStatusFilter:
user: User
statuses: List[ResponseStatusFilter]


@dataclasses.dataclass
class SearchResponseItem:
record_id: UUID
score: Optional[float]


@dataclasses.dataclass
class SearchResponses:
items: List[SearchResponseItem]


@dataclasses.dataclass
class SearchEngine:
config: Dict[str, Any]
Expand All @@ -68,14 +100,15 @@ def __post_init__(self):

async def create_index(self, dataset: Dataset):
fields = {
"id": {"type": "keyword"},
"responses": {"dynamic": True, "type": "object"},
}

for field in dataset.fields:
fields[f"fields.{field.name}"] = self._es_mapping_for_field(field)

# See https://www.elastic.co/guide/en/elasticsearch/reference/current/dynamic-templates.html
dynamic_templates = [
dynamic_templates: List[dict] = [
{
f"{question.name}_responses": {
"path_match": f"responses.*.values.{question.name}",
Expand All @@ -96,27 +129,6 @@ async def create_index(self, dataset: Dataset):
index_name = self._index_name_for_dataset(dataset)
await self.client.indices.create(index=index_name, body=dict(mappings=mappings))

def _field_mapping_for_question(self, question: Question):
settings = question.parsed_settings

if settings.type == QuestionType.rating:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/number.html
return {"type": "integer"}
elif settings.type in [QuestionType.text, QuestionType.label_selection, QuestionType.multi_label_selection]:
# TODO: Review mapping for label selection. Could make sense to use `keyword` mapping instead. See https://www.elastic.co/guide/en/elasticsearch/reference/current/keyword.html
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/text.html
return {"type": "text", "index": False}
else:
raise ValueError(f"ElasticSearch mappings for Question of type {settings.type} cannot be generated")

def _es_mapping_for_field(self, field: Field):
field_type = field.settings["type"]

if field_type == FieldType.text:
return {"type": "text"}
else:
raise ValueError(f"ElasticSearch mappings for Field of type {field_type} cannot be generated")

async def add_records(self, dataset: Dataset, records: Iterable[Record]):
index_name = self._index_name_for_dataset(dataset)

Expand All @@ -139,6 +151,82 @@ async def add_records(self, dataset: Dataset, records: Iterable[Record]):
if errors:
raise RuntimeError(errors)

async def search(
self,
dataset: Dataset,
query: Union[Query, str],
user_response_status_filter: Optional[UserResponseStatusFilter] = None,
limit: int = 100,
) -> SearchResponses:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html

if isinstance(query, str):
query = Query(text=TextQuery(q=query))

text_query = self._text_query_builder(dataset, text=query.text)

bool_query = {"must": [text_query]}
if user_response_status_filter:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html
user_response_status_field = f"responses.{user_response_status_filter.user.username}.status"
bool_query["filter"] = [{"terms": {user_response_status_field: user_response_status_filter.statuses}}]

body = {
"_source": False,
"query": {"bool": bool_query},
"sort": ["_score", {"id": "asc"}],
}
# TODO: Work on search pagination after endpoint integration
next_page_token = None
if next_page_token:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html
body["search_after"] = next_page_token

response = await self.client.search(index=self._index_name_for_dataset(dataset), size=limit, body=body)

items = []
next_page_token = None
for hit in response["hits"]["hits"]:
items.append(SearchResponseItem(record_id=hit["_id"], score=hit["_score"]))
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html
next_page_token = hit.get("_sort")

return SearchResponses(items=items)

@staticmethod
def _text_query_builder(dataset: Dataset, text: TextQuery) -> dict:
if not text.field:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-multi-match-query.html
field_names = [
f"fields.{field.name}" for field in dataset.fields if field.settings.get("type") == FieldType.text
]
return {"multi_match": {"query": text.q, "fields": field_names, "operator": "and"}}
else:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html
return {"match": {f"fields.{text.field}": {"query": text.q, "operator": "and"}}}

def _field_mapping_for_question(self, question: Question):
settings = question.parsed_settings

if settings.type == QuestionType.rating:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/number.html
return {"type": "integer"}
elif settings.type in [QuestionType.text, QuestionType.label_selection, QuestionType.multi_label_selection]:
# TODO: Review mapping for label selection. Could make sense to use `keyword` mapping instead.
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/keyword.html
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/text.html
return {"type": "text", "index": False}
else:
raise ValueError(f"ElasticSearch mappings for Question of type {settings.type} cannot be generated")

def _es_mapping_for_field(self, field: Field):
field_type = field.settings["type"]

if field_type == FieldType.text:
return {"type": "text"}
else:
raise ValueError(f"ElasticSearch mappings for Field of type {field_type} cannot be generated")

@staticmethod
def _index_name_for_dataset(dataset: Dataset):
return f"rg.{dataset.id}"
Expand Down
13 changes: 11 additions & 2 deletions tests/server/api/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,21 +1791,30 @@ async def test_create_dataset_records(

index_name = f"rg.{dataset.id}"
opensearch.indices.refresh(index=index_name)
assert [hit["_source"] for hit in opensearch.search(index=index_name)["hits"]["hits"]] == [
es_docs = [hit["_source"] for hit in opensearch.search(index=index_name)["hits"]["hits"]]
assert es_docs == [
{
"id": str(db.get(Record, UUID(es_docs[0]["id"])).id),
"fields": {"input": "Say Hello", "output": "Hello"},
"responses": {"admin": {"values": {"input_ok": "yes", "output_ok": "yes"}, "status": "submitted"}},
},
{"fields": {"input": "Say Hello", "output": "Hi"}, "responses": {}},
{
"id": str(db.get(Record, UUID(es_docs[1]["id"])).id),
"fields": {"input": "Say Hello", "output": "Hi"},
"responses": {},
},
{
"id": str(db.get(Record, UUID(es_docs[2]["id"])).id),
"fields": {"input": "Say Pello", "output": "Hello World"},
"responses": {"admin": {"values": {"input_ok": "no", "output_ok": "no"}, "status": "submitted"}},
},
{
"id": str(db.get(Record, UUID(es_docs[3]["id"])).id),
"fields": {"input": "Say Hello", "output": "Good Morning"},
"responses": {"admin": {"values": {"input_ok": "yes", "output_ok": "no"}, "status": "discarded"}},
},
{
"id": str(db.get(Record, UUID(es_docs[4]["id"])).id),
"fields": {"input": "Say Hello", "output": "Say Hello"},
"responses": {"admin": {"values": None, "status": "discarded"}},
},
Expand Down
99 changes: 98 additions & 1 deletion tests/server/test_search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,68 @@
import random

import pytest
from argilla.server.search_engine import SearchEngine
import pytest_asyncio
from argilla.server.models import Dataset
from argilla.server.search_engine import Query as SearchQuery
from argilla.server.search_engine import SearchEngine, TextQuery
from opensearchpy import OpenSearch, RequestError
from sqlalchemy.orm import Session

from tests.factories import (
DatasetFactory,
RatingQuestionFactory,
RecordFactory,
TextFieldFactory,
TextQuestionFactory,
)


@pytest_asyncio.fixture()
async def test_banking_sentiment_dataset(search_engine: SearchEngine):
text_question = TextQuestionFactory()
rating_question = RatingQuestionFactory()

dataset = DatasetFactory.create(
fields=[TextFieldFactory(name="textId"), TextFieldFactory(name="text"), TextFieldFactory(name="label")],
questions=[text_question, rating_question],
)

await search_engine.create_index(dataset)

await search_engine.add_records(
dataset,
records=[
RecordFactory(
dataset=dataset,
fields={"textId": "00000", "text": "My card payment had the wrong exchange rate", "label": "negative"},
),
RecordFactory(
dataset=dataset,
fields={
"textId": "00001",
"text": "I believe that a card payment I made was cancelled.",
"label": "neutral",
},
),
RecordFactory(
dataset=dataset,
fields={"textId": "00002", "text": "Why was I charged for getting cash?", "label": "neutral"},
),
RecordFactory(
dataset=dataset,
fields={
"textId": "00003",
"text": "I deposited cash into my account a week ago and it is still not available,"
" please tell me why? I need the cash back now.",
"label": "negative",
},
),
],
)

return dataset


@pytest.mark.asyncio
class TestSuiteElasticSearchEngine:
async def test_create_index_for_dataset(self, search_engine: SearchEngine, opensearch: OpenSearch):
Expand All @@ -41,6 +91,7 @@ async def test_create_index_for_dataset(self, search_engine: SearchEngine, opens
"dynamic": "strict",
"dynamic_templates": [],
"properties": {
"id": {"type": "keyword"},
"responses": {"dynamic": "true", "type": "object"},
},
}
Expand All @@ -64,6 +115,7 @@ async def test_create_index_for_dataset_with_fields(
"dynamic": "strict",
"dynamic_templates": [],
"properties": {
"id": {"type": "keyword"},
"fields": {"properties": {field.name: {"type": "text"} for field in dataset.fields}},
"responses": {"type": "object", "dynamic": "true"},
},
Expand Down Expand Up @@ -95,6 +147,7 @@ async def test_create_index_for_dataset_with_questions(
assert index["mappings"] == {
"dynamic": "strict",
"properties": {
"id": {"type": "keyword"},
"responses": {"dynamic": "true", "type": "object"},
},
"dynamic_templates": [
Expand Down Expand Up @@ -136,3 +189,47 @@ async def test_create_index_with_existing_index(

with pytest.raises(RequestError, match="resource_already_exists_exception"):
await search_engine.create_index(dataset)

@pytest.mark.parametrize(
("query", "expected_items"),
[
("card", 2),
("account", 1),
("payment", 2),
("cash", 2),
("negative", 2),
("00000", 1),
("card payment", 2),
("nothing", 0),
(SearchQuery(text=TextQuery(q="card")), 2),
(SearchQuery(text=TextQuery(q="account")), 1),
(SearchQuery(text=TextQuery(q="payment")), 2),
(SearchQuery(text=TextQuery(q="cash")), 2),
(SearchQuery(text=TextQuery(q="card payment")), 2),
(SearchQuery(text=TextQuery(q="nothing")), 0),
(SearchQuery(text=TextQuery(q="negative", field="label")), 2),
(SearchQuery(text=TextQuery(q="00000", field="textId")), 1),
(SearchQuery(text=TextQuery(q="card payment", field="text")), 2),
],
)
async def test_search_with_query_string(
self,
search_engine: SearchEngine,
opensearch: OpenSearch,
db: Session,
test_banking_sentiment_dataset: Dataset,
query: str,
expected_items: int,
):
opensearch.indices.refresh(index=f"rg.{test_banking_sentiment_dataset.id}")

result = await search_engine.search(test_banking_sentiment_dataset, query=query)
assert len(result.items) == expected_items

scores = [item.score > 0 for item in result.items]
assert all(map(lambda s: s > 0, scores))

sorted_scores = scores.copy()
sorted_scores.sort(reverse=True)

assert scores == sorted_scores