diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 505c5249b8..bfe9808134 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -8,7 +8,7 @@ from langchain_core.agents import AgentAction from langchain_core.callbacks import BaseCallbackHandler from langchain_openai import ChatOpenAI -from pydantic import Field, InstanceOf, model_validator +from pydantic import Field, InstanceOf, PrivateAttr, model_validator from crewai.agents import CacheHandler, CrewAgentExecutor, CrewAgentParser from crewai.agents.agent_builder.base_agent import BaseAgent @@ -55,8 +55,11 @@ class Agent(BaseAgent): tools: Tools at agents disposal step_callback: Callback to be executed after each step of the agent execution. callbacks: A list of callback functions from the langchain library that are triggered during the agent's execution process + allow_code_execution: Enable code execution for the agent. + max_retry_limit: Maximum number of retries for an agent to execute a task when an error occurs. """ + _times_executed: int = PrivateAttr(default=0) max_execution_time: Optional[int] = Field( default=None, description="Maximum execution time for an agent to execute a task", @@ -97,6 +100,10 @@ class Agent(BaseAgent): allow_code_execution: Optional[bool] = Field( default=False, description="Enable code execution for the agent." ) + max_retry_limit: int = Field( + default=2, + description="Maximum number of retries for an agent to execute a task when an error occurs.", + ) def __init__(__pydantic_self__, **data): config = data.pop("config", {}) @@ -185,13 +192,19 @@ def execute_task( else: task_prompt = self._use_trained_data(task_prompt=task_prompt) - result = self.agent_executor.invoke( - { - "input": task_prompt, - "tool_names": self.agent_executor.tools_names, - "tools": self.agent_executor.tools_description, - } - )["output"] + try: + result = self.agent_executor.invoke( + { + "input": task_prompt, + "tool_names": self.agent_executor.tools_names, + "tools": self.agent_executor.tools_description, + } + )["output"] + except Exception as e: + self._times_executed += 1 + if self._times_executed > self.max_retry_limit: + raise e + self.execute_task(task, context, tools) if self.max_rpm: self._rpm_controller.stop_rpm_counter() diff --git a/tests/agent_test.py b/tests/agent_test.py index d8e04c110b..8ffbe591a3 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -963,3 +963,54 @@ def test_agent_use_trained_data(crew_training_handler): crew_training_handler.assert_has_calls( [mock.call(), mock.call("trained_agents_data.pkl"), mock.call().load()] ) + + +def test_agent_max_retry_limit(): + agent = Agent( + role="test role", + goal="test goal", + backstory="test backstory", + max_retry_limit=1, + ) + + task = Task( + agent=agent, + description="Say the word: Hi", + expected_output="The word: Hi", + human_input=True, + ) + + error_message = "Error happening while sending prompt to model." + with patch.object( + CrewAgentExecutor, "invoke", wraps=agent.agent_executor.invoke + ) as invoke_mock: + invoke_mock.side_effect = Exception(error_message) + + assert agent._times_executed == 0 + assert agent.max_retry_limit == 1 + + with pytest.raises(Exception) as e: + agent.execute_task( + task=task, + ) + assert e.value.args[0] == error_message + assert agent._times_executed == 2 + + invoke_mock.assert_has_calls( + [ + mock.call( + { + "input": "Say the word: Hi\n\nThis is the expect criteria for your final answer: The word: Hi \n you MUST return the actual complete content as the final answer, not a summary.", + "tool_names": "", + "tools": "", + } + ), + mock.call( + { + "input": "Say the word: Hi\n\nThis is the expect criteria for your final answer: The word: Hi \n you MUST return the actual complete content as the final answer, not a summary.", + "tool_names": "", + "tools": "", + } + ), + ] + )