diff --git a/.github/workflows/samples-tools-tests.yml b/.github/workflows/samples-tools-tests.yml new file mode 100644 index 00000000000..3bbd2a723cd --- /dev/null +++ b/.github/workflows/samples-tools-tests.yml @@ -0,0 +1,46 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: SamplesToolsTests + +on: + pull_request: + branches: ["main"] + paths: + - "autogen/**" + - "samples/tools/**" + - ".github/workflows/samples-tools-tests.yml" + - "setup.py" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} +permissions: {} +jobs: + SamplesToolsFineTuningTests: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python-version: ["3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install packages and dependencies for all tests + run: | + python -m pip install --upgrade pip wheel + pip install -e . + pip install pytest + - name: Set AUTOGEN_USE_DOCKER based on OS + shell: bash + run: | + if [[ ${{ matrix.os }} != ubuntu-latest ]]; then + echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV + fi + - name: Test finetuning tools + run: | + pytest samples/tools/finetuning/tests/ diff --git a/samples/tools/finetuning/README.md b/samples/tools/finetuning/README.md new file mode 100644 index 00000000000..2ef931c82bf --- /dev/null +++ b/samples/tools/finetuning/README.md @@ -0,0 +1,87 @@ +# Tools for fine-tuning the local models that power agents + +This directory aims to contain tools for fine-tuning the local models that power agents. + +## Fine tune a custom model client + +AutoGen supports the use of custom models to power agents [see blog post here](https://microsoft.github.io/autogen/blog/2024/01/26/Custom-Models). This directory contains a tool to provide feedback to that model, that can be used to fine-tune the model. + +The creator of the Custom Model Client will have to decide what kind of data is going to be fed back and how it will be used to fine-tune the model. This tool is designed to be flexible and allow for a wide variety of feedback mechanisms. + +Custom Model Client will have follow the protocol client defined in `update_model.py` `UpdateableModelClient` which is a subclass of `ModelClient` and adds the following method: + +```python +def update_model( + self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any +) -> Dict[str, Any]: + """Optional method to learn from the preference data, if the model supports learning. Can be omitted. + + Learn from the preference data. + + Args: + preference_data: The preference data. + inference_messages: The messages that were used during inference between the agent that is being updated and another agent. + **kwargs: other arguments. + + Returns: + Dict of learning stats. + """ +``` + +The function provided in the file `update_model.py` is called by passing these arguments: + +- the agent whose model is to be updated +- the preference data +- the agent whose conversation is being used to provide the inference messages + +The function will find the conversation thread that occurred between the "update agent" and the "other agent", and call the `update_model` method of the model client. It will return a dictionary containing the update stats, inference messages, and preference data: + +```python +{ + "update_stats": , + "inference_messages": , + "preference_data": +} +``` + +**NOTES**: + +`inference_messages` will contain messages that were passed into the custom model client when `create` was called and a response was needed from the model. It is up to the author of the custom model client to decide which parts of the conversation are needed and how to use this data to fine-tune the model. + +If a conversation has been long-running before `update_model` is called, then the `inference_messages` will contain a conversation thread that was used for multiple inference steps. It is again up to the author of the custom model client to decide which parts of the conversation correspond to the preference data and how to use this data to fine-tune the model. + +An example of how to use this tool is shown below: + +```python +from finetuning.update_model import update_model + +assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + human_input_mode="NEVER", + llm_config={ + "config_list": [], + }, +) + +assistant.register_model_client(model_client_cls=) + +user_proxy = UserProxyAgent( + "user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + code_execution_config=False, + llm_config=False, +) + +res = user_proxy.initiate_chat(assistant, message="the message") +response_content = res.summary + +# Evaluate the summary here and provide feedback. Pretending I am going to perform DPO on the response. + +# preference_data will be passed on as-is to the custom model client's update_model implementation +# so it should be in the format that the custom model client expects and is completely up to the author of the custom model client +preference_data = [("this is what the response should have been like", response_content)] + +update_model_stats = update_model(assistant, preference_data, user_proxy) +``` diff --git a/samples/tools/finetuning/finetuning/__init__.py b/samples/tools/finetuning/finetuning/__init__.py new file mode 100644 index 00000000000..3f1f928d57c --- /dev/null +++ b/samples/tools/finetuning/finetuning/__init__.py @@ -0,0 +1,3 @@ +from .update_model import update_model + +__all__ = ["update_model"] diff --git a/samples/tools/finetuning/finetuning/update_model.py b/samples/tools/finetuning/finetuning/update_model.py new file mode 100644 index 00000000000..df9fc6496b7 --- /dev/null +++ b/samples/tools/finetuning/finetuning/update_model.py @@ -0,0 +1,93 @@ +from autogen import ConversableAgent, Agent, OpenAIWrapper, ModelClient +from typing import Any, Dict, List, Protocol + + +class UpdateableModelClient(ModelClient, Protocol): + def update_model( + self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any + ) -> Dict[str, Any]: + """Optional method to learn from the preference data, if the model supports learning. Can be omitted. + + Learn from the preference data. + + Args: + preference_data: The preference data. + inference_messages: The messages used for inference. + **kwargs: other arguments. + + Returns: + Dict of learning stats. + """ + ... # pragma: no cover + + +def _client_wrapper_update_model( + oai_wrapper_client: OpenAIWrapper, + preference_data: List[Any], + inference_messages: List[Dict[str, Any]], + **kwargs: Any, +) -> Dict[str, Any]: + """Learn from the preference data. + + update_model is not supported for multiple model clients as it would be ambiguous which client was responsible for the inference messages. + + Args: + oai_wrapper_client: The OpenAIWrapper client. + preference_data: The preference data. + inference_messages: The messages that were used during inference between the agent that is being updated and another agent. + **kwargs: other arguments. + + Returns: + Learning stats. + + Raises: + ValueError: If multiple model clients are registered. + NotImplementedError: If update_model is not implemented for the client. + """ + + clients = oai_wrapper_client._clients + + if len(clients) != 1: + raise ValueError("update_model is not supported for multiple model clients.") + client = clients[0] + if hasattr(client, "update_model") and callable(getattr(client, "update_model")): + return client.update_model(preference_data, inference_messages, **kwargs) + else: + raise NotImplementedError(f"update_model is not implemented for {client.__class__.__name__}.") + + +def update_model( + update_agent: ConversableAgent, preference_data: List[Dict[str, Any]], other_agent: Agent, **kwargs +) -> Dict[str, Any]: + """Update the model using the preference data and the conversation history. + + Args: + update_agent (ConversableAgent): the agent whose model will be updated. + preference_data (List[Dict]): a list of dictionaries containing the preference data. + other_agent (Agent): the agent whose conversation history will be used to update the model. + **kwargs: additional keyword arguments for the update model function. + + Returns: + Dict: a dictionary containing the update stats, inference_messages, and preference data, like so: + { + "update_stats": update_model_stats, + "inference_messages": inference_messages, + "preference_data": preference_data + } + + Raises: + ValueError: If no OpenAIWrapper client is found. + ValueError: If multiple model clients are registered. + NotImplementedError: If update_model is not implemented for the underlying client. + """ + if update_agent.client is None: + raise ValueError("No OpenAIWrapper client is found.") + inference_messages = update_agent._oai_messages[other_agent] + update_model_stats = _client_wrapper_update_model( + update_agent.client, preference_data, inference_messages, **kwargs + ) + return { + "update_stats": update_model_stats, + "inference_messages": inference_messages, + "preference_data": preference_data, + } diff --git a/samples/tools/finetuning/tests/test_conversable_agent_update_model.py b/samples/tools/finetuning/tests/test_conversable_agent_update_model.py new file mode 100644 index 00000000000..56f9474a268 --- /dev/null +++ b/samples/tools/finetuning/tests/test_conversable_agent_update_model.py @@ -0,0 +1,216 @@ +import pytest +from autogen import AssistantAgent, UserProxyAgent +import sys + +sys.path.append("samples/tools/finetuning") + +from finetuning import update_model # noqa: E402 +from typing import Dict # noqa: E402 + +sys.path.append("test") + +TEST_CUSTOM_RESPONSE = "This is a custom response." +TEST_LOCAL_MODEL_NAME = "local_model_name" + + +def test_custom_model_client(): + TEST_LOSS = 0.5 + + class UpdatableCustomModel: + def __init__(self, config: Dict): + self.model = config["model"] + self.model_name = config["model"] + + def create(self, params): + from types import SimpleNamespace + + response = SimpleNamespace() + # need to follow Client.ClientResponseProtocol + response.choices = [] + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = TEST_CUSTOM_RESPONSE + response.choices.append(choice) + response.model = self.model + return response + + def message_retrieval(self, response): + return [response.choices[0].message.content] + + def cost(self, response) -> float: + """Calculate the cost of the response.""" + response.cost = 0 + return 0 + + @staticmethod + def get_usage(response) -> Dict: + return {} + + def update_model(self, preference_data, messages, **kwargs): + return {"loss": TEST_LOSS} + + config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}] + + assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + human_input_mode="NEVER", + llm_config={"config_list": config_list}, + ) + assistant.register_model_client(model_client_cls=UpdatableCustomModel) + user_proxy = UserProxyAgent( + "user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + code_execution_config=False, + llm_config=False, + ) + + res = user_proxy.initiate_chat(assistant, message="2+2=", silent=True) + response_content = res.summary + + assert response_content == TEST_CUSTOM_RESPONSE + preference_data = [("this is what the response should have been like", response_content)] + update_model_stats = update_model(assistant, preference_data, user_proxy) + assert update_model_stats["update_stats"]["loss"] == TEST_LOSS + + +def test_update_model_without_client_raises_error(): + assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + human_input_mode="NEVER", + max_consecutive_auto_reply=0, + llm_config=False, + code_execution_config=False, + ) + + user_proxy = UserProxyAgent( + "user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + code_execution_config=False, + llm_config=False, + ) + + user_proxy.initiate_chat(assistant, message="2+2=", silent=True) + with pytest.raises(ValueError): + update_model(assistant, [], user_proxy) + + +def test_custom_model_update_func_missing_raises_error(): + class UpdatableCustomModel: + def __init__(self, config: Dict): + self.model = config["model"] + self.model_name = config["model"] + + def create(self, params): + from types import SimpleNamespace + + response = SimpleNamespace() + # need to follow Client.ClientResponseProtocol + response.choices = [] + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = TEST_CUSTOM_RESPONSE + response.choices.append(choice) + response.model = self.model + return response + + def message_retrieval(self, response): + return [response.choices[0].message.content] + + def cost(self, response) -> float: + """Calculate the cost of the response.""" + response.cost = 0 + return 0 + + @staticmethod + def get_usage(response) -> Dict: + return {} + + config_list = [{"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}] + + assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + human_input_mode="NEVER", + llm_config={"config_list": config_list}, + ) + assistant.register_model_client(model_client_cls=UpdatableCustomModel) + user_proxy = UserProxyAgent( + "user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + code_execution_config=False, + llm_config=False, + ) + + res = user_proxy.initiate_chat(assistant, message="2+2=", silent=True) + response_content = res.summary + + assert response_content == TEST_CUSTOM_RESPONSE + + with pytest.raises(NotImplementedError): + update_model(assistant, [], user_proxy) + + +def test_multiple_model_clients_raises_error(): + class UpdatableCustomModel: + def __init__(self, config: Dict): + self.model = config["model"] + self.model_name = config["model"] + + def create(self, params): + from types import SimpleNamespace + + response = SimpleNamespace() + # need to follow Client.ClientResponseProtocol + response.choices = [] + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = TEST_CUSTOM_RESPONSE + response.choices.append(choice) + response.model = self.model + return response + + def message_retrieval(self, response): + return [response.choices[0].message.content] + + def cost(self, response) -> float: + """Calculate the cost of the response.""" + response.cost = 0 + return 0 + + @staticmethod + def get_usage(response) -> Dict: + return {} + + def update_model(self, preference_data, messages, **kwargs): + return {} + + config_list = [ + {"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}, + {"model": TEST_LOCAL_MODEL_NAME, "model_client_cls": "UpdatableCustomModel"}, + ] + + assistant = AssistantAgent( + "assistant", + system_message="You are a helpful assistant.", + human_input_mode="NEVER", + llm_config={"config_list": config_list}, + ) + assistant.register_model_client(model_client_cls=UpdatableCustomModel) + assistant.register_model_client(model_client_cls=UpdatableCustomModel) + user_proxy = UserProxyAgent( + "user_proxy", + human_input_mode="NEVER", + max_consecutive_auto_reply=1, + code_execution_config=False, + llm_config=False, + ) + + user_proxy.initiate_chat(assistant, message="2+2=", silent=True) + + with pytest.raises(ValueError): + update_model(assistant, [], user_proxy)