diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index b31c8ce786d..dd9fcf7b1d0 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -129,6 +129,7 @@ def __init__( self._name = name # a dictionary of conversations, default value is list self._oai_messages = defaultdict(list) + self._update_model_metadata = defaultdict(list) self._oai_system_message = [{"content": system_message, "role": "system"}] self._description = description if description is not None else system_message self._is_termination_msg = ( @@ -1066,6 +1067,31 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser flush=True, ) + def update_model(self, preference_data: List[Dict[str, Any]], agent: Agent, **kwargs) -> Dict[str, Any]: + """Update the model using the preference data and the conversation history. + + Args: + preference_data (List[Dict]): a list of dictionaries containing the preference data. + agent (Agent): the agent to update the model. + **kwargs: additional keyword arguments for the update model function. + + Returns: + Dict: a dictionary containing the update model statistics. + + Raises: + ValueError: If no OpenAIWrapper client is found. + ValueError: If multiple model clients are registered. + NotImplementedError: If update_model is not implemented for the underlying client. + """ + if self.client is None: + raise ValueError("No OpenAIWrapper client is found.") + messages = self._oai_messages[agent] + update_model_stats = self.client.update_model(preference_data, messages, **kwargs) + self._update_model_metadata[agent].append( + {"messages": messages, "preference_data": preference_data, "update_stats": update_model_stats} + ) + return update_model_stats + def generate_oai_reply( self, messages: Optional[List[Dict]] = None, diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 59e59815330..b27e07964fe 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -83,8 +83,7 @@ class Message(Protocol): choices: List[Choice] model: str - def create(self, **params: Any) -> ModelClientResponseProtocol: - ... # pragma: no cover + def create(self, **params: Any) -> ModelClientResponseProtocol: ... # pragma: no cover def message_retrieval( self, response: ModelClientResponseProtocol @@ -97,14 +96,30 @@ def message_retrieval( """ ... # pragma: no cover - def cost(self, response: ModelClientResponseProtocol) -> float: - ... # pragma: no cover + def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover @staticmethod def get_usage(response: ModelClientResponseProtocol) -> Dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" ... # pragma: no cover + def update_model( + self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any + ) -> Dict[str, Any]: + """Optional method to learn from the preference data, if the model supports learning. Can be missing. + + Learn from the preference data. + + Args: + preference_data: The preference data. + inference_messages: The messages used for inference. + **kwargs: other arguments. + + Returns: + Learning stats. + """ + ... # pragma: no cover + class PlaceHolderClient: def __init__(self, config): @@ -503,6 +518,33 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: ] return params + def update_model( + self, preference_data: List[Any], inference_messages: List[Dict[str, Any]], **kwargs: Any + ) -> Dict[str, Any]: + """Learn from the preference data. + + update_model is not supported for multiple model clients as it would be ambiguous which client was responsible for the inference messages. + + Args: + preference_data: The preference data. + inference_messages: The messages used for inference. + **kwargs: other arguments. + + Returns: + Learning stats. + + Raises: + ValueError: If multiple model clients are registered. + NotImplementedError: If update_model is not implemented for the client. + """ + if len(self._clients) != 1: + raise ValueError("update_model is not supported for multiple model clients.") + client = self._clients[0] + if hasattr(client, "update_model") and callable(getattr(client, "update_model")): + return client.update_model(preference_data, inference_messages, **kwargs) + else: + raise NotImplementedError(f"update_model is not implemented for {client.__class__.__name__}.") + def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: """Make a completion for a given config using available clients. Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs.