Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Add last message query generator #210

Merged
1 change: 1 addition & 0 deletions src/canopy/chat_engine/query_generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base import QueryGenerator
from .function_calling import FunctionCallingQueryGenerator
from .last_message import LastMessageQueryGenerator
17 changes: 17 additions & 0 deletions src/canopy/chat_engine/query_generator/last_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import List

from canopy.chat_engine.query_generator import QueryGenerator
from canopy.models.data_models import Messages, Query


class LastMessageQueryGenerator(QueryGenerator):

izellevy marked this conversation as resolved.
Show resolved Hide resolved
def generate(self,
messages: Messages,
max_prompt_tokens: int) -> List[Query]:
return [Query(text=messages[-1].content)]
izellevy marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take the last message which is a UserMessage (message.role == 'Role.user' or isinstance()).
If you go backwards on the messages and can't find a UserMessage - then raise an error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds a bit weird to me. It's more likely that if the last message is not user message, there is something wrong in the pipeline or usage of this class, so IMO it's better to raise an error here instead of silently do something that might not be expected by the user

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point @igiloh-pinecone but agree with @acatav. Maybe I can raise if the history is empty or the last message is not a user message @igiloh-pinecone wdyt?


async def agenerate(self,
messages: Messages,
max_prompt_tokens: int) -> List[Query]:
return self.generate(messages, max_prompt_tokens)
29 changes: 29 additions & 0 deletions tests/unit/query_generators/test_last_message_query_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest

from canopy.chat_engine.query_generator import LastMessageQueryGenerator
from canopy.models.data_models import UserMessage, Query


@pytest.fixture
def sample_messages():
return [
UserMessage(content="What is photosynthesis?")
]


@pytest.fixture
def query_generator():
return LastMessageQueryGenerator()


def test_generate(query_generator, sample_messages):
expected = [Query(text=sample_messages[-1].content)]
actual = query_generator.generate(sample_messages, 0)
assert actual == expected


@pytest.mark.asyncio
async def test_agenerate(query_generator, sample_messages):
expected = [Query(text=sample_messages[-1].content)]
actual = await query_generator.agenerate(sample_messages, 0)
assert actual == expected