Skip to content

Commit

Permalink
Enable search data from SearchEngine component (#3037)
Browse files Browse the repository at this point in the history
# Description

Adding `search` functionality to the `SearchEngine` class. This
functionality will be integrated with the new search endpoint in another
PR.

This search applies a basic search over all record fields, or by
selecting one of them.

### Update  

The search pagination will be tackled after the search endpoint
integration.

Closes #3017 

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [x] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

New test cases have been added 

**Checklist**

- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and davidberenstein1957 committed Jun 5, 2023
1 parent c520d9b commit 4f2f78b
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 26 deletions.
134 changes: 111 additions & 23 deletions src/argilla/server/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import dataclasses
from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional, Union
from uuid import UUID

from opensearchpy import AsyncOpenSearch, helpers
from pydantic import BaseModel
Expand All @@ -27,7 +28,9 @@
QuestionType,
Record,
ResponseStatus,
User,
)
from argilla.server.schemas.v1.datasets import ResponseStatusFilter
from argilla.server.settings import settings


Expand All @@ -50,6 +53,7 @@ class UserResponse(BaseModel):


class SearchDocument(BaseModel):
id: UUID
fields: Dict[str, Any]

responses: Optional[Dict[str, UserResponse]]
Expand All @@ -59,6 +63,34 @@ class Config:
getter_dict = SearchDocumentGetter


@dataclasses.dataclass
class TextQuery:
q: str
field: Optional[str] = None


@dataclasses.dataclass
class Query:
text: TextQuery


@dataclasses.dataclass
class UserResponseStatusFilter:
user: User
statuses: List[ResponseStatusFilter]


@dataclasses.dataclass
class SearchResponseItem:
record_id: UUID
score: Optional[float]


@dataclasses.dataclass
class SearchResponses:
items: List[SearchResponseItem]


@dataclasses.dataclass
class SearchEngine:
config: Dict[str, Any]
Expand All @@ -71,14 +103,15 @@ def __post_init__(self):

async def create_index(self, dataset: Dataset):
fields = {
"id": {"type": "keyword"},
"responses": {"dynamic": True, "type": "object"},
}

for field in dataset.fields:
fields[f"fields.{field.name}"] = self._es_mapping_for_field(field)

# See https://www.elastic.co/guide/en/elasticsearch/reference/current/dynamic-templates.html
dynamic_templates = [
dynamic_templates: List[dict] = [
{
f"{question.name}_responses": {
"path_match": f"responses.*.values.{question.name}",
Expand Down Expand Up @@ -108,27 +141,6 @@ async def delete_index(self, dataset: Dataset):
index_name = self._index_name_for_dataset(dataset)
await self.client.indices.delete(index_name, ignore=[404], ignore_unavailable=True)

def _field_mapping_for_question(self, question: Question):
settings = question.parsed_settings

if settings.type == QuestionType.rating:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/number.html
return {"type": "integer"}
elif settings.type in [QuestionType.text, QuestionType.label_selection, QuestionType.multi_label_selection]:
# TODO: Review mapping for label selection. Could make sense to use `keyword` mapping instead. See https://www.elastic.co/guide/en/elasticsearch/reference/current/keyword.html
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/text.html
return {"type": "text", "index": False}
else:
raise ValueError(f"ElasticSearch mappings for Question of type {settings.type} cannot be generated")

def _es_mapping_for_field(self, field: Field):
field_type = field.settings["type"]

if field_type == FieldType.text:
return {"type": "text"}
else:
raise ValueError(f"ElasticSearch mappings for Field of type {field_type} cannot be generated")

async def add_records(self, dataset: Dataset, records: Iterable[Record]):
index_name = self._index_name_for_dataset(dataset)

Expand All @@ -151,6 +163,82 @@ async def add_records(self, dataset: Dataset, records: Iterable[Record]):
if errors:
raise RuntimeError(errors)

async def search(
self,
dataset: Dataset,
query: Union[Query, str],
user_response_status_filter: Optional[UserResponseStatusFilter] = None,
limit: int = 100,
) -> SearchResponses:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html

if isinstance(query, str):
query = Query(text=TextQuery(q=query))

text_query = self._text_query_builder(dataset, text=query.text)

bool_query = {"must": [text_query]}
if user_response_status_filter:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html
user_response_status_field = f"responses.{user_response_status_filter.user.username}.status"
bool_query["filter"] = [{"terms": {user_response_status_field: user_response_status_filter.statuses}}]

body = {
"_source": False,
"query": {"bool": bool_query},
"sort": ["_score", {"id": "asc"}],
}
# TODO: Work on search pagination after endpoint integration
next_page_token = None
if next_page_token:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html
body["search_after"] = next_page_token

response = await self.client.search(index=self._index_name_for_dataset(dataset), size=limit, body=body)

items = []
next_page_token = None
for hit in response["hits"]["hits"]:
items.append(SearchResponseItem(record_id=hit["_id"], score=hit["_score"]))
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html
next_page_token = hit.get("_sort")

return SearchResponses(items=items)

@staticmethod
def _text_query_builder(dataset: Dataset, text: TextQuery) -> dict:
if not text.field:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-multi-match-query.html
field_names = [
f"fields.{field.name}" for field in dataset.fields if field.settings.get("type") == FieldType.text
]
return {"multi_match": {"query": text.q, "fields": field_names, "operator": "and"}}
else:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html
return {"match": {f"fields.{text.field}": {"query": text.q, "operator": "and"}}}

def _field_mapping_for_question(self, question: Question):
settings = question.parsed_settings

if settings.type == QuestionType.rating:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/number.html
return {"type": "integer"}
elif settings.type in [QuestionType.text, QuestionType.label_selection, QuestionType.multi_label_selection]:
# TODO: Review mapping for label selection. Could make sense to use `keyword` mapping instead.
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/keyword.html
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/text.html
return {"type": "text", "index": False}
else:
raise ValueError(f"ElasticSearch mappings for Question of type {settings.type} cannot be generated")

def _es_mapping_for_field(self, field: Field):
field_type = field.settings["type"]

if field_type == FieldType.text:
return {"type": "text"}
else:
raise ValueError(f"ElasticSearch mappings for Field of type {field_type} cannot be generated")

@staticmethod
def _index_name_for_dataset(dataset: Dataset):
return f"rg.{dataset.id}"
Expand Down
13 changes: 11 additions & 2 deletions tests/server/api/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,21 +1791,30 @@ async def test_create_dataset_records(

index_name = f"rg.{dataset.id}"
opensearch.indices.refresh(index=index_name)
assert [hit["_source"] for hit in opensearch.search(index=index_name)["hits"]["hits"]] == [
es_docs = [hit["_source"] for hit in opensearch.search(index=index_name)["hits"]["hits"]]
assert es_docs == [
{
"id": str(db.get(Record, UUID(es_docs[0]["id"])).id),
"fields": {"input": "Say Hello", "output": "Hello"},
"responses": {"admin": {"values": {"input_ok": "yes", "output_ok": "yes"}, "status": "submitted"}},
},
{"fields": {"input": "Say Hello", "output": "Hi"}, "responses": {}},
{
"id": str(db.get(Record, UUID(es_docs[1]["id"])).id),
"fields": {"input": "Say Hello", "output": "Hi"},
"responses": {},
},
{
"id": str(db.get(Record, UUID(es_docs[2]["id"])).id),
"fields": {"input": "Say Pello", "output": "Hello World"},
"responses": {"admin": {"values": {"input_ok": "no", "output_ok": "no"}, "status": "submitted"}},
},
{
"id": str(db.get(Record, UUID(es_docs[3]["id"])).id),
"fields": {"input": "Say Hello", "output": "Good Morning"},
"responses": {"admin": {"values": {"input_ok": "yes", "output_ok": "no"}, "status": "discarded"}},
},
{
"id": str(db.get(Record, UUID(es_docs[4]["id"])).id),
"fields": {"input": "Say Hello", "output": "Say Hello"},
"responses": {"admin": {"values": None, "status": "discarded"}},
},
Expand Down
99 changes: 98 additions & 1 deletion tests/server/test_search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,68 @@
import random

import pytest
from argilla.server.search_engine import SearchEngine
import pytest_asyncio
from argilla.server.models import Dataset
from argilla.server.search_engine import Query as SearchQuery
from argilla.server.search_engine import SearchEngine, TextQuery
from opensearchpy import OpenSearch, RequestError
from sqlalchemy.orm import Session

from tests.factories import (
DatasetFactory,
RatingQuestionFactory,
RecordFactory,
TextFieldFactory,
TextQuestionFactory,
)


@pytest_asyncio.fixture()
async def test_banking_sentiment_dataset(search_engine: SearchEngine):
text_question = TextQuestionFactory()
rating_question = RatingQuestionFactory()

dataset = DatasetFactory.create(
fields=[TextFieldFactory(name="textId"), TextFieldFactory(name="text"), TextFieldFactory(name="label")],
questions=[text_question, rating_question],
)

await search_engine.create_index(dataset)

await search_engine.add_records(
dataset,
records=[
RecordFactory(
dataset=dataset,
fields={"textId": "00000", "text": "My card payment had the wrong exchange rate", "label": "negative"},
),
RecordFactory(
dataset=dataset,
fields={
"textId": "00001",
"text": "I believe that a card payment I made was cancelled.",
"label": "neutral",
},
),
RecordFactory(
dataset=dataset,
fields={"textId": "00002", "text": "Why was I charged for getting cash?", "label": "neutral"},
),
RecordFactory(
dataset=dataset,
fields={
"textId": "00003",
"text": "I deposited cash into my account a week ago and it is still not available,"
" please tell me why? I need the cash back now.",
"label": "negative",
},
),
],
)

return dataset


@pytest.mark.asyncio
class TestSuiteElasticSearchEngine:
async def test_create_index_for_dataset(self, search_engine: SearchEngine, opensearch: OpenSearch):
Expand All @@ -41,6 +91,7 @@ async def test_create_index_for_dataset(self, search_engine: SearchEngine, opens
"dynamic": "strict",
"dynamic_templates": [],
"properties": {
"id": {"type": "keyword"},
"responses": {"dynamic": "true", "type": "object"},
},
}
Expand All @@ -66,6 +117,7 @@ async def test_create_index_for_dataset_with_fields(
"dynamic": "strict",
"dynamic_templates": [],
"properties": {
"id": {"type": "keyword"},
"fields": {"properties": {field.name: {"type": "text"} for field in dataset.fields}},
"responses": {"type": "object", "dynamic": "true"},
},
Expand Down Expand Up @@ -97,6 +149,7 @@ async def test_create_index_for_dataset_with_questions(
assert index["mappings"] == {
"dynamic": "strict",
"properties": {
"id": {"type": "keyword"},
"responses": {"dynamic": "true", "type": "object"},
},
"dynamic_templates": [
Expand Down Expand Up @@ -138,3 +191,47 @@ async def test_create_index_with_existing_index(

with pytest.raises(RequestError, match="resource_already_exists_exception"):
await search_engine.create_index(dataset)

@pytest.mark.parametrize(
("query", "expected_items"),
[
("card", 2),
("account", 1),
("payment", 2),
("cash", 2),
("negative", 2),
("00000", 1),
("card payment", 2),
("nothing", 0),
(SearchQuery(text=TextQuery(q="card")), 2),
(SearchQuery(text=TextQuery(q="account")), 1),
(SearchQuery(text=TextQuery(q="payment")), 2),
(SearchQuery(text=TextQuery(q="cash")), 2),
(SearchQuery(text=TextQuery(q="card payment")), 2),
(SearchQuery(text=TextQuery(q="nothing")), 0),
(SearchQuery(text=TextQuery(q="negative", field="label")), 2),
(SearchQuery(text=TextQuery(q="00000", field="textId")), 1),
(SearchQuery(text=TextQuery(q="card payment", field="text")), 2),
],
)
async def test_search_with_query_string(
self,
search_engine: SearchEngine,
opensearch: OpenSearch,
db: Session,
test_banking_sentiment_dataset: Dataset,
query: str,
expected_items: int,
):
opensearch.indices.refresh(index=f"rg.{test_banking_sentiment_dataset.id}")

result = await search_engine.search(test_banking_sentiment_dataset, query=query)
assert len(result.items) == expected_items

scores = [item.score > 0 for item in result.items]
assert all(map(lambda s: s > 0, scores))

sorted_scores = scores.copy()
sorted_scores.sort(reverse=True)

assert scores == sorted_scores

0 comments on commit 4f2f78b

Please sign in to comment.