Skip to content

Commit

Permalink
Fix context_variable being emptied when tool call function fails
Browse files Browse the repository at this point in the history
  • Loading branch information
StreetLamb committed Dec 30, 2024
1 parent c17f569 commit 66c968e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
4 changes: 3 additions & 1 deletion rojak/workflows/agent_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ async def handle_tool_call(
f"Failed to process tool call '{name}'. "
f"Error will be sent to agent to reassess. Error: {e}"
)
result = AgentExecuteFnResult(output=str(e.cause))
result = AgentExecuteFnResult(
output=str(e.cause), context_variables=context_variables
)
tool_response = ToolResponse(tool_call_id=tool_call.id, output=result)
return AgentWorkflowResponse(output=tool_response, sender=self.agent.name)
70 changes: 70 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,76 @@ def instruct_fn(context_variables):
)


@pytest.mark.asyncio
async def test_failed_tool_call(mock_openai_client: MockOpenAIClient):
"""Context variable should be updated by first tool call only since 2nd tool call fails."""
task_queue_name = str(uuid.uuid4())

get_weather_mock = Mock()
get_air_quality_mock = Mock()

def get_weather(context_variables: dict):
get_weather_mock()
context_variables["seen"].append("get_weather")
raise Exception("Something went wrong!")

def get_air_quality(context_variables: dict):
get_air_quality_mock()
context_variables["seen"].append("get_air_quality")
return AgentExecuteFnResult(
output="Air quality is great!", context_variables=context_variables
)

messages = [
{
"role": "user",
"content": "What's the weather and air quality like in San Francisco?",
}
]

# set mock to return a response that triggers function call
mock_openai_client.set_sequential_responses(
[
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[{"name": "get_air_quality", "args": {}}],
),
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[{"name": "get_weather", "args": {}}],
),
create_mock_response(
{"role": "assistant", "content": DEFAULT_RESPONSE_CONTENT}
),
]
)

async with await WorkflowEnvironment.start_time_skipping() as env:
agent = OpenAIAgent(
name="Test Agent", functions=["get_weather", "get_air_quality"]
)
openai_activities = OpenAIAgentActivities(
OpenAIAgentOptions(
client=mock_openai_client, all_functions=[get_weather, get_air_quality]
)
)
rojak = Rojak(client=env.client, task_queue=task_queue_name)
worker = await rojak.create_worker([openai_activities])
async with worker:
context_variables = {"seen": ["test"]}
response = await rojak.run(
id=str(uuid.uuid4()),
agent=agent,
messages=messages,
context_variables=context_variables,
)
get_weather_mock.assert_called()
get_air_quality_mock.assert_called_once()
assert response.context_variables["seen"] == ["test", "get_air_quality"]
assert response.messages[-1].role == "assistant"
assert response.messages[-1].content == DEFAULT_RESPONSE_CONTENT


@pytest.mark.asyncio
async def test_multiple_tool_calls(mock_openai_client: MockOpenAIClient):
task_queue_name = str(uuid.uuid4())
Expand Down

0 comments on commit 66c968e

Please sign in to comment.