Skip to content

Commit

Permalink
Feat/sliding context window (#1042)
Browse files Browse the repository at this point in the history
* patching for non-gpt model

* removal of json_object tool name assignment

* fixed issue for smaller models due to instructions prompt

* fixing for ollama llama3 models

* WIP: generated summary from documents split, could also create memgpt approach

* WIP: need tests but user inputted summarization strategy implemented - handling context window exceeding errors

* rm extra line

* removed type ignores

* added tests

* handling n to summarize prompt

* code cleanup, using click for cli asker

* rm not used class

* better refactor

* reverted poetry lock

* reverted poetry.locl

* improved context window exceeding exception class
  • Loading branch information
lorenzejay committed Aug 1, 2024
1 parent c93b85a commit 8118b7b
Show file tree
Hide file tree
Showing 7 changed files with 572 additions and 3 deletions.
127 changes: 125 additions & 2 deletions src/crewai/agents/executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import threading
import time
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
import click


from langchain.agents import AgentExecutor
from langchain.agents.agent import ExceptionTool
Expand All @@ -11,12 +13,21 @@
from langchain_core.utils.input import get_color_mapping
from pydantic import InstanceOf

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain

from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.tools_handler import ToolsHandler


from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.logger import Logger


class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
Expand All @@ -40,6 +51,8 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
system_template: Optional[str] = None
prompt_template: Optional[str] = None
response_template: Optional[str] = None
_logger: Logger = Logger(verbose_level=2)
_fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize"

def _call(
self,
Expand Down Expand Up @@ -131,7 +144,7 @@ def _iter_next_step(
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)

# Call the LLM to see what to do.
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
Expand Down Expand Up @@ -185,6 +198,27 @@ def _iter_next_step(
yield AgentStep(action=output, observation=observation)
return

except Exception as e:
if LLMContextLengthExceededException(str(e))._is_context_limit_error(
str(e)
):
output = self._handle_context_length_error(
intermediate_steps, run_manager, inputs
)

if isinstance(output, AgentFinish):
yield output
elif isinstance(output, list):
for step in output:
yield step
return

yield AgentStep(
action=AgentAction("_Exception", str(e), str(e)),
observation=str(e),
)
return

# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.should_ask_for_human_input:
Expand Down Expand Up @@ -235,6 +269,7 @@ def _iter_next_step(
agent=self.crew_agent,
action=agent_action,
)

tool_calling = tool_usage.parse(agent_action.log)

if isinstance(tool_calling, ToolUsageErrorException):
Expand Down Expand Up @@ -280,3 +315,91 @@ def _handle_crew_training_output(
CrewTrainingHandler(TRAINING_DATA_FILE).append(
self.crew._train_iteration, agent_id, training_data
)

def _handle_context_length(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> List[Tuple[AgentAction, str]]:
text = intermediate_steps[0][1]
original_action = intermediate_steps[0][0]

text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"],
chunk_size=8000,
chunk_overlap=500,
)

if self._fit_context_window_strategy == "summarize":
docs = text_splitter.create_documents([text])
self._logger.log(
"debug",
"Summarizing Content, it is recommended to use a RAG tool",
color="bold_blue",
)
summarize_chain = load_summarize_chain(
self.llm, chain_type="map_reduce", verbose=True
)
summarized_docs = []
for doc in docs:
summary = summarize_chain.invoke(
{"input_documents": [doc]}, return_only_outputs=True
)

summarized_docs.append(summary["output_text"])

formatted_results = "\n\n".join(summarized_docs)
summary_step = AgentStep(
action=AgentAction(
tool=original_action.tool,
tool_input=original_action.tool_input,
log=original_action.log,
),
observation=formatted_results,
)
summary_tuple = (summary_step.action, summary_step.observation)
return [summary_tuple]

return intermediate_steps

def _handle_context_length_error(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun],
inputs: Dict[str, str],
) -> Union[AgentFinish, List[AgentStep]]:
self._logger.log(
"debug",
"Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.",
color="yellow",
)
user_choice = click.confirm(
"Context length exceeded. Do you want to summarize the text to fit models context window?"
)
if user_choice:
self._logger.log(
"debug",
"Context length exceeded. Using summarize prompt to fit, this will reduce context length.",
color="bold_blue",
)
intermediate_steps = self._handle_context_length(intermediate_steps)

output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)

if isinstance(output, AgentFinish):
return output
elif isinstance(output, AgentAction):
return [AgentStep(action=output, observation=None)]
else:
return [AgentStep(action=action, observation=None) for action in output]
else:
self._logger.log(
"debug",
"Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
color="red",
)
raise SystemExit(
"Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools."
)
2 changes: 1 addition & 1 deletion src/crewai/tools/tool_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
except ImportError:
agentops = None

OPENAI_BIGGER_MODELS = ["gpt-4"]
OPENAI_BIGGER_MODELS = ["gpt-4o"]


class ToolUsageErrorException(Exception):
Expand Down
4 changes: 4 additions & 0 deletions src/crewai/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from .printer import Printer
from .prompts import Prompts
from .rpm_controller import RPMController
from .exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)

__all__ = [
"Converter",
Expand All @@ -19,4 +22,5 @@
"Prompts",
"RPMController",
"YamlParser",
"LLMContextLengthExceededException",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
class LLMContextLengthExceededException(Exception):
CONTEXT_LIMIT_ERRORS = [
"maximum context length",
"context length exceeded",
"context_length_exceeded",
"context window full",
"too many tokens",
"input is too long",
"exceeds token limit",
]

def __init__(self, error_message: str):
self.original_error_message = error_message
super().__init__(self._get_error_message(error_message))

def _is_context_limit_error(self, error_message: str) -> bool:
return any(
phrase.lower() in error_message.lower()
for phrase in self.CONTEXT_LIMIT_ERRORS
)

def _get_error_message(self, error_message: str):
return (
f"LLM context length exceeded. Original error: {error_message}\n"
"Consider using a smaller input or implementing a text splitting strategy."
)
73 changes: 73 additions & 0 deletions tests/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain.tools import tool
from langchain_core.exceptions import OutputParserException
from langchain_openai import ChatOpenAI
from langchain.schema import AgentAction

from crewai import Agent, Crew, Task
from crewai.agents.cache import CacheHandler
Expand Down Expand Up @@ -1014,3 +1015,75 @@ def test_agent_max_retry_limit():
),
]
)


@pytest.mark.vcr(filter_headers=["authorization"])
def test_handle_context_length_exceeds_limit():
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
)
original_action = AgentAction(
tool="test_tool", tool_input="test_input", log="test_log"
)

with patch.object(
CrewAgentExecutor, "_iter_next_step", wraps=agent.agent_executor._iter_next_step
) as private_mock:
task = Task(
description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.",
expected_output="The final answer",
)
agent.execute_task(
task=task,
)
private_mock.assert_called_once()
with patch("crewai.agents.executor.click") as mock_prompt:
mock_prompt.return_value = "y"
with patch.object(
CrewAgentExecutor, "_handle_context_length"
) as mock_handle_context:
mock_handle_context.side_effect = ValueError(
"Context length limit exceeded"
)

long_input = "This is a very long input. " * 10000

# Attempt to handle context length, expecting the mocked error
with pytest.raises(ValueError) as excinfo:
agent.agent_executor._handle_context_length(
[(original_action, long_input)]
)

assert "Context length limit exceeded" in str(excinfo.value)
mock_handle_context.assert_called_once()


@pytest.mark.vcr(filter_headers=["authorization"])
def test_handle_context_length_exceeds_limit_cli_no():
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
)
task = Task(description="test task", agent=agent, expected_output="test output")

with patch.object(
CrewAgentExecutor, "_iter_next_step", wraps=agent.agent_executor._iter_next_step
) as private_mock:
task = Task(
description="The final answer is 42. But don't give it yet, instead keep using the `get_final_answer` tool.",
expected_output="The final answer",
)
agent.execute_task(
task=task,
)
private_mock.assert_called_once()
with patch("crewai.agents.executor.click") as mock_prompt:
mock_prompt.return_value = "n"
pytest.raises(SystemExit)
with patch.object(
CrewAgentExecutor, "_handle_context_length"
) as mock_handle_context:
mock_handle_context.assert_not_called()
Loading

0 comments on commit 8118b7b

Please sign in to comment.