Skip to content

Commit

Permalink
fix(reranking): uses semantic search score
Browse files Browse the repository at this point in the history
Fills up the documents returned by the retriever with
the top scored semantic search documents
  • Loading branch information
massi-ang authored and bigadsoleiman committed Nov 7, 2023
1 parent 62cfc62 commit 5776d33
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
35 changes: 23 additions & 12 deletions lib/shared/layers/python-sdk/python/genai_core/aurora/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@ def query_workspace_aurora(
score_dict[unique_items[i]["chunk_id"]] = score

unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True)
unique_items = unique_items[:limit]


for record in vector_search_records:
record["score"] = score_dict[record["chunk_id"]]
for record in keyword_search_records:
record["score"] = score_dict[record["chunk_id"]]

if full_response:
unique_items = unique_items[:limit]
ret_value = {
"engine": "aurora",
"query_language": language_name,
Expand All @@ -216,24 +216,35 @@ def query_workspace_aurora(
"keyword_search_items": convert_types(keyword_search_records),
}
else:
ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))
ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[:limit]
if len(ret_items) < limit:
unique_items = sorted(
unique_items, key=lambda x: x["vector_search_score"], reverse=True
)
ret_items = ret_items + (
list(
filter(lambda val: val["vector_search_score"] > 0.5, unique_items)
)[: (limit - len(ret_items))]
)
# inner product metric is negative hence we sort ascending
if metric == "inner":
unique_items = sorted(
unique_items, key=lambda x: x["vector_search_score"] or 1, reverse=False
)
ret_items = ret_items + (
list(
filter(lambda val: (val["vector_search_score"] or 1) < -0.5, unique_items)
)[: (limit - len(ret_items))]
)
else:
unique_items = sorted(
unique_items, key=lambda x: x["vector_search_score"] or -1, reverse=True
)
ret_items = ret_items + (
list(
filter(lambda val: (val["vector_search_score"] or -1) > 0.5 , unique_items)
)[: (limit - len(ret_items))]
)

ret_value = {
"engine": "aurora",
"query_language": language_name,
"supported_languages": languages,
"detected_languages": detected_languages,
"items": convert_types(
list(filter(lambda val: val["score"] > 0, unique_items))
ret_items
),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ def query_workspace_open_search(
unique_items[i]["score"] = score
score_dict[unique_items[i]["chunk_id"]] = score
unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True)
unique_items = unique_items[:limit]


for record in vector_search_records:
record["score"] = score_dict[record["chunk_id"]]
for record in keyword_search_records:
record["score"] = score_dict[record["chunk_id"]]

if full_response:
unique_items = unique_items[:limit]
ret_value = {
"engine": "opensearch",
"supported_languages": languages,
Expand All @@ -124,14 +124,14 @@ def query_workspace_open_search(
"keyword_search_items": keyword_search_records,
}
else:
ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))
ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[:limit]
if len(ret_items) < limit:
unique_items = sorted(
unique_items, key=lambda x: x["vector_search_score"], reverse=True
unique_items, key=lambda x: x["vector_search_score"] or -1, reverse=True
)
ret_items = ret_items + (
list(
filter(lambda val: val["vector_search_score"] > 0.5, unique_items)
filter(lambda val: (val["vector_search_score"] or -1) > 0.5, unique_items)
)[: (limit - len(ret_items))]
)

Expand Down

0 comments on commit 5776d33

Please sign in to comment.