Skip to content

Commit

Permalink
should_hide_tools function added to client_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
marklysze committed Jun 19, 2024
1 parent b43ec8d commit c01fb70
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 2 deletions.
54 changes: 54 additions & 0 deletions autogen/oai/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,57 @@ def validate_parameter(
param_value = default_value

return param_value


def should_hide_tools(messages: List[Dict[str, Any]], tools: List[Dict[str, Any]], hide_tools_param: str) -> bool:
"""
Determines if tools should be hidden. This function is used to hide tools when they have been run, minimising the chance of the LLM choosing them when they shouldn't.
Parameters:
messages (List[Dict[str, Any]]): List of messages
tools (List[Dict[str, Any]]): List of tools
hide_tools_param (str): "hide_tools" parameter value. Can be "if_all_run" (hide tools if all tools have been run), "if_any_run" (hide tools if any of the tools have been run), "never" (never hide tools). Default is "never".
Returns:
bool: Indicates whether the tools should be excluded from the response create request
Example Usage:
```python
# Validating a numerical parameter within specific bounds
messages = params.get("messages", [])
tools = params.get("tools", None)
hide_tools = should_hide_tools(messages, tools, params["hide_tools"])
"""

if hide_tools_param == "never" or tools is None or len(tools) == 0:
return False
elif hide_tools_param == "if_any_run":
# Return True if any tool_call_id exists, indicating a tool call has been executed. False otherwise.
return any(["tool_call_id" in dictionary for dictionary in messages])
elif hide_tools_param == "if_all_run":
# Return True if all tools have been executed at least once. False otherwise.

# Get the list of tool names
check_tool_names = [item["function"]["name"] for item in tools]

# Prepare a list of tool call ids and related function names
tool_call_ids = {}

# Loop through the messages and check if the tools have been run, removing them as we go
for message in messages:
if "tool_calls" in message:
# Register the tool id and the name
tool_call_ids[message["tool_calls"][0]["id"]] = message["tool_calls"][0]["function"]["name"]
elif "tool_call_id" in message:
# Tool called, get the name of the function based on the id
tool_name_called = tool_call_ids[message["tool_call_id"]]

# If we had not yet called the tool, check and remove it to indicate we have
if tool_name_called in check_tool_names:
check_tool_names.remove(tool_name_called)

# Return True if all tools have been called at least once (accounted for)
return len(check_tool_names) == 0
else:
raise TypeError(
f"hide_tools_param is not a valid value ['if_all_run','if_any_run','never'], got '{hide_tools_param}'"
)
178 changes: 176 additions & 2 deletions test/oai/test_client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

import autogen
from autogen.oai.client_utils import validate_parameter
from autogen.oai.client_utils import should_hide_tools, validate_parameter


def test_validate_parameter():
Expand Down Expand Up @@ -132,5 +132,179 @@ def test_validate_parameter():
assert validate_parameter({}, "max_tokens", int, True, 512, (0, None), None) == 512


def test_should_hide_tools():
# Test messages
no_tools_called_messages = [
{"content": "You are a chess program and are playing for player white.", "role": "system"},
{"content": "Let's play chess! Make a move.", "role": "user"},
{
"tool_calls": [
{
"id": "call_abcde56o5jlugh9uekgo84c6",
"function": {"arguments": "{}", "name": "get_legal_moves"},
"type": "function",
}
],
"content": None,
"role": "assistant",
},
{
"tool_calls": [
{
"id": "call_p1fla56o5jlugh9uekgo84c6",
"function": {"arguments": "{}", "name": "get_legal_moves"},
"type": "function",
}
],
"content": None,
"role": "assistant",
},
{
"tool_calls": [
{
"id": "call_lcow1j0ehuhrcr3aakdmd9ju",
"function": {"arguments": '{"move":"g1f3"}', "name": "make_move"},
"type": "function",
}
],
"content": None,
"role": "assistant",
},
]
one_tool_called_messages = [
{"content": "You are a chess program and are playing for player white.", "role": "system"},
{"content": "Let's play chess! Make a move.", "role": "user"},
{
"tool_calls": [
{
"id": "call_abcde56o5jlugh9uekgo84c6",
"function": {"arguments": "{}", "name": "get_legal_moves"},
"type": "function",
}
],
"content": None,
"role": "assistant",
},
{
"tool_call_id": "call_abcde56o5jlugh9uekgo84c6",
"role": "user",
"content": "Possible moves are: g1h3,g1f3,b1c3,b1a3,h2h3,g2g3,f2f3,e2e3,d2d3,c2c3,b2b3,a2a3,h2h4,g2g4,f2f4,e2e4,d2d4,c2c4,b2b4,a2a4",
},
{
"tool_calls": [
{
"id": "call_lcow1j0ehuhrcr3aakdmd9ju",
"function": {"arguments": '{"move":"g1f3"}', "name": "make_move"},
"type": "function",
}
],
"content": None,
"role": "assistant",
},
]
messages = [
{"content": "You are a chess program and are playing for player white.", "role": "system"},
{"content": "Let's play chess! Make a move.", "role": "user"},
{
"tool_calls": [
{
"id": "call_abcde56o5jlugh9uekgo84c6",
"function": {"arguments": "{}", "name": "get_legal_moves"},
"type": "function",
}
],
"content": None,
"role": "assistant",
},
{
"tool_call_id": "call_abcde56o5jlugh9uekgo84c6",
"role": "user",
"content": "Possible moves are: g1h3,g1f3,b1c3,b1a3,h2h3,g2g3,f2f3,e2e3,d2d3,c2c3,b2b3,a2a3,h2h4,g2g4,f2f4,e2e4,d2d4,c2c4,b2b4,a2a4",
},
{
"tool_calls": [
{
"id": "call_p1fla56o5jlugh9uekgo84c6",
"function": {"arguments": "{}", "name": "get_legal_moves"},
"type": "function",
}
],
"content": None,
"role": "assistant",
},
{
"tool_call_id": "call_p1fla56o5jlugh9uekgo84c6",
"role": "user",
"content": "Possible moves are: g1h3,g1f3,b1c3,b1a3,h2h3,g2g3,f2f3,e2e3,d2d3,c2c3,b2b3,a2a3,h2h4,g2g4,f2f4,e2e4,d2d4,c2c4,b2b4,a2a4",
},
{
"tool_calls": [
{
"id": "call_lcow1j0ehuhrcr3aakdmd9ju",
"function": {"arguments": '{"move":"g1f3"}', "name": "make_move"},
"type": "function",
}
],
"content": None,
"role": "assistant",
},
{"tool_call_id": "call_lcow1j0ehuhrcr3aakdmd9ju", "role": "user", "content": "Moved knight (♘) from g1 to f3."},
]

# Test if no tools
no_tools = []
all_tools = [
{
"type": "function",
"function": {
"description": "Call this tool to make a move after you have the list of legal moves.",
"name": "make_move",
"parameters": {
"type": "object",
"properties": {
"move": {"type": "string", "description": "A move in UCI format. (e.g. e2e4 or e7e5 or e7e8q)"}
},
"required": ["move"],
},
},
},
{
"type": "function",
"function": {
"description": "Call this tool to make a move after you have the list of legal moves.",
"name": "get_legal_moves",
"parameters": {"type": "object", "properties": {}, "required": []},
},
},
]

# Should not hide for any hide_tools value
assert not should_hide_tools(messages, no_tools, "if_all_run")
assert not should_hide_tools(messages, no_tools, "if_any_run")
assert not should_hide_tools(messages, no_tools, "never")

# Has run tools but never hide, should be false
assert not should_hide_tools(messages, all_tools, "never")

# Has run tools, should be true if all or any
assert should_hide_tools(messages, all_tools, "if_all_run")
assert should_hide_tools(messages, all_tools, "if_any_run")

# Hasn't run any tools, should be false for all
assert not should_hide_tools(no_tools_called_messages, all_tools, "if_all_run")
assert not should_hide_tools(no_tools_called_messages, all_tools, "if_any_run")
assert not should_hide_tools(no_tools_called_messages, all_tools, "never")

# Has run one of the two tools, should be true only for 'if_any_run'
assert not should_hide_tools(one_tool_called_messages, all_tools, "if_all_run")
assert should_hide_tools(one_tool_called_messages, all_tools, "if_any_run")
assert not should_hide_tools(one_tool_called_messages, all_tools, "never")

# Parameter validation
with pytest.raises(TypeError):
assert not should_hide_tools(one_tool_called_messages, all_tools, "not_a_valid_value")


if __name__ == "__main__":
test_validate_parameter()
# test_validate_parameter()
test_should_hide_tools()

0 comments on commit c01fb70

Please sign in to comment.