From 3cbcc45bd13790db3429549c12cbaf85e5b39416 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Mon, 26 Feb 2024 17:26:50 +0000 Subject: [PATCH 1/8] uAbility to update_model on conversable agents --- autogen/agentchat/conversable_agent.py | 26 ++++++++++++++ autogen/oai/client.py | 50 +++++++++++++++++++++++--- 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index b31c8ce786d..dd9fcf7b1d0 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -129,6 +129,7 @@ def __init__( self._name = name # a dictionary of conversations, default value is list self._oai_messages = defaultdict(list) + self._update_model_metadata = defaultdict(list) self._oai_system_message = [{"content": system_message, "role": "system"}] self._description = description if description is not None else system_message self._is_termination_msg = ( @@ -1066,6 +1067,31 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser flush=True, ) + def update_model(self, preference_data: List[Dict[str, Any]], agent: Agent, **kwargs) -> Dict[str, Any]: + """Update the model using the preference data and the conversation history. + + Args: + preference_data (List[Dict]): a list of dictionaries containing the preference data. + agent (Agent): the agent to update the model. + **kwargs: additional keyword arguments for the update model function. + + Returns: + Dict: a dictionary containing the update model statistics. + + Raises: + ValueError: If no OpenAIWrapper client is found. + ValueError: If multiple model clients are registered. + NotImplementedError: If update_model is not implemented for the underlying client. + """ + if self.client is None: + raise ValueError("No OpenAIWrapper client is found.") + messages = self._oai_messages[agent] + update_model_stats = self.client.update_model(preference_data, messages, **kwargs) + self._update_model_metadata[agent].append( + {"messages": messages, "preference_data": preference_data, "update_stats": update_model_stats} + ) + return update_model_stats + def generate_oai_reply( self, messages: Optional[List[Dict]] = None, diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 59e59815330..b27e07964fe 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -83,8 +83,7 @@ class Message(Protocol): choices: List[Choice] model: str - def create(self, **params: Any) -> ModelClientResponseProtocol: - ... # pragma: no cover + def create(self, **params: Any) -> ModelClientResponseProtocol: ... # pragma: no cover def message_retrieval( self, response: ModelClientResponseProtocol @@ -97,14 +96,30 @@ def message_retrieval( """ ... # pragma: no cover - def cost(self, response: ModelClientResponseProtocol) -> float: - ... # pragma: no cover + def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover @staticmethod def get_usage(response: ModelClientResponseProtocol) -> Dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" ... # pragma: no cover + def update_model( + self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any + ) -> Dict[str, Any]: + """Optional method to learn from the preference data, if the model supports learning. Can be missing. + + Learn from the preference data. + + Args: + preference_data: The preference data. + inference_messages: The messages used for inference. + **kwargs: other arguments. + + Returns: + Learning stats. + """ + ... # pragma: no cover + class PlaceHolderClient: def __init__(self, config): @@ -503,6 +518,33 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: ] return params + def update_model( + self, preference_data: List[Any], inference_messages: List[Dict[str, Any]], **kwargs: Any + ) -> Dict[str, Any]: + """Learn from the preference data. + + update_model is not supported for multiple model clients as it would be ambiguous which client was responsible for the inference messages. + + Args: + preference_data: The preference data. + inference_messages: The messages used for inference. + **kwargs: other arguments. + + Returns: + Learning stats. + + Raises: + ValueError: If multiple model clients are registered. + NotImplementedError: If update_model is not implemented for the client. + """ + if len(self._clients) != 1: + raise ValueError("update_model is not supported for multiple model clients.") + client = self._clients[0] + if hasattr(client, "update_model") and callable(getattr(client, "update_model")): + return client.update_model(preference_data, inference_messages, **kwargs) + else: + raise NotImplementedError(f"update_model is not implemented for {client.__class__.__name__}.") + def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: """Make a completion for a given config using available clients. Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. From 41699b951b4478ca9f11a8a5a3cdfc49bbeacc7a Mon Sep 17 00:00:00 2001 From: olgavrou Date: Mon, 26 Feb 2024 17:34:36 +0000 Subject: [PATCH 2/8] formatting --- autogen/agentchat/conversable_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index dd9fcf7b1d0..2f10e66c6b9 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -1069,7 +1069,7 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser def update_model(self, preference_data: List[Dict[str, Any]], agent: Agent, **kwargs) -> Dict[str, Any]: """Update the model using the preference data and the conversation history. - + Args: preference_data (List[Dict]): a list of dictionaries containing the preference data. agent (Agent): the agent to update the model. From c27228955e8acb470ad36f1cd19a6eae200ca976 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Mon, 26 Feb 2024 17:36:58 +0000 Subject: [PATCH 3/8] formatting --- autogen/oai/client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index b27e07964fe..35febbab7ad 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -83,7 +83,8 @@ class Message(Protocol): choices: List[Choice] model: str - def create(self, **params: Any) -> ModelClientResponseProtocol: ... # pragma: no cover + def create(self, **params: Any) -> ModelClientResponseProtocol: + ... # pragma: no cover def message_retrieval( self, response: ModelClientResponseProtocol @@ -96,7 +97,8 @@ def message_retrieval( """ ... # pragma: no cover - def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover + def cost(self, response: ModelClientResponseProtocol) -> float: + ... # pragma: no cover @staticmethod def get_usage(response: ModelClientResponseProtocol) -> Dict: From 335e8310305e717f4f53b14333f4c5027efeac15 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 5 Mar 2024 17:21:40 -0500 Subject: [PATCH 4/8] move code from conversable agent into samples/tools and add testing and README --- .github/workflows/samples-tools-tests.yml | 45 ++++ autogen/agentchat/conversable_agent.py | 26 -- autogen/oai/client.py | 4 +- samples/tools/finetuning/README.md | 81 ++++++ .../tools/finetuning/finetuning/__init__.py | 3 + .../conversable_agent_update_model.py | 33 +++ .../test_conversable_agent_update_model.py | 243 ++++++++++++++++++ 7 files changed, 407 insertions(+), 28 deletions(-) create mode 100644 .github/workflows/samples-tools-tests.yml create mode 100644 samples/tools/finetuning/README.md create mode 100644 samples/tools/finetuning/finetuning/__init__.py create mode 100644 samples/tools/finetuning/finetuning/conversable_agent_update_model.py create mode 100644 samples/tools/finetuning/tests/test_conversable_agent_update_model.py diff --git a/.github/workflows/samples-tools-tests.yml b/.github/workflows/samples-tools-tests.yml new file mode 100644 index 00000000000..a28c3b544ef --- /dev/null +++ b/.github/workflows/samples-tools-tests.yml @@ -0,0 +1,45 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: SamplesToolsTests + +on: + pull_request: + branches: ["main"] + paths: + - "autogen/**" + - "samples/tools/**" + - ".github/workflows/samples-tools-tests.yml" + - "setup.py" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} +permissions: {} +jobs: + SamplesToolsFineTuningTests: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python-version: ["3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install pytest + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Test finetuning tools + run: | + pytest samples/tools/finetuning/tests/ diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 11922482ab9..6439c2099bb 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -130,7 +130,6 @@ def __init__( self._name = name # a dictionary of conversations, default value is list self._oai_messages = defaultdict(list) - self._update_model_metadata = defaultdict(list) self._oai_system_message = [{"content": system_message, "role": "system"}] self._description = description if description is not None else system_message self._is_termination_msg = ( @@ -1142,31 +1141,6 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser flush=True, ) - def update_model(self, preference_data: List[Dict[str, Any]], agent: Agent, **kwargs) -> Dict[str, Any]: - """Update the model using the preference data and the conversation history. - - Args: - preference_data (List[Dict]): a list of dictionaries containing the preference data. - agent (Agent): the agent to update the model. - **kwargs: additional keyword arguments for the update model function. - - Returns: - Dict: a dictionary containing the update model statistics. - - Raises: - ValueError: If no OpenAIWrapper client is found. - ValueError: If multiple model clients are registered. - NotImplementedError: If update_model is not implemented for the underlying client. - """ - if self.client is None: - raise ValueError("No OpenAIWrapper client is found.") - messages = self._oai_messages[agent] - update_model_stats = self.client.update_model(preference_data, messages, **kwargs) - self._update_model_metadata[agent].append( - {"messages": messages, "preference_data": preference_data, "update_stats": update_model_stats} - ) - return update_model_stats - def generate_oai_reply( self, messages: Optional[List[Dict]] = None, diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 35febbab7ad..bbe530cb31f 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -108,7 +108,7 @@ def get_usage(response: ModelClientResponseProtocol) -> Dict: def update_model( self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any ) -> Dict[str, Any]: - """Optional method to learn from the preference data, if the model supports learning. Can be missing. + """Optional method to learn from the preference data, if the model supports learning. Can be omitted. Learn from the preference data. @@ -118,7 +118,7 @@ def update_model( **kwargs: other arguments. Returns: - Learning stats. + Dict of learning stats. """ ... # pragma: no cover diff --git a/samples/tools/finetuning/README.md b/samples/tools/finetuning/README.md new file mode 100644 index 00000000000..47793296a46 --- /dev/null +++ b/samples/tools/finetuning/README.md @@ -0,0 +1,81 @@ +# Tools for fine-tuning the local models that power agents + +This directory aims to contain tools for fine-tuning the local models that power agents. + +## Fine tune a custom model client + +AutoGen supports the use of custom models to power agents [see blog post here](https://microsoft.github.io/autogen/blog/2024/01/26/Custom-Models). This directory contains a tool to provide feedback to that model, that can be used to fine-tune the model. + +The creator of the Custom Model Client will have to decide what kind of data is going to be fed back and how it will be used to fine-tune the model. This tool is designed to be flexible and allow for a wide variety of feedback mechanisms. + +Custom Model Client will have to implement the method: + +```python +def update_model( + self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any +) -> Dict[str, Any]: + """Optional method to learn from the preference data, if the model supports learning. Can be omitted. + + Learn from the preference data. + + Args: + preference_data: The preference data. + inference_messages: The messages used for inference. + **kwargs: other arguments. + + Returns: + Dict of learning stats. + """ +``` + +The function provided in the file `conversable_agent_update_model.py` is called by passing these arguments: + +- the agent whose model is to be updated +- the preference data +- the agent who's conversation is being used to provide the inference messages + +The function will call the `update_model` method of the model client and will return a dictionary containing the update stats, messages, and preference data, like so: + +```python +{ + "update_stats": , + "messages": , + "preference_data": +} +``` + +An example of how to use this tool is shown below: + +```python +from finetuning.conversable_agent_update_model import update_model + +assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + human_input_mode="NEVER", + llm_config={ + "config_list": [], + }, +) + +assistant.register_model_client(model_client_cls=) + +user_proxy = UserProxyAgent( + "user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + code_execution_config=False, + llm_config=False, +) + +res = user_proxy.initiate_chat(assistant, message="the message") +response_content = res.summary + +# Evaluate the summary here and provide feedback. Pretending I am going to perform DPO on the response. + +# preference_data will be passed on as-is to the custom model client's update_model implementation +# so it should be in the format that the custom model client expects and is completely up to the author of the custom model client +preference_data = [("this is what the response should have been like", response_content)] + +update_model_stats = update_model(assistant, preference_data, user_proxy) +``` diff --git a/samples/tools/finetuning/finetuning/__init__.py b/samples/tools/finetuning/finetuning/__init__.py new file mode 100644 index 00000000000..8de66c2b813 --- /dev/null +++ b/samples/tools/finetuning/finetuning/__init__.py @@ -0,0 +1,3 @@ +from .conversable_agent_update_model import update_model + +__all__ = ["update_model"] diff --git a/samples/tools/finetuning/finetuning/conversable_agent_update_model.py b/samples/tools/finetuning/finetuning/conversable_agent_update_model.py new file mode 100644 index 00000000000..ba3e712f63e --- /dev/null +++ b/samples/tools/finetuning/finetuning/conversable_agent_update_model.py @@ -0,0 +1,33 @@ +from autogen import ConversableAgent, Agent +from typing import Any, Dict, List + + +def update_model( + update_agent: ConversableAgent, preference_data: List[Dict[str, Any]], other_agent: Agent, **kwargs +) -> Dict[str, Any]: + """Update the model using the preference data and the conversation history. + + Args: + update_agent (ConversableAgent): the agent who's model will be updated. + preference_data (List[Dict]): a list of dictionaries containing the preference data. + other_agent (Agent): the agent who's conversation history will be used to update the model. + **kwargs: additional keyword arguments for the update model function. + + Returns: + Dict: a dictionary containing the update stats, messages, and preference data, like so: + { + "update_stats": update_model_stats, + "messages": messages, + "preference_data": preference_data + } + + Raises: + ValueError: If no OpenAIWrapper client is found. + ValueError: If multiple model clients are registered. + NotImplementedError: If update_model is not implemented for the underlying client. + """ + if update_agent.client is None: + raise ValueError("No OpenAIWrapper client is found.") + messages = update_agent._oai_messages[other_agent] + update_model_stats = update_agent.client.update_model(preference_data, messages, **kwargs) + return {"update_stats": update_model_stats, "messages": messages, "preference_data": preference_data} diff --git a/samples/tools/finetuning/tests/test_conversable_agent_update_model.py b/samples/tools/finetuning/tests/test_conversable_agent_update_model.py new file mode 100644 index 00000000000..135ff964a16 --- /dev/null +++ b/samples/tools/finetuning/tests/test_conversable_agent_update_model.py @@ -0,0 +1,243 @@ +import pytest +from autogen import AssistantAgent, UserProxyAgent +from finetuning import update_model +from typing import Dict + +import sys + +sys.path.append("samples/tools/finetuning") + +try: + from openai import OpenAI +except ImportError: + skip = True +else: + skip = False + +TEST_CUSTOM_RESPONSE = "This is a custom response." +TEST_LOCAL_MODEL_NAME = "local_model_name" + + +def test_custom_model_client(): + TEST_LOSS = 0.5 + + class CustomModel: + def __init__(self, config: Dict): + self.model = config["model"] + self.model_name = config["model"] + + def create(self, params): + from types import SimpleNamespace + + response = SimpleNamespace() + # need to follow Client.ClientResponseProtocol + response.choices = [] + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = TEST_CUSTOM_RESPONSE + response.choices.append(choice) + response.model = self.model + return response + + def message_retrieval(self, response): + return [response.choices[0].message.content] + + def cost(self, response) -> float: + """Calculate the cost of the response.""" + response.cost = 0 + return 0 + + @staticmethod + def get_usage(response) -> Dict: + return {} + + def update_model(self, preference_data, messages, **kwargs): + return {"loss": TEST_LOSS} + + config_list = [ + { + "model": TEST_LOCAL_MODEL_NAME, + "model_client_cls": "CustomModel", + } + ] + + assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + human_input_mode="NEVER", + llm_config={ + "config_list": config_list, + }, + ) + assistant.register_model_client(model_client_cls=CustomModel) + user_proxy = UserProxyAgent( + "user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + code_execution_config=False, + llm_config=False, + ) + + res = user_proxy.initiate_chat(assistant, message="2+2=", silent=True) + response_content = res.summary + + assert response_content == TEST_CUSTOM_RESPONSE + preference_data = [("this is what the response should have been like", response_content)] + update_model_stats = update_model(assistant, preference_data, user_proxy) + assert update_model_stats["update_stats"]["loss"] == TEST_LOSS + + +def test_update_model_without_client_raises_error(): + assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + human_input_mode="NEVER", + max_consecutive_auto_reply=0, + llm_config=False, + code_execution_config=False, + ) + + user_proxy = UserProxyAgent( + "user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + code_execution_config=False, + llm_config=False, + ) + + user_proxy.initiate_chat(assistant, message="2+2=", silent=True) + with pytest.raises(ValueError): + update_model(assistant, [], user_proxy) + + +def test_custom_model_update_func_missing_raises_error(): + class CustomModel: + def __init__(self, config: Dict): + self.model = config["model"] + self.model_name = config["model"] + + def create(self, params): + from types import SimpleNamespace + + response = SimpleNamespace() + # need to follow Client.ClientResponseProtocol + response.choices = [] + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = TEST_CUSTOM_RESPONSE + response.choices.append(choice) + response.model = self.model + return response + + def message_retrieval(self, response): + return [response.choices[0].message.content] + + def cost(self, response) -> float: + """Calculate the cost of the response.""" + response.cost = 0 + return 0 + + @staticmethod + def get_usage(response) -> Dict: + return {} + + config_list = [ + { + "model": TEST_LOCAL_MODEL_NAME, + "model_client_cls": "CustomModel", + } + ] + + assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + human_input_mode="NEVER", + llm_config={ + "config_list": config_list, + }, + ) + assistant.register_model_client(model_client_cls=CustomModel) + user_proxy = UserProxyAgent( + "user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + code_execution_config=False, + llm_config=False, + ) + + res = user_proxy.initiate_chat(assistant, message="2+2=", silent=True) + response_content = res.summary + + assert response_content == TEST_CUSTOM_RESPONSE + + with pytest.raises(NotImplementedError): + update_model(assistant, [], user_proxy) + + +def test_multiple_model_clients_raises_error(): + class CustomModel: + def __init__(self, config: Dict): + self.model = config["model"] + self.model_name = config["model"] + + def create(self, params): + from types import SimpleNamespace + + response = SimpleNamespace() + # need to follow Client.ClientResponseProtocol + response.choices = [] + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = TEST_CUSTOM_RESPONSE + response.choices.append(choice) + response.model = self.model + return response + + def message_retrieval(self, response): + return [response.choices[0].message.content] + + def cost(self, response) -> float: + """Calculate the cost of the response.""" + response.cost = 0 + return 0 + + @staticmethod + def get_usage(response) -> Dict: + return {} + + def update_model(self, preference_data, messages, **kwargs): + return {} + + config_list = [ + { + "model": TEST_LOCAL_MODEL_NAME, + "model_client_cls": "CustomModel", + }, + { + "model": TEST_LOCAL_MODEL_NAME, + "model_client_cls": "CustomModel", + }, + ] + + assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + human_input_mode="NEVER", + llm_config={ + "config_list": config_list, + }, + ) + assistant.register_model_client(model_client_cls=CustomModel) + assistant.register_model_client(model_client_cls=CustomModel) + user_proxy = UserProxyAgent( + "user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + code_execution_config=False, + llm_config=False, + ) + + user_proxy.initiate_chat(assistant, message="2+2=", silent=True) + + with pytest.raises(ValueError): + update_model(assistant, [], user_proxy) From b72c2ee3603fe9a5b4cdfdc41f0721d457d39b1b Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 5 Mar 2024 17:35:24 -0500 Subject: [PATCH 5/8] forgot install step --- .github/workflows/samples-tools-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/samples-tools-tests.yml b/.github/workflows/samples-tools-tests.yml index a28c3b544ef..3bbd2a723cd 100644 --- a/.github/workflows/samples-tools-tests.yml +++ b/.github/workflows/samples-tools-tests.yml @@ -33,6 +33,7 @@ jobs: - name: Install packages and dependencies for all tests run: | python -m pip install --upgrade pip wheel + pip install -e . pip install pytest - name: Set AUTOGEN_USE_DOCKER based on OS shell: bash From f33bc009379ae05571279189aea0e5800bfd091b Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 5 Mar 2024 17:45:11 -0500 Subject: [PATCH 6/8] fix --- .../test_conversable_agent_update_model.py | 41 +++++-------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/samples/tools/finetuning/tests/test_conversable_agent_update_model.py b/samples/tools/finetuning/tests/test_conversable_agent_update_model.py index 135ff964a16..222b242f261 100644 --- a/samples/tools/finetuning/tests/test_conversable_agent_update_model.py +++ b/samples/tools/finetuning/tests/test_conversable_agent_update_model.py @@ -1,12 +1,13 @@ import pytest from autogen import AssistantAgent, UserProxyAgent -from finetuning import update_model -from typing import Dict import sys sys.path.append("samples/tools/finetuning") +from finetuning import update_model # noqa: E402 +from typing import Dict # noqa: E402 + try: from openai import OpenAI except ImportError: @@ -54,20 +55,13 @@ def get_usage(response) -> Dict: def update_model(self, preference_data, messages, **kwargs): return {"loss": TEST_LOSS} - config_list = [ - { - "model": TEST_LOCAL_MODEL_NAME, - "model_client_cls": "CustomModel", - } - ] + config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "CustomModel"}] assistant = AssistantAgent( "assistant", system_message="You are a helpful assistant.", human_input_mode="NEVER", - llm_config={ - "config_list": config_list, - }, + llm_config={"config_list": config_list}, ) assistant.register_model_client(model_client_cls=CustomModel) user_proxy = UserProxyAgent( @@ -141,20 +135,13 @@ def cost(self, response) -> float: def get_usage(response) -> Dict: return {} - config_list = [ - { - "model": TEST_LOCAL_MODEL_NAME, - "model_client_cls": "CustomModel", - } - ] + config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "CustomModel"}] assistant = AssistantAgent( "assistant", system_message="You are a helpful assistant.", human_input_mode="NEVER", - llm_config={ - "config_list": config_list, - }, + llm_config={"config_list": config_list}, ) assistant.register_model_client(model_client_cls=CustomModel) user_proxy = UserProxyAgent( @@ -209,23 +196,15 @@ def update_model(self, preference_data, messages, **kwargs): return {} config_list = [ - { - "model": TEST_LOCAL_MODEL_NAME, - "model_client_cls": "CustomModel", - }, - { - "model": TEST_LOCAL_MODEL_NAME, - "model_client_cls": "CustomModel", - }, + {"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "CustomModel"}, + {"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "CustomModel"}, ] assistant = AssistantAgent( "assistant", system_message="You are a helpful assistant.", human_input_mode="NEVER", - llm_config={ - "config_list": config_list, - }, + llm_config={"config_list": config_list}, ) assistant.register_model_client(model_client_cls=CustomModel) assistant.register_model_client(model_client_cls=CustomModel) From efbb1de970b7be1db9cb13b55b1f4695381ebe5e Mon Sep 17 00:00:00 2001 From: olgavrou Date: Wed, 6 Mar 2024 01:36:45 -0500 Subject: [PATCH 7/8] leave core lib unchanged and move everything to samples/tools --- autogen/oai/client.py | 44 --------- samples/tools/finetuning/README.md | 20 ++-- .../tools/finetuning/finetuning/__init__.py | 2 +- .../conversable_agent_update_model.py | 33 ------- .../finetuning/finetuning/update_model.py | 93 +++++++++++++++++++ .../test_conversable_agent_update_model.py | 36 +++---- 6 files changed, 125 insertions(+), 103 deletions(-) delete mode 100644 samples/tools/finetuning/finetuning/conversable_agent_update_model.py create mode 100644 samples/tools/finetuning/finetuning/update_model.py diff --git a/autogen/oai/client.py b/autogen/oai/client.py index bbe530cb31f..59e59815330 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -105,23 +105,6 @@ def get_usage(response: ModelClientResponseProtocol) -> Dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" ... # pragma: no cover - def update_model( - self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any - ) -> Dict[str, Any]: - """Optional method to learn from the preference data, if the model supports learning. Can be omitted. - - Learn from the preference data. - - Args: - preference_data: The preference data. - inference_messages: The messages used for inference. - **kwargs: other arguments. - - Returns: - Dict of learning stats. - """ - ... # pragma: no cover - class PlaceHolderClient: def __init__(self, config): @@ -520,33 +503,6 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: ] return params - def update_model( - self, preference_data: List[Any], inference_messages: List[Dict[str, Any]], **kwargs: Any - ) -> Dict[str, Any]: - """Learn from the preference data. - - update_model is not supported for multiple model clients as it would be ambiguous which client was responsible for the inference messages. - - Args: - preference_data: The preference data. - inference_messages: The messages used for inference. - **kwargs: other arguments. - - Returns: - Learning stats. - - Raises: - ValueError: If multiple model clients are registered. - NotImplementedError: If update_model is not implemented for the client. - """ - if len(self._clients) != 1: - raise ValueError("update_model is not supported for multiple model clients.") - client = self._clients[0] - if hasattr(client, "update_model") and callable(getattr(client, "update_model")): - return client.update_model(preference_data, inference_messages, **kwargs) - else: - raise NotImplementedError(f"update_model is not implemented for {client.__class__.__name__}.") - def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: """Make a completion for a given config using available clients. Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. diff --git a/samples/tools/finetuning/README.md b/samples/tools/finetuning/README.md index 47793296a46..2ef931c82bf 100644 --- a/samples/tools/finetuning/README.md +++ b/samples/tools/finetuning/README.md @@ -8,7 +8,7 @@ AutoGen supports the use of custom models to power agents [see blog post here](h The creator of the Custom Model Client will have to decide what kind of data is going to be fed back and how it will be used to fine-tune the model. This tool is designed to be flexible and allow for a wide variety of feedback mechanisms. -Custom Model Client will have to implement the method: +Custom Model Client will have follow the protocol client defined in `update_model.py` `UpdateableModelClient` which is a subclass of `ModelClient` and adds the following method: ```python def update_model( @@ -20,7 +20,7 @@ def update_model( Args: preference_data: The preference data. - inference_messages: The messages used for inference. + inference_messages: The messages that were used during inference between the agent that is being updated and another agent. **kwargs: other arguments. Returns: @@ -28,26 +28,32 @@ def update_model( """ ``` -The function provided in the file `conversable_agent_update_model.py` is called by passing these arguments: +The function provided in the file `update_model.py` is called by passing these arguments: - the agent whose model is to be updated - the preference data -- the agent who's conversation is being used to provide the inference messages +- the agent whose conversation is being used to provide the inference messages -The function will call the `update_model` method of the model client and will return a dictionary containing the update stats, messages, and preference data, like so: +The function will find the conversation thread that occurred between the "update agent" and the "other agent", and call the `update_model` method of the model client. It will return a dictionary containing the update stats, inference messages, and preference data: ```python { "update_stats": , - "messages": , + "inference_messages": , "preference_data": } ``` +**NOTES**: + +`inference_messages` will contain messages that were passed into the custom model client when `create` was called and a response was needed from the model. It is up to the author of the custom model client to decide which parts of the conversation are needed and how to use this data to fine-tune the model. + +If a conversation has been long-running before `update_model` is called, then the `inference_messages` will contain a conversation thread that was used for multiple inference steps. It is again up to the author of the custom model client to decide which parts of the conversation correspond to the preference data and how to use this data to fine-tune the model. + An example of how to use this tool is shown below: ```python -from finetuning.conversable_agent_update_model import update_model +from finetuning.update_model import update_model assistant = AssistantAgent( "assistant", diff --git a/samples/tools/finetuning/finetuning/__init__.py b/samples/tools/finetuning/finetuning/__init__.py index 8de66c2b813..3f1f928d57c 100644 --- a/samples/tools/finetuning/finetuning/__init__.py +++ b/samples/tools/finetuning/finetuning/__init__.py @@ -1,3 +1,3 @@ -from .conversable_agent_update_model import update_model +from .update_model import update_model __all__ = ["update_model"] diff --git a/samples/tools/finetuning/finetuning/conversable_agent_update_model.py b/samples/tools/finetuning/finetuning/conversable_agent_update_model.py deleted file mode 100644 index ba3e712f63e..00000000000 --- a/samples/tools/finetuning/finetuning/conversable_agent_update_model.py +++ /dev/null @@ -1,33 +0,0 @@ -from autogen import ConversableAgent, Agent -from typing import Any, Dict, List - - -def update_model( - update_agent: ConversableAgent, preference_data: List[Dict[str, Any]], other_agent: Agent, **kwargs -) -> Dict[str, Any]: - """Update the model using the preference data and the conversation history. - - Args: - update_agent (ConversableAgent): the agent who's model will be updated. - preference_data (List[Dict]): a list of dictionaries containing the preference data. - other_agent (Agent): the agent who's conversation history will be used to update the model. - **kwargs: additional keyword arguments for the update model function. - - Returns: - Dict: a dictionary containing the update stats, messages, and preference data, like so: - { - "update_stats": update_model_stats, - "messages": messages, - "preference_data": preference_data - } - - Raises: - ValueError: If no OpenAIWrapper client is found. - ValueError: If multiple model clients are registered. - NotImplementedError: If update_model is not implemented for the underlying client. - """ - if update_agent.client is None: - raise ValueError("No OpenAIWrapper client is found.") - messages = update_agent._oai_messages[other_agent] - update_model_stats = update_agent.client.update_model(preference_data, messages, **kwargs) - return {"update_stats": update_model_stats, "messages": messages, "preference_data": preference_data} diff --git a/samples/tools/finetuning/finetuning/update_model.py b/samples/tools/finetuning/finetuning/update_model.py new file mode 100644 index 00000000000..df9fc6496b7 --- /dev/null +++ b/samples/tools/finetuning/finetuning/update_model.py @@ -0,0 +1,93 @@ +from autogen import ConversableAgent, Agent, OpenAIWrapper, ModelClient +from typing import Any, Dict, List, Protocol + + +class UpdateableModelClient(ModelClient, Protocol): + def update_model( + self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any + ) -> Dict[str, Any]: + """Optional method to learn from the preference data, if the model supports learning. Can be omitted. + + Learn from the preference data. + + Args: + preference_data: The preference data. + inference_messages: The messages used for inference. + **kwargs: other arguments. + + Returns: + Dict of learning stats. + """ + ... # pragma: no cover + + +def _client_wrapper_update_model( + oai_wrapper_client: OpenAIWrapper, + preference_data: List[Any], + inference_messages: List[Dict[str, Any]], + **kwargs: Any, +) -> Dict[str, Any]: + """Learn from the preference data. + + update_model is not supported for multiple model clients as it would be ambiguous which client was responsible for the inference messages. + + Args: + oai_wrapper_client: The OpenAIWrapper client. + preference_data: The preference data. + inference_messages: The messages that were used during inference between the agent that is being updated and another agent. + **kwargs: other arguments. + + Returns: + Learning stats. + + Raises: + ValueError: If multiple model clients are registered. + NotImplementedError: If update_model is not implemented for the client. + """ + + clients = oai_wrapper_client._clients + + if len(clients) != 1: + raise ValueError("update_model is not supported for multiple model clients.") + client = clients[0] + if hasattr(client, "update_model") and callable(getattr(client, "update_model")): + return client.update_model(preference_data, inference_messages, **kwargs) + else: + raise NotImplementedError(f"update_model is not implemented for {client.__class__.__name__}.") + + +def update_model( + update_agent: ConversableAgent, preference_data: List[Dict[str, Any]], other_agent: Agent, **kwargs +) -> Dict[str, Any]: + """Update the model using the preference data and the conversation history. + + Args: + update_agent (ConversableAgent): the agent whose model will be updated. + preference_data (List[Dict]): a list of dictionaries containing the preference data. + other_agent (Agent): the agent whose conversation history will be used to update the model. + **kwargs: additional keyword arguments for the update model function. + + Returns: + Dict: a dictionary containing the update stats, inference_messages, and preference data, like so: + { + "update_stats": update_model_stats, + "inference_messages": inference_messages, + "preference_data": preference_data + } + + Raises: + ValueError: If no OpenAIWrapper client is found. + ValueError: If multiple model clients are registered. + NotImplementedError: If update_model is not implemented for the underlying client. + """ + if update_agent.client is None: + raise ValueError("No OpenAIWrapper client is found.") + inference_messages = update_agent._oai_messages[other_agent] + update_model_stats = _client_wrapper_update_model( + update_agent.client, preference_data, inference_messages, **kwargs + ) + return { + "update_stats": update_model_stats, + "inference_messages": inference_messages, + "preference_data": preference_data, + } diff --git a/samples/tools/finetuning/tests/test_conversable_agent_update_model.py b/samples/tools/finetuning/tests/test_conversable_agent_update_model.py index 222b242f261..042e98a6333 100644 --- a/samples/tools/finetuning/tests/test_conversable_agent_update_model.py +++ b/samples/tools/finetuning/tests/test_conversable_agent_update_model.py @@ -1,28 +1,25 @@ import pytest from autogen import AssistantAgent, UserProxyAgent - import sys +import os sys.path.append("samples/tools/finetuning") from finetuning import update_model # noqa: E402 from typing import Dict # noqa: E402 -try: - from openai import OpenAI -except ImportError: - skip = True -else: - skip = False +sys.path.append("test") +from conftest import skip_openai # noqa: E402 TEST_CUSTOM_RESPONSE = "This is a custom response." TEST_LOCAL_MODEL_NAME = "local_model_name" +@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests") def test_custom_model_client(): TEST_LOSS = 0.5 - class CustomModel: + class UpdatableCustomModel: def __init__(self, config: Dict): self.model = config["model"] self.model_name = config["model"] @@ -55,7 +52,7 @@ def get_usage(response) -> Dict: def update_model(self, preference_data, messages, **kwargs): return {"loss": TEST_LOSS} - config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "CustomModel"}] + config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}] assistant = AssistantAgent( "assistant", @@ -63,7 +60,7 @@ def update_model(self, preference_data, messages, **kwargs): human_input_mode="NEVER", llm_config={"config_list": config_list}, ) - assistant.register_model_client(model_client_cls=CustomModel) + assistant.register_model_client(model_client_cls=UpdatableCustomModel) user_proxy = UserProxyAgent( "user_proxy", human_input_mode="NEVER", @@ -81,6 +78,7 @@ def update_model(self, preference_data, messages, **kwargs): assert update_model_stats["update_stats"]["loss"] == TEST_LOSS +@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests") def test_update_model_without_client_raises_error(): assistant = AssistantAgent( "assistant", @@ -104,8 +102,9 @@ def test_update_model_without_client_raises_error(): update_model(assistant, [], user_proxy) +@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests") def test_custom_model_update_func_missing_raises_error(): - class CustomModel: + class UpdatableCustomModel: def __init__(self, config: Dict): self.model = config["model"] self.model_name = config["model"] @@ -135,7 +134,7 @@ def cost(self, response) -> float: def get_usage(response) -> Dict: return {} - config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "CustomModel"}] + config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}] assistant = AssistantAgent( "assistant", @@ -143,7 +142,7 @@ def get_usage(response) -> Dict: human_input_mode="NEVER", llm_config={"config_list": config_list}, ) - assistant.register_model_client(model_client_cls=CustomModel) + assistant.register_model_client(model_client_cls=UpdatableCustomModel) user_proxy = UserProxyAgent( "user_proxy", human_input_mode="NEVER", @@ -161,8 +160,9 @@ def get_usage(response) -> Dict: update_model(assistant, [], user_proxy) +@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests") def test_multiple_model_clients_raises_error(): - class CustomModel: + class UpdatableCustomModel: def __init__(self, config: Dict): self.model = config["model"] self.model_name = config["model"] @@ -196,8 +196,8 @@ def update_model(self, preference_data, messages, **kwargs): return {} config_list = [ - {"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "CustomModel"}, - {"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "CustomModel"}, + {"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}, + {"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}, ] assistant = AssistantAgent( @@ -206,8 +206,8 @@ def update_model(self, preference_data, messages, **kwargs): human_input_mode="NEVER", llm_config={"config_list": config_list}, ) - assistant.register_model_client(model_client_cls=CustomModel) - assistant.register_model_client(model_client_cls=CustomModel) + assistant.register_model_client(model_client_cls=UpdatableCustomModel) + assistant.register_model_client(model_client_cls=UpdatableCustomModel) user_proxy = UserProxyAgent( "user_proxy", human_input_mode="NEVER", From 62c5bb9685bc98aeaf492064cad228387e843f5c Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 7 Mar 2024 19:30:57 -0500 Subject: [PATCH 8/8] remove skip openai --- .../finetuning/tests/test_conversable_agent_update_model.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/samples/tools/finetuning/tests/test_conversable_agent_update_model.py b/samples/tools/finetuning/tests/test_conversable_agent_update_model.py index 042e98a6333..56f9474a268 100644 --- a/samples/tools/finetuning/tests/test_conversable_agent_update_model.py +++ b/samples/tools/finetuning/tests/test_conversable_agent_update_model.py @@ -1,7 +1,6 @@ import pytest from autogen import AssistantAgent, UserProxyAgent import sys -import os sys.path.append("samples/tools/finetuning") @@ -9,13 +8,11 @@ from typing import Dict # noqa: E402 sys.path.append("test") -from conftest import skip_openai # noqa: E402 TEST_CUSTOM_RESPONSE = "This is a custom response." TEST_LOCAL_MODEL_NAME = "local_model_name" -@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests") def test_custom_model_client(): TEST_LOSS = 0.5 @@ -78,7 +75,6 @@ def update_model(self, preference_data, messages, **kwargs): assert update_model_stats["update_stats"]["loss"] == TEST_LOSS -@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests") def test_update_model_without_client_raises_error(): assistant = AssistantAgent( "assistant", @@ -102,7 +98,6 @@ def test_update_model_without_client_raises_error(): update_model(assistant, [], user_proxy) -@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests") def test_custom_model_update_func_missing_raises_error(): class UpdatableCustomModel: def __init__(self, config: Dict): @@ -160,7 +155,6 @@ def get_usage(response) -> Dict: update_model(assistant, [], user_proxy) -@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests") def test_multiple_model_clients_raises_error(): class UpdatableCustomModel: def __init__(self, config: Dict):