From 19907d5e62f41cb5910db6fb8de50a291cb11664 Mon Sep 17 00:00:00 2001 From: Izel Levy Date: Sun, 10 Dec 2023 15:13:37 +0200 Subject: [PATCH 01/11] Add last message query generator --- .../chat_engine/query_generator/__init__.py | 1 + .../query_generator/function_calling.py | 3 +- .../query_generator/last_message.py | 17 +++++++++++ .../test_last_message_query_generator.py | 29 +++++++++++++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 src/canopy/chat_engine/query_generator/last_message.py create mode 100644 tests/unit/query_generators/test_last_message_query_generator.py diff --git a/src/canopy/chat_engine/query_generator/__init__.py b/src/canopy/chat_engine/query_generator/__init__.py index 13ffd0d0..9005d02b 100644 --- a/src/canopy/chat_engine/query_generator/__init__.py +++ b/src/canopy/chat_engine/query_generator/__init__.py @@ -1,2 +1,3 @@ from .base import QueryGenerator from .function_calling import FunctionCallingQueryGenerator +from .last_message import LastMessageQueryGenerator diff --git a/src/canopy/chat_engine/query_generator/function_calling.py b/src/canopy/chat_engine/query_generator/function_calling.py index 7fb21577..09b41acc 100644 --- a/src/canopy/chat_engine/query_generator/function_calling.py +++ b/src/canopy/chat_engine/query_generator/function_calling.py @@ -16,7 +16,6 @@ class FunctionCallingQueryGenerator(QueryGenerator): - _DEFAULT_COMPONENTS = { "llm": OpenAILLM, } @@ -64,3 +63,5 @@ def _function(self) -> Function: ] ), ) + + diff --git a/src/canopy/chat_engine/query_generator/last_message.py b/src/canopy/chat_engine/query_generator/last_message.py new file mode 100644 index 00000000..7ffa02d5 --- /dev/null +++ b/src/canopy/chat_engine/query_generator/last_message.py @@ -0,0 +1,17 @@ +from typing import List + +from canopy.chat_engine.query_generator import QueryGenerator +from canopy.models.data_models import Messages, Query + + +class LastMessageQueryGenerator(QueryGenerator): + + def generate(self, + messages: Messages, + max_prompt_tokens: int) -> List[Query]: + return [Query(text=messages[-1].content)] + + async def agenerate(self, + messages: Messages, + max_prompt_tokens: int) -> List[Query]: + return self.generate(messages, max_prompt_tokens) diff --git a/tests/unit/query_generators/test_last_message_query_generator.py b/tests/unit/query_generators/test_last_message_query_generator.py new file mode 100644 index 00000000..d760a9de --- /dev/null +++ b/tests/unit/query_generators/test_last_message_query_generator.py @@ -0,0 +1,29 @@ +import pytest + +from canopy.chat_engine.query_generator import LastMessageQueryGenerator +from canopy.models.data_models import UserMessage, Query + + +@pytest.fixture +def sample_messages(): + return [ + UserMessage(content="What is photosynthesis?") + ] + + +@pytest.fixture +def query_generator(): + return LastMessageQueryGenerator() + + +def test_generate(query_generator, sample_messages): + expected = [Query(text=sample_messages[-1].content)] + actual = query_generator.generate(sample_messages, 0) + assert actual == expected + + +@pytest.mark.asyncio +async def test_agenerate(query_generator, sample_messages): + expected = [Query(text=sample_messages[-1].content)] + actual = await query_generator.agenerate(sample_messages, 0) + assert actual == expected From fc6776b421633b500676b5c5070a4d9b09de18e9 Mon Sep 17 00:00:00 2001 From: Izel Levy Date: Sun, 10 Dec 2023 15:18:25 +0200 Subject: [PATCH 02/11] Fix lint --- src/canopy/chat_engine/query_generator/function_calling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/canopy/chat_engine/query_generator/function_calling.py b/src/canopy/chat_engine/query_generator/function_calling.py index 09b41acc..1a070538 100644 --- a/src/canopy/chat_engine/query_generator/function_calling.py +++ b/src/canopy/chat_engine/query_generator/function_calling.py @@ -16,6 +16,7 @@ class FunctionCallingQueryGenerator(QueryGenerator): + _DEFAULT_COMPONENTS = { "llm": OpenAILLM, } @@ -64,4 +65,3 @@ def _function(self) -> Function: ), ) - From d4dfc2935c7ea09e2a3705bf7ee535ca84822008 Mon Sep 17 00:00:00 2001 From: Izel Levy Date: Sun, 10 Dec 2023 15:19:44 +0200 Subject: [PATCH 03/11] Fix lint --- src/canopy/chat_engine/query_generator/function_calling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/canopy/chat_engine/query_generator/function_calling.py b/src/canopy/chat_engine/query_generator/function_calling.py index 1a070538..7fb21577 100644 --- a/src/canopy/chat_engine/query_generator/function_calling.py +++ b/src/canopy/chat_engine/query_generator/function_calling.py @@ -64,4 +64,3 @@ def _function(self) -> Function: ] ), ) - From 04fa9acc6da334aea13db235854d0edf05504047 Mon Sep 17 00:00:00 2001 From: Izel Levy Date: Sun, 10 Dec 2023 15:51:55 +0200 Subject: [PATCH 04/11] Improve docs --- .../chat_engine/query_generator/last_message.py | 13 ++++++++++++- .../test_last_message_query_generator.py | 5 +++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/canopy/chat_engine/query_generator/last_message.py b/src/canopy/chat_engine/query_generator/last_message.py index 7ffa02d5..efdc1f23 100644 --- a/src/canopy/chat_engine/query_generator/last_message.py +++ b/src/canopy/chat_engine/query_generator/last_message.py @@ -5,10 +5,21 @@ class LastMessageQueryGenerator(QueryGenerator): - + """ + Just returns the last message as a query without running any LLMs. This can be + considered as the most basic query generation. Please use other query generators + for more accurate results. + """ def generate(self, messages: Messages, max_prompt_tokens: int) -> List[Query]: + """ + max_prompt_token is dismissed since we do not consume any token for + generating the queries. + """ + if len(messages) == 0: + raise ValueError("Passed chat history does not contain any messages." + " Please include at least one message in the history.") return [Query(text=messages[-1].content)] async def agenerate(self, diff --git a/tests/unit/query_generators/test_last_message_query_generator.py b/tests/unit/query_generators/test_last_message_query_generator.py index d760a9de..47c0f92d 100644 --- a/tests/unit/query_generators/test_last_message_query_generator.py +++ b/tests/unit/query_generators/test_last_message_query_generator.py @@ -27,3 +27,8 @@ async def test_agenerate(query_generator, sample_messages): expected = [Query(text=sample_messages[-1].content)] actual = await query_generator.agenerate(sample_messages, 0) assert actual == expected + + +def test_generate_fails_with_empty_history(query_generator): + with pytest.raises(ValueError): + query_generator.generate([], 0) From bd943fff513911f8e6934da5ce04c89501f18d97 Mon Sep 17 00:00:00 2001 From: Izel Levy Date: Sun, 10 Dec 2023 15:53:29 +0200 Subject: [PATCH 05/11] Fix docstring --- src/canopy/chat_engine/query_generator/last_message.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/canopy/chat_engine/query_generator/last_message.py b/src/canopy/chat_engine/query_generator/last_message.py index efdc1f23..6c4358a3 100644 --- a/src/canopy/chat_engine/query_generator/last_message.py +++ b/src/canopy/chat_engine/query_generator/last_message.py @@ -6,7 +6,7 @@ class LastMessageQueryGenerator(QueryGenerator): """ - Just returns the last message as a query without running any LLMs. This can be + Returns the last message as a query without running any LLMs. This can be considered as the most basic query generation. Please use other query generators for more accurate results. """ From d3a1df65b6ee14da4c82faf0cbf518b23b75195e Mon Sep 17 00:00:00 2001 From: Izel Levy Date: Sun, 10 Dec 2023 16:16:30 +0200 Subject: [PATCH 06/11] Take only user messages --- .../query_generator/last_message.py | 16 +++++---- .../test_last_message_query_generator.py | 36 ++++++++++++++----- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/src/canopy/chat_engine/query_generator/last_message.py b/src/canopy/chat_engine/query_generator/last_message.py index 6c4358a3..1deee36e 100644 --- a/src/canopy/chat_engine/query_generator/last_message.py +++ b/src/canopy/chat_engine/query_generator/last_message.py @@ -1,15 +1,16 @@ from typing import List from canopy.chat_engine.query_generator import QueryGenerator -from canopy.models.data_models import Messages, Query +from canopy.models.data_models import Messages, Query, Role class LastMessageQueryGenerator(QueryGenerator): """ - Returns the last message as a query without running any LLMs. This can be + Returns the last user message as a query without running any LLMs. This can be considered as the most basic query generation. Please use other query generators for more accurate results. """ + def generate(self, messages: Messages, max_prompt_tokens: int) -> List[Query]: @@ -17,10 +18,13 @@ def generate(self, max_prompt_token is dismissed since we do not consume any token for generating the queries. """ - if len(messages) == 0: - raise ValueError("Passed chat history does not contain any messages." - " Please include at least one message in the history.") - return [Query(text=messages[-1].content)] + user_messages = [message for message in messages if message.role == Role.USER] + + if len(user_messages) == 0: + raise ValueError("Passed chat history does not contain any user messages." + " Please include at least one user message in the history.") + + return [Query(text=user_messages[-1].content)] async def agenerate(self, messages: Messages, diff --git a/tests/unit/query_generators/test_last_message_query_generator.py b/tests/unit/query_generators/test_last_message_query_generator.py index 47c0f92d..38b8a844 100644 --- a/tests/unit/query_generators/test_last_message_query_generator.py +++ b/tests/unit/query_generators/test_last_message_query_generator.py @@ -1,34 +1,54 @@ import pytest from canopy.chat_engine.query_generator import LastMessageQueryGenerator -from canopy.models.data_models import UserMessage, Query +from canopy.models.data_models import UserMessage, Query, AssistantMessage, Role @pytest.fixture -def sample_messages(): +def sample_user_messages(): return [ UserMessage(content="What is photosynthesis?") ] +@pytest.fixture +def sample_messages(): + return [ + UserMessage(content="What is photosynthesis?"), + AssistantMessage(content="Oh! I don't know.") + ] + + @pytest.fixture def query_generator(): return LastMessageQueryGenerator() -def test_generate(query_generator, sample_messages): - expected = [Query(text=sample_messages[-1].content)] - actual = query_generator.generate(sample_messages, 0) +def test_generate(query_generator, sample_user_messages): + expected = [Query(text=sample_user_messages[-1].content)] + actual = query_generator.generate(sample_user_messages, 0) assert actual == expected @pytest.mark.asyncio -async def test_agenerate(query_generator, sample_messages): - expected = [Query(text=sample_messages[-1].content)] - actual = await query_generator.agenerate(sample_messages, 0) +async def test_agenerate(query_generator, sample_user_messages): + expected = [Query(text=sample_user_messages[-1].content)] + actual = await query_generator.agenerate(sample_user_messages, 0) + assert actual == expected + + +def test_generate_with_user_and_assistant(query_generator, sample_messages): + last_user_message = next(message for message in sample_messages if message.role == Role.USER) + expected = [Query(text=last_user_message.content)] + actual = query_generator.generate(sample_messages, 0) assert actual == expected def test_generate_fails_with_empty_history(query_generator): with pytest.raises(ValueError): query_generator.generate([], 0) + + +def test_generate_fails_with_no_user_message(query_generator): + with pytest.raises(ValueError): + query_generator.generate([AssistantMessage(content="Hi! How can I help you?")], 0) From 16a5df6cd868c2adda98754cc0ce2a481c5a0c6f Mon Sep 17 00:00:00 2001 From: Izel Levy Date: Sun, 10 Dec 2023 16:20:43 +0200 Subject: [PATCH 07/11] Fix lint --- src/canopy/chat_engine/query_generator/last_message.py | 5 +++-- .../test_last_message_query_generator.py | 9 ++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/canopy/chat_engine/query_generator/last_message.py b/src/canopy/chat_engine/query_generator/last_message.py index 1deee36e..05016cab 100644 --- a/src/canopy/chat_engine/query_generator/last_message.py +++ b/src/canopy/chat_engine/query_generator/last_message.py @@ -21,8 +21,9 @@ def generate(self, user_messages = [message for message in messages if message.role == Role.USER] if len(user_messages) == 0: - raise ValueError("Passed chat history does not contain any user messages." - " Please include at least one user message in the history.") + raise ValueError("Passed chat history does not contain any user " + "messages. Please include at least one user message" + " in the history.") return [Query(text=user_messages[-1].content)] diff --git a/tests/unit/query_generators/test_last_message_query_generator.py b/tests/unit/query_generators/test_last_message_query_generator.py index 38b8a844..6f56ac86 100644 --- a/tests/unit/query_generators/test_last_message_query_generator.py +++ b/tests/unit/query_generators/test_last_message_query_generator.py @@ -38,8 +38,9 @@ async def test_agenerate(query_generator, sample_user_messages): def test_generate_with_user_and_assistant(query_generator, sample_messages): - last_user_message = next(message for message in sample_messages if message.role == Role.USER) - expected = [Query(text=last_user_message.content)] + user_messages = (msg for msg in sample_messages if msg.role == Role.USER) + last_user_msg = next(user_messages) + expected = [Query(text=last_user_msg.content)] actual = query_generator.generate(sample_messages, 0) assert actual == expected @@ -51,4 +52,6 @@ def test_generate_fails_with_empty_history(query_generator): def test_generate_fails_with_no_user_message(query_generator): with pytest.raises(ValueError): - query_generator.generate([AssistantMessage(content="Hi! How can I help you?")], 0) + query_generator.generate([ + AssistantMessage(content="Hi! How can I help you?") + ], 0) From 1252248779877549dadca4aab8c265250743823e Mon Sep 17 00:00:00 2001 From: Izel Levy Date: Sun, 10 Dec 2023 16:38:01 +0200 Subject: [PATCH 08/11] Fix behavior --- .../query_generator/last_message.py | 17 ++++++++++------- .../test_last_message_query_generator.py | 18 +----------------- 2 files changed, 11 insertions(+), 24 deletions(-) diff --git a/src/canopy/chat_engine/query_generator/last_message.py b/src/canopy/chat_engine/query_generator/last_message.py index 05016cab..74e15661 100644 --- a/src/canopy/chat_engine/query_generator/last_message.py +++ b/src/canopy/chat_engine/query_generator/last_message.py @@ -6,7 +6,7 @@ class LastMessageQueryGenerator(QueryGenerator): """ - Returns the last user message as a query without running any LLMs. This can be + Returns the last message as a query without running any LLMs. This can be considered as the most basic query generation. Please use other query generators for more accurate results. """ @@ -18,14 +18,17 @@ def generate(self, max_prompt_token is dismissed since we do not consume any token for generating the queries. """ - user_messages = [message for message in messages if message.role == Role.USER] - if len(user_messages) == 0: - raise ValueError("Passed chat history does not contain any user " - "messages. Please include at least one user message" - " in the history.") + if len(messages) == 0: + raise ValueError("Passed chat history does not contain any messages. " + "Please include at least one message in the history.") - return [Query(text=user_messages[-1].content)] + last_message = messages[-1] + + if last_message.role != Role.USER: + raise ValueError(f"Expected a UserMessage, got {type(last_message)}.") + + return [Query(text=last_message.content)] async def agenerate(self, messages: Messages, diff --git a/tests/unit/query_generators/test_last_message_query_generator.py b/tests/unit/query_generators/test_last_message_query_generator.py index 6f56ac86..360fc6d9 100644 --- a/tests/unit/query_generators/test_last_message_query_generator.py +++ b/tests/unit/query_generators/test_last_message_query_generator.py @@ -1,7 +1,7 @@ import pytest from canopy.chat_engine.query_generator import LastMessageQueryGenerator -from canopy.models.data_models import UserMessage, Query, AssistantMessage, Role +from canopy.models.data_models import UserMessage, Query, AssistantMessage @pytest.fixture @@ -11,14 +11,6 @@ def sample_user_messages(): ] -@pytest.fixture -def sample_messages(): - return [ - UserMessage(content="What is photosynthesis?"), - AssistantMessage(content="Oh! I don't know.") - ] - - @pytest.fixture def query_generator(): return LastMessageQueryGenerator() @@ -37,14 +29,6 @@ async def test_agenerate(query_generator, sample_user_messages): assert actual == expected -def test_generate_with_user_and_assistant(query_generator, sample_messages): - user_messages = (msg for msg in sample_messages if msg.role == Role.USER) - last_user_msg = next(user_messages) - expected = [Query(text=last_user_msg.content)] - actual = query_generator.generate(sample_messages, 0) - assert actual == expected - - def test_generate_fails_with_empty_history(query_generator): with pytest.raises(ValueError): query_generator.generate([], 0) From 16759c67e5ca36ad58f0cb2eb56325b61d6c6c57 Mon Sep 17 00:00:00 2001 From: Izel Levy Date: Sun, 10 Dec 2023 16:39:40 +0200 Subject: [PATCH 09/11] Fix sample messages --- .../test_last_message_query_generator.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/unit/query_generators/test_last_message_query_generator.py b/tests/unit/query_generators/test_last_message_query_generator.py index 360fc6d9..308c1b4d 100644 --- a/tests/unit/query_generators/test_last_message_query_generator.py +++ b/tests/unit/query_generators/test_last_message_query_generator.py @@ -5,7 +5,7 @@ @pytest.fixture -def sample_user_messages(): +def sample_messages(): return [ UserMessage(content="What is photosynthesis?") ] @@ -16,16 +16,16 @@ def query_generator(): return LastMessageQueryGenerator() -def test_generate(query_generator, sample_user_messages): - expected = [Query(text=sample_user_messages[-1].content)] - actual = query_generator.generate(sample_user_messages, 0) +def test_generate(query_generator, sample_messages): + expected = [Query(text=sample_messages[-1].content)] + actual = query_generator.generate(sample_messages, 0) assert actual == expected @pytest.mark.asyncio -async def test_agenerate(query_generator, sample_user_messages): - expected = [Query(text=sample_user_messages[-1].content)] - actual = await query_generator.agenerate(sample_user_messages, 0) +async def test_agenerate(query_generator, sample_messages): + expected = [Query(text=sample_messages[-1].content)] + actual = await query_generator.agenerate(sample_messages, 0) assert actual == expected From c425f83e350e91054d103679bd9fc74ba254418b Mon Sep 17 00:00:00 2001 From: ilai Date: Sun, 10 Dec 2023 20:57:47 +0200 Subject: [PATCH 10/11] [config] Changed Anyscale config to use LastMessageQueryGenerator Since Anyscale don't support function calling for now --- config/anyscale.yaml | 2 +- config/config.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/config/anyscale.yaml b/config/anyscale.yaml index 816a0ffc..dd6ce1b8 100644 --- a/config/anyscale.yaml +++ b/config/anyscale.yaml @@ -24,7 +24,7 @@ chat_engine: # The query builder is responsible for generating textual queries given user message history. # -------------------------------------------------------------------- query_builder: - type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator] + type: LastMessageQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator] params: prompt: *query_builder_prompt # The query builder's system prompt for calling the LLM function_description: # A function description passed to the LLM's `function_calling` API diff --git a/config/config.yaml b/config/config.yaml index 0d3576cd..8fc78572 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -59,7 +59,7 @@ chat_engine: # The query builder is responsible for generating textual queries given user message history. # -------------------------------------------------------------------- query_builder: - type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator] + type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator] params: prompt: *query_builder_prompt # The query builder's system prompt for calling the LLM function_description: # A function description passed to the LLM's `function_calling` API From 47734f389754f72b59da5577fa1199bb19329b51 Mon Sep 17 00:00:00 2001 From: ilai Date: Sun, 10 Dec 2023 21:09:24 +0200 Subject: [PATCH 11/11] [config] Update anyscale.yaml - The new QueryGenerator doesn't need an LLM Otherwise it will error out --- config/anyscale.yaml | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/config/anyscale.yaml b/config/anyscale.yaml index dd6ce1b8..7f5f28ef 100644 --- a/config/anyscale.yaml +++ b/config/anyscale.yaml @@ -24,14 +24,4 @@ chat_engine: # The query builder is responsible for generating textual queries given user message history. # -------------------------------------------------------------------- query_builder: - type: LastMessageQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator] - params: - prompt: *query_builder_prompt # The query builder's system prompt for calling the LLM - function_description: # A function description passed to the LLM's `function_calling` API - Query search engine for relevant information - - llm: # The LLM that the query builder will use to generate queries. - #Use OpenAI for function call for now - type: OpenAILLM - params: - model_name: gpt-3.5-turbo + type: LastMessageQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator] \ No newline at end of file