Skip to content

Commit

Permalink
Fix input interpolation bug (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
gvieira committed Mar 22, 2024
1 parent aa0eb02 commit 128ce91
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 5 deletions.
17 changes: 14 additions & 3 deletions src/crewai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class Agent(BaseModel):
default=None, description="Callback to be executed"
)

_original_role: str | None = None
_original_goal: str | None = None
_original_backstory: str | None = None

def __init__(__pydantic_self__, **data):
config = data.pop("config", {})
super().__init__(**config, **data)
Expand Down Expand Up @@ -282,10 +286,17 @@ def create_agent_executor(self, tools=None) -> None:

def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Interpolate inputs into the agent description and backstory."""
if self._original_role is None:
self._original_role = self.role
if self._original_goal is None:
self._original_goal = self.goal
if self._original_backstory is None:
self._original_backstory = self.backstory

if inputs:
self.role = self.role.format(**inputs)
self.goal = self.goal.format(**inputs)
self.backstory = self.backstory.format(**inputs)
self.role = self._original_role.format(**inputs)
self.goal = self._original_goal.format(**inputs)
self.backstory = self._original_backstory.format(**inputs)

def increment_formatting_errors(self) -> None:
"""Count the formatting errors of the agent."""
Expand Down
12 changes: 10 additions & 2 deletions src/crewai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ class Config:
description="Unique identifier for the object, not set by user.",
)

_original_description: str | None = None
_original_expected_output: str | None = None

def __init__(__pydantic_self__, **data):
config = data.pop("config", {})
super().__init__(**config, **data)
Expand Down Expand Up @@ -189,9 +192,14 @@ def prompt(self) -> str:

def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Interpolate inputs into the task description and expected output."""
if self._original_description is None:
self._original_description = self.description
if self._original_expected_output is None:
self._original_expected_output = self.expected_output

if inputs:
self.description = self.description.format(**inputs)
self.expected_output = self.expected_output.format(**inputs)
self.description = self._original_description.format(**inputs)
self.expected_output = self._original_expected_output.format(**inputs)

def increment_tools_errors(self) -> None:
"""Increment the tools errors counter."""
Expand Down
18 changes: 18 additions & 0 deletions tests/agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,21 @@ def test_agent_definition_based_on_dict():
assert agent.backstory == "test backstory"
assert agent.verbose == True
assert agent.tools == []


def test_interpolate_inputs():
agent = Agent(
role="{topic} specialist",
goal="Figure {goal} out",
backstory="I am the master of {role}",
)

agent.interpolate_inputs({"topic": "AI", "goal": "life", "role": "all things"})
assert agent.role == "AI specialist"
assert agent.goal == "Figure life out"
assert agent.backstory == "I am the master of all things"

agent.interpolate_inputs({"topic": "Sales", "goal": "stuff", "role": "nothing"})
assert agent.role == "Sales specialist"
assert agent.goal == "Figure stuff out"
assert agent.backstory == "I am the master of nothing"
21 changes: 21 additions & 0 deletions tests/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,24 @@ def test_task_definition_based_on_dict():
assert task.description == config["description"]
assert task.expected_output == config["expected_output"]
assert task.agent is None


def test_interpolate_inputs():
task = Task(
description="Give me a list of 5 interesting ideas about {topic} to explore for an article, what makes them unique and interesting.",
expected_output="Bullet point list of 5 interesting ideas about {topic}.",
)

task.interpolate_inputs(inputs={"topic": "AI"})
assert (
task.description
== "Give me a list of 5 interesting ideas about AI to explore for an article, what makes them unique and interesting."
)
assert task.expected_output == "Bullet point list of 5 interesting ideas about AI."

task.interpolate_inputs(inputs={"topic": "ML"})
assert (
task.description
== "Give me a list of 5 interesting ideas about ML to explore for an article, what makes them unique and interesting."
)
assert task.expected_output == "Bullet point list of 5 interesting ideas about ML."

0 comments on commit 128ce91

Please sign in to comment.