From 9e741abe22dd88eaa046e7fc646d057169764112 Mon Sep 17 00:00:00 2001 From: Ryan Chan Date: Wed, 1 Nov 2023 05:50:26 +0000 Subject: [PATCH] Fix #7765: Fix exceed context length error (#8530) * fix context length error in chat engine * fix mypy issues * fix mypy issues 2 * add type hint for kwargs --- llama_index/chat_engine/context.py | 37 ++++++++++++++++++++---- llama_index/chat_engine/simple.py | 36 ++++++++++++++++++++--- llama_index/memory/chat_memory_buffer.py | 14 ++++++--- llama_index/memory/types.py | 4 +-- 4 files changed, 76 insertions(+), 15 deletions(-) diff --git a/llama_index/chat_engine/context.py b/llama_index/chat_engine/context.py index dde6e5698da21..ce0019118a578 100644 --- a/llama_index/chat_engine/context.py +++ b/llama_index/chat_engine/context.py @@ -154,8 +154,14 @@ def chat( context_str_template, nodes = self._generate_context(message) prefix_messages = self._get_prefix_messages_with_context(context_str_template) - all_messages = prefix_messages + self._memory.get() - + initial_token_count = len( + self._memory.tokenizer_fn( + " ".join([(m.content or "") for m in prefix_messages]) + ) + ) + all_messages = prefix_messages + self._memory.get( + initial_token_count=initial_token_count + ) chat_response = self._llm.chat(all_messages) ai_message = chat_response.message self._memory.put(ai_message) @@ -183,7 +189,14 @@ def stream_chat( context_str_template, nodes = self._generate_context(message) prefix_messages = self._get_prefix_messages_with_context(context_str_template) - all_messages = prefix_messages + self._memory.get() + initial_token_count = len( + self._memory.tokenizer_fn( + " ".join([(m.content or "") for m in prefix_messages]) + ) + ) + all_messages = prefix_messages + self._memory.get( + initial_token_count=initial_token_count + ) chat_response = StreamingAgentChatResponse( chat_stream=self._llm.stream_chat(all_messages), @@ -214,7 +227,14 @@ async def achat( context_str_template, nodes = await self._agenerate_context(message) prefix_messages = self._get_prefix_messages_with_context(context_str_template) - all_messages = prefix_messages + self._memory.get() + initial_token_count = len( + self._memory.tokenizer_fn( + " ".join([(m.content or "") for m in prefix_messages]) + ) + ) + all_messages = prefix_messages + self._memory.get( + initial_token_count=initial_token_count + ) chat_response = await self._llm.achat(all_messages) ai_message = chat_response.message @@ -243,7 +263,14 @@ async def astream_chat( context_str_template, nodes = await self._agenerate_context(message) prefix_messages = self._get_prefix_messages_with_context(context_str_template) - all_messages = prefix_messages + self._memory.get() + initial_token_count = len( + self._memory.tokenizer_fn( + " ".join([(m.content or "") for m in prefix_messages]) + ) + ) + all_messages = prefix_messages + self._memory.get( + initial_token_count=initial_token_count + ) chat_response = StreamingAgentChatResponse( achat_stream=await self._llm.astream_chat(all_messages), diff --git a/llama_index/chat_engine/simple.py b/llama_index/chat_engine/simple.py index 6cbf8918ad148..7ec546e49f859 100644 --- a/llama_index/chat_engine/simple.py +++ b/llama_index/chat_engine/simple.py @@ -76,7 +76,14 @@ def chat( if chat_history is not None: self._memory.set(chat_history) self._memory.put(ChatMessage(content=message, role="user")) - all_messages = self._prefix_messages + self._memory.get() + initial_token_count = len( + self._memory.tokenizer_fn( + " ".join([(m.content or "") for m in self._prefix_messages]) + ) + ) + all_messages = self._prefix_messages + self._memory.get( + initial_token_count=initial_token_count + ) chat_response = self._llm.chat(all_messages) ai_message = chat_response.message @@ -91,7 +98,14 @@ def stream_chat( if chat_history is not None: self._memory.set(chat_history) self._memory.put(ChatMessage(content=message, role="user")) - all_messages = self._prefix_messages + self._memory.get() + initial_token_count = len( + self._memory.tokenizer_fn( + " ".join([(m.content or "") for m in self._prefix_messages]) + ) + ) + all_messages = self._prefix_messages + self._memory.get( + initial_token_count=initial_token_count + ) chat_response = StreamingAgentChatResponse( chat_stream=self._llm.stream_chat(all_messages) @@ -110,7 +124,14 @@ async def achat( if chat_history is not None: self._memory.set(chat_history) self._memory.put(ChatMessage(content=message, role="user")) - all_messages = self._prefix_messages + self._memory.get() + initial_token_count = len( + self._memory.tokenizer_fn( + " ".join([(m.content or "") for m in self._prefix_messages]) + ) + ) + all_messages = self._prefix_messages + self._memory.get( + initial_token_count=initial_token_count + ) chat_response = await self._llm.achat(all_messages) ai_message = chat_response.message @@ -125,7 +146,14 @@ async def astream_chat( if chat_history is not None: self._memory.set(chat_history) self._memory.put(ChatMessage(content=message, role="user")) - all_messages = self._prefix_messages + self._memory.get() + initial_token_count = len( + self._memory.tokenizer_fn( + " ".join([(m.content or "") for m in self._prefix_messages]) + ) + ) + all_messages = self._prefix_messages + self._memory.get( + initial_token_count=initial_token_count + ) chat_response = StreamingAgentChatResponse( achat_stream=await self._llm.astream_chat(all_messages) diff --git a/llama_index/memory/chat_memory_buffer.py b/llama_index/memory/chat_memory_buffer.py index afe0527d914fc..6577d27394d86 100644 --- a/llama_index/memory/chat_memory_buffer.py +++ b/llama_index/memory/chat_memory_buffer.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Dict, List, Optional, cast from llama_index.bridge.pydantic import Field, root_validator -from llama_index.llms.base import LLM, ChatMessage +from llama_index.llms.base import LLM, ChatMessage, MessageRole from llama_index.memory.types import BaseMemory from llama_index.utils import GlobalsHelper @@ -82,20 +82,26 @@ def to_dict(self) -> dict: def from_dict(cls, json_dict: dict) -> "ChatMemoryBuffer": return cls.parse_obj(json_dict) - def get(self) -> List[ChatMessage]: + def get(self, initial_token_count: int = 0, **kwargs: Any) -> List[ChatMessage]: """Get chat history.""" message_count = len(self.chat_history) message_str = " ".join( [str(m.content) for m in self.chat_history[-message_count:]] ) - token_count = len(self.tokenizer_fn(message_str)) + token_count = initial_token_count + len(self.tokenizer_fn(message_str)) while token_count > self.token_limit and message_count > 1: message_count -= 1 + if self.chat_history[-message_count].role == MessageRole.ASSISTANT: + # we cannot have an assistant message at the start of the chat history + # if after removal of the first, we have an assistant message, + # we need to remove the assistant message too + message_count -= 1 + message_str = " ".join( [str(m.content) for m in self.chat_history[-message_count:]] ) - token_count = len(self.tokenizer_fn(message_str)) + token_count = initial_token_count + len(self.tokenizer_fn(message_str)) # catch one message longer than token limit if token_count > self.token_limit: diff --git a/llama_index/memory/types.py b/llama_index/memory/types.py index a2f6dd780d549..5375f5d408290 100644 --- a/llama_index/memory/types.py +++ b/llama_index/memory/types.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import List, Optional +from typing import Any, List, Optional from llama_index.bridge.pydantic import BaseModel from llama_index.llms.base import LLM, ChatMessage @@ -21,7 +21,7 @@ def from_defaults( """Create a chat memory from defaults.""" @abstractmethod - def get(self) -> List[ChatMessage]: + def get(self, **kwargs: Any) -> List[ChatMessage]: """Get chat history.""" @abstractmethod