From fec7c93e4a02a9d7291691ae388ea0a210de735a Mon Sep 17 00:00:00 2001 From: jtoy Date: Wed, 22 May 2024 12:53:39 -0700 Subject: [PATCH] add warning if duplicate function is registered (#2159) * add warning if duplicate function is registereed * check _function_map and llm_config * check function_map and llm_config * use register_function and llm_config * cleanups * cleanups * warning test * warning test * more test coverage * use a fake config * formatting * formatting --------- Co-authored-by: Jason Co-authored-by: Eric Zhu --- autogen/agentchat/conversable_agent.py | 9 +++- test/agentchat/test_conversable_agent.py | 59 ++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 89b2dd94345..c3394a96bb6 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -2406,6 +2406,8 @@ def register_function(self, function_map: Dict[str, Union[Callable, None]]): self._assert_valid_name(name) if func is None and name not in self._function_map.keys(): warnings.warn(f"The function {name} to remove doesn't exist", name) + if name in self._function_map: + warnings.warn(f"Function '{name}' is being overridden.", UserWarning) self._function_map.update(function_map) self._function_map = {k: v for k, v in self._function_map.items() if v is not None} @@ -2442,6 +2444,9 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None) self._assert_valid_name(func_sig["name"]) if "functions" in self.llm_config.keys(): + if any(func["name"] == func_sig["name"] for func in self.llm_config["functions"]): + warnings.warn(f"Function '{func_sig['name']}' is being overridden.", UserWarning) + self.llm_config["functions"] = [ func for func in self.llm_config["functions"] if func.get("name") != func_sig["name"] ] + [func_sig] @@ -2481,7 +2486,9 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None): f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}" ) self._assert_valid_name(tool_sig["function"]["name"]) - if "tools" in self.llm_config.keys(): + if "tools" in self.llm_config: + if any(tool["function"]["name"] == tool_sig["function"]["name"] for tool in self.llm_config["tools"]): + warnings.warn(f"Function '{tool_sig['function']['name']}' is being overridden.", UserWarning) self.llm_config["tools"] = [ tool for tool in self.llm_config["tools"] diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index b81a897b47c..3c2e79beb13 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -1403,6 +1403,64 @@ def test_http_client(): ) +def test_adding_duplicate_function_warning(): + + config_base = [{"base_url": "http://0.0.0.0:8000", "api_key": "NULL"}] + + agent = autogen.ConversableAgent( + "jtoy", + llm_config={"config_list": config_base}, + ) + + def sample_function(): + pass + + agent.register_function( + function_map={ + "sample_function": sample_function, + } + ) + agent.update_function_signature( + { + "name": "foo", + }, + is_remove=False, + ) + agent.update_tool_signature( + { + "type": "function", + "function": { + "name": "yo", + }, + }, + is_remove=False, + ) + + with pytest.warns(UserWarning, match="Function 'sample_function' is being overridden."): + agent.register_function( + function_map={ + "sample_function": sample_function, + } + ) + with pytest.warns(UserWarning, match="Function 'foo' is being overridden."): + agent.update_function_signature( + { + "name": "foo", + }, + is_remove=False, + ) + with pytest.warns(UserWarning, match="Function 'yo' is being overridden."): + agent.update_tool_signature( + { + "type": "function", + "function": { + "name": "yo", + }, + }, + is_remove=False, + ) + + if __name__ == "__main__": # test_trigger() # test_context() @@ -1414,4 +1472,5 @@ def test_http_client(): # test_process_before_send() # test_message_func() test_summary() + test_adding_duplicate_function_warning() # test_function_registration_e2e_sync()