From 89e7c5bbc1279843be1c8a967d01914e8044bfde Mon Sep 17 00:00:00 2001 From: Karen Shaw Date: Wed, 7 Aug 2024 18:10:14 +0000 Subject: [PATCH] Up default k to 40 but leave size at 5 --- chat/src/event_config.py | 9 ++++++++- chat/src/handlers/chat.py | 2 +- chat/src/helpers/hybrid_query.py | 2 +- chat/src/helpers/metrics.py | 1 + chat/src/helpers/response.py | 2 +- chat/test/helpers/test_metrics.py | 6 ++++-- chat/test/test_event_config.py | 4 +++- 7 files changed, 19 insertions(+), 7 deletions(-) diff --git a/chat/src/event_config.py b/chat/src/event_config.py index 10356537..07e42ee7 100644 --- a/chat/src/event_config.py +++ b/chat/src/event_config.py @@ -17,9 +17,10 @@ CHAIN_TYPE = "stuff" DOCUMENT_VARIABLE_NAME = "context" -K_VALUE = 5 +K_VALUE = 40 MAX_K = 100 MAX_TOKENS = 1000 +SIZE = 5 TEMPERATURE = 0.2 TEXT_KEY = "id" VERSION = "2024-02-01" @@ -56,6 +57,7 @@ class EventConfig: ref: str = field(init=False) request_context: dict = field(init=False) temperature: float = field(init=False) + size: int = field(init=False) socket: Websocket = field(init=False, default=None) stream_response: bool = field(init=False) text_key: str = field(init=False) @@ -76,6 +78,7 @@ def __post_init__(self): self.request_context = self.event.get("requestContext", {}) self.question = self.payload.get("question") self.ref = self.payload.get("ref") + self.size = self._get_size() self.stream_response = self.payload.get("stream_response", not self.debug_mode) self.temperature = self._get_temperature() self.text_key = self._get_text_key() @@ -130,6 +133,9 @@ def _get_openai_api_version(self): def _get_prompt_text(self): return self._get_payload_value_with_superuser_check("prompt", prompt_template()) + def _get_size(self): + return self._get_payload_value_with_superuser_check("size", SIZE) + def _get_temperature(self): return self._get_payload_value_with_superuser_check("temperature", TEMPERATURE) @@ -151,6 +157,7 @@ def debug_message(self): "prompt": self.prompt_text, "question": self.question, "ref": self.ref, + "size": self.ref, "temperature": self.temperature, "text_key": self.text_key, }, diff --git a/chat/src/handlers/chat.py b/chat/src/handlers/chat.py index db1c9ef0..8ec897ed 100644 --- a/chat/src/handlers/chat.py +++ b/chat/src/handlers/chat.py @@ -13,7 +13,7 @@ RESPONSE_TYPES = { "base": ["answer", "ref"], "debug": ["answer", "attributes", "azure_endpoint", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "temperature", "text_key", "token_counts"], - "log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "source_documents", "temperature", "token_counts"] + "log": ["answer", "deployment_name", "is_superuser", "k", "openai_api_version", "prompt", "question", "ref", "size", "source_documents", "temperature", "token_counts"] } def handler(event, context): diff --git a/chat/src/helpers/hybrid_query.py b/chat/src/helpers/hybrid_query.py index 47e2d910..d0cb8287 100644 --- a/chat/src/helpers/hybrid_query.py +++ b/chat/src/helpers/hybrid_query.py @@ -13,7 +13,7 @@ def filter(query: dict): def hybrid_query(query: str, model_id: str, vector_field: str = "embedding", k: int = 10, **kwargs: Any): result = { - "size": k, + "size": kwargs.get("size", 5), "query": { "hybrid": { "queries": [ diff --git a/chat/src/helpers/metrics.py b/chat/src/helpers/metrics.py index cf6d9b6e..9610eac7 100644 --- a/chat/src/helpers/metrics.py +++ b/chat/src/helpers/metrics.py @@ -14,6 +14,7 @@ def debug_response(config, response, original_question): "prompt": config.prompt_text, "question": config.question, "ref": config.ref, + "size": config.size, "source_documents": source_urls, "temperature": config.temperature, "text_key": config.text_key, diff --git a/chat/src/helpers/response.py b/chat/src/helpers/response.py index dea2098b..94c9678c 100644 --- a/chat/src/helpers/response.py +++ b/chat/src/helpers/response.py @@ -39,7 +39,7 @@ def get_and_send_original_question(docs): def prepare_response(self): try: - retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "_source": {"excludes": ["embedding"]}}) + retriever = self.config.opensearch.as_retriever(search_type="similarity", search_kwargs={"k": self.config.k, "size": self.config.size, "_source": {"excludes": ["embedding"]}}) chain = ( {"context": retriever, "question": RunnablePassthrough()} | self.original_question_passthrough() diff --git a/chat/test/helpers/test_metrics.py b/chat/test/helpers/test_metrics.py index 0438a662..de147ab6 100644 --- a/chat/test/helpers/test_metrics.py +++ b/chat/test/helpers/test_metrics.py @@ -60,11 +60,12 @@ def setUp(self): "body": json.dumps({ "deployment_name": "test", "index": "test", - "k": 5, + "k": 40, "openai_api_version": "2019-05-06", "prompt": "This is a test prompt.", "question": self.question, "ref": "test", + "size": 5, "temperature": 0.5, "text_key": "text", "auth": "test123" @@ -78,9 +79,10 @@ def setUp(self): def test_debug_response(self): result = debug_response(self.config, self.response, self.original_question) - self.assertEqual(result["k"], 5) + self.assertEqual(result["k"], 40) self.assertEqual(result["question"], self.question) self.assertEqual(result["ref"], "test") + self.assertEqual(result["size"], 5) self.assertEqual( result["source_documents"], [ diff --git a/chat/test/test_event_config.py b/chat/test/test_event_config.py index 625d854d..9f9f39f5 100644 --- a/chat/test/test_event_config.py +++ b/chat/test/test_event_config.py @@ -43,6 +43,7 @@ def test_attempt_override_without_superuser_status(self): "openai_api_version": "2024-01-01", "question": "test question", "ref": "test ref", + "size": 90, "temperature": 0.9, "text_key": "accession_number", } @@ -52,9 +53,10 @@ def test_attempt_override_without_superuser_status(self): expected_output = { "attributes": EventConfig.DEFAULT_ATTRIBUTES, "azure_endpoint": "https://test.openai.azure.com/", - "k": 5, + "k": 40, "openai_api_version": "2024-02-01", "question": "test question", + "size": 5, "ref": "test ref", "temperature": 0.2, "text_key": "id",