Skip to content

Commit

Permalink
Fix #7765: Fix exceed context length error (#8530)
Browse files Browse the repository at this point in the history
* fix context length error in chat engine

* fix mypy issues

* fix mypy issues 2

* add type hint for kwargs
  • Loading branch information
rchan26 authored Nov 1, 2023
1 parent dfa2f45 commit 9e741ab
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 15 deletions.
37 changes: 32 additions & 5 deletions llama_index/chat_engine/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,14 @@ def chat(

context_str_template, nodes = self._generate_context(message)
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
all_messages = prefix_messages + self._memory.get()

initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in prefix_messages])
)
)
all_messages = prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)
chat_response = self._llm.chat(all_messages)
ai_message = chat_response.message
self._memory.put(ai_message)
Expand Down Expand Up @@ -183,7 +189,14 @@ def stream_chat(

context_str_template, nodes = self._generate_context(message)
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
all_messages = prefix_messages + self._memory.get()
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in prefix_messages])
)
)
all_messages = prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)

chat_response = StreamingAgentChatResponse(
chat_stream=self._llm.stream_chat(all_messages),
Expand Down Expand Up @@ -214,7 +227,14 @@ async def achat(

context_str_template, nodes = await self._agenerate_context(message)
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
all_messages = prefix_messages + self._memory.get()
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in prefix_messages])
)
)
all_messages = prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)

chat_response = await self._llm.achat(all_messages)
ai_message = chat_response.message
Expand Down Expand Up @@ -243,7 +263,14 @@ async def astream_chat(

context_str_template, nodes = await self._agenerate_context(message)
prefix_messages = self._get_prefix_messages_with_context(context_str_template)
all_messages = prefix_messages + self._memory.get()
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in prefix_messages])
)
)
all_messages = prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)

chat_response = StreamingAgentChatResponse(
achat_stream=await self._llm.astream_chat(all_messages),
Expand Down
36 changes: 32 additions & 4 deletions llama_index/chat_engine/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,14 @@ def chat(
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
all_messages = self._prefix_messages + self._memory.get()
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
)
)
all_messages = self._prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)

chat_response = self._llm.chat(all_messages)
ai_message = chat_response.message
Expand All @@ -91,7 +98,14 @@ def stream_chat(
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
all_messages = self._prefix_messages + self._memory.get()
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
)
)
all_messages = self._prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)

chat_response = StreamingAgentChatResponse(
chat_stream=self._llm.stream_chat(all_messages)
Expand All @@ -110,7 +124,14 @@ async def achat(
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
all_messages = self._prefix_messages + self._memory.get()
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
)
)
all_messages = self._prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)

chat_response = await self._llm.achat(all_messages)
ai_message = chat_response.message
Expand All @@ -125,7 +146,14 @@ async def astream_chat(
if chat_history is not None:
self._memory.set(chat_history)
self._memory.put(ChatMessage(content=message, role="user"))
all_messages = self._prefix_messages + self._memory.get()
initial_token_count = len(
self._memory.tokenizer_fn(
" ".join([(m.content or "") for m in self._prefix_messages])
)
)
all_messages = self._prefix_messages + self._memory.get(
initial_token_count=initial_token_count
)

chat_response = StreamingAgentChatResponse(
achat_stream=await self._llm.astream_chat(all_messages)
Expand Down
14 changes: 10 additions & 4 deletions llama_index/memory/chat_memory_buffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Callable, Dict, List, Optional, cast

from llama_index.bridge.pydantic import Field, root_validator
from llama_index.llms.base import LLM, ChatMessage
from llama_index.llms.base import LLM, ChatMessage, MessageRole
from llama_index.memory.types import BaseMemory
from llama_index.utils import GlobalsHelper

Expand Down Expand Up @@ -82,20 +82,26 @@ def to_dict(self) -> dict:
def from_dict(cls, json_dict: dict) -> "ChatMemoryBuffer":
return cls.parse_obj(json_dict)

def get(self) -> List[ChatMessage]:
def get(self, initial_token_count: int = 0, **kwargs: Any) -> List[ChatMessage]:
"""Get chat history."""
message_count = len(self.chat_history)
message_str = " ".join(
[str(m.content) for m in self.chat_history[-message_count:]]
)
token_count = len(self.tokenizer_fn(message_str))
token_count = initial_token_count + len(self.tokenizer_fn(message_str))

while token_count > self.token_limit and message_count > 1:
message_count -= 1
if self.chat_history[-message_count].role == MessageRole.ASSISTANT:
# we cannot have an assistant message at the start of the chat history
# if after removal of the first, we have an assistant message,
# we need to remove the assistant message too
message_count -= 1

message_str = " ".join(
[str(m.content) for m in self.chat_history[-message_count:]]
)
token_count = len(self.tokenizer_fn(message_str))
token_count = initial_token_count + len(self.tokenizer_fn(message_str))

# catch one message longer than token limit
if token_count > self.token_limit:
Expand Down
4 changes: 2 additions & 2 deletions llama_index/memory/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Optional
from typing import Any, List, Optional

from llama_index.bridge.pydantic import BaseModel
from llama_index.llms.base import LLM, ChatMessage
Expand All @@ -21,7 +21,7 @@ def from_defaults(
"""Create a chat memory from defaults."""

@abstractmethod
def get(self) -> List[ChatMessage]:
def get(self, **kwargs: Any) -> List[ChatMessage]:
"""Get chat history."""

@abstractmethod
Expand Down

0 comments on commit 9e741ab

Please sign in to comment.