From b2b384dbba2a6575a0e28c62800d4192fa004ed2 Mon Sep 17 00:00:00 2001 From: Massimiliano Angelino Date: Tue, 7 Nov 2023 17:52:04 +0100 Subject: [PATCH] fix(reranking): uses semantic search score Fills up the documents returned by the retriever with the top scored semantic search documents --- .../python/genai_core/aurora/query.py | 35 ++++++++++++------- .../python/genai_core/opensearch/query.py | 10 +++--- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py b/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py index c7cf5f74b..1478b4155 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py +++ b/lib/shared/layers/python-sdk/python/genai_core/aurora/query.py @@ -197,14 +197,14 @@ def query_workspace_aurora( score_dict[unique_items[i]["chunk_id"]] = score unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True) - unique_items = unique_items[:limit] - + for record in vector_search_records: record["score"] = score_dict[record["chunk_id"]] for record in keyword_search_records: record["score"] = score_dict[record["chunk_id"]] if full_response: + unique_items = unique_items[:limit] ret_value = { "engine": "aurora", "query_language": language_name, @@ -216,16 +216,27 @@ def query_workspace_aurora( "keyword_search_items": convert_types(keyword_search_records), } else: - ret_items = list(filter(lambda val: val["score"] > threshold, unique_items)) + ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[:limit] if len(ret_items) < limit: - unique_items = sorted( - unique_items, key=lambda x: x["vector_search_score"], reverse=True - ) - ret_items = ret_items + ( - list( - filter(lambda val: val["vector_search_score"] > 0.5, unique_items) - )[: (limit - len(ret_items))] - ) + # inner product metric is negative hence we sort ascending + if metric == "inner": + unique_items = sorted( + unique_items, key=lambda x: x["vector_search_score"] or 1, reverse=False + ) + ret_items = ret_items + ( + list( + filter(lambda val: (val["vector_search_score"] or 1) < -0.5, unique_items) + )[: (limit - len(ret_items))] + ) + else: + unique_items = sorted( + unique_items, key=lambda x: x["vector_search_score"] or -1, reverse=True + ) + ret_items = ret_items + ( + list( + filter(lambda val: (val["vector_search_score"] or -1) > 0.5 , unique_items) + )[: (limit - len(ret_items))] + ) ret_value = { "engine": "aurora", @@ -233,7 +244,7 @@ def query_workspace_aurora( "supported_languages": languages, "detected_languages": detected_languages, "items": convert_types( - list(filter(lambda val: val["score"] > 0, unique_items)) + ret_items ), } diff --git a/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py b/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py index 4145deb50..e4c297785 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py +++ b/lib/shared/layers/python-sdk/python/genai_core/opensearch/query.py @@ -107,14 +107,14 @@ def query_workspace_open_search( unique_items[i]["score"] = score score_dict[unique_items[i]["chunk_id"]] = score unique_items = sorted(unique_items, key=lambda x: x["score"], reverse=True) - unique_items = unique_items[:limit] - + for record in vector_search_records: record["score"] = score_dict[record["chunk_id"]] for record in keyword_search_records: record["score"] = score_dict[record["chunk_id"]] if full_response: + unique_items = unique_items[:limit] ret_value = { "engine": "opensearch", "supported_languages": languages, @@ -124,14 +124,14 @@ def query_workspace_open_search( "keyword_search_items": keyword_search_records, } else: - ret_items = list(filter(lambda val: val["score"] > threshold, unique_items)) + ret_items = list(filter(lambda val: val["score"] > threshold, unique_items))[:limit] if len(ret_items) < limit: unique_items = sorted( - unique_items, key=lambda x: x["vector_search_score"], reverse=True + unique_items, key=lambda x: x["vector_search_score"] or -1, reverse=True ) ret_items = ret_items + ( list( - filter(lambda val: val["vector_search_score"] > 0.5, unique_items) + filter(lambda val: (val["vector_search_score"] or -1) > 0.5, unique_items) )[: (limit - len(ret_items))] )