Skip to content

Commit

Permalink
Refactor search engine
Browse files Browse the repository at this point in the history
* move text query build logic to a method
* Rename `TextFieldQuery` to `TextQuery`
  • Loading branch information
frascuchon committed Jun 1, 2023
1 parent 8caa5ac commit fcf6d1f
Showing 1 changed file with 18 additions and 26 deletions.
44 changes: 18 additions & 26 deletions src/argilla/server/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ class Config:


@dataclasses.dataclass
class TextFieldQuery:
class TextQuery:
q: str
field: Optional[str] = None


@dataclasses.dataclass
class Query:
text: TextFieldQuery
text: TextQuery


@dataclasses.dataclass
Expand Down Expand Up @@ -108,7 +108,7 @@ async def create_index(self, dataset: Dataset):
fields[f"fields.{field.name}"] = self._es_mapping_for_field(field)

# See https://www.elastic.co/guide/en/elasticsearch/reference/current/dynamic-templates.html
dynamic_templates = [
dynamic_templates: List[dict] = [
{
f"{question.name}_responses": {
"path_match": f"responses.*.values.{question.name}",
Expand Down Expand Up @@ -161,31 +161,11 @@ async def search(
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html

if isinstance(query, str):
query = Query(text=TextFieldQuery(q=query))
query = Query(text=TextQuery(q=query))

bool_query = {"must": []}
if not query.text.field:
# https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-combined-fields-query.html
text_query = {
"combined_fields": {
"query": query.text.q,
"fields": [f"fields.{field.name}" for field in dataset.fields],
"operator": "and",
}
}
else:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html
text_query = {
"match_phrase": {
f"fields.{query.text.field}": {
"query": query.text.q,
"operator": "and",
}
}
}

bool_query["must"].append(text_query)
text_query = self._text_query_builder(dataset, text=query.text)

bool_query = {"must": [text_query]}
if user_response_status_filter:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html
user_response_status_field = f"responses.{user_response_status_filter.user.username}.status"
Expand Down Expand Up @@ -213,6 +193,18 @@ async def search(

return SearchResponses(items=items)

@staticmethod
def _text_query_builder(dataset: Dataset, text: TextQuery) -> dict:
if not text.field:
# https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-combined-fields-query.html
field_names = [
f"fields.{field.name}" for field in dataset.fields if field.settings.get("type") == FieldType.text
]
return {"combined_fields": {"query": text.q, "fields": field_names, "operator": "and"}}
else:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-match-query.html
return {"match": {f"fields.{text.field}": {"query": text.q, "operator": "and"}}}

def _field_mapping_for_question(self, question: Question):
settings = question.parsed_settings

Expand Down

0 comments on commit fcf6d1f

Please sign in to comment.