Skip to content

Commit

Permalink
uAbility to update_model on conversable agents
Browse files Browse the repository at this point in the history
  • Loading branch information
olgavrou committed Feb 26, 2024
1 parent 8ec1c3e commit 3cbcc45
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 4 deletions.
26 changes: 26 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
self._name = name
# a dictionary of conversations, default value is list
self._oai_messages = defaultdict(list)
self._update_model_metadata = defaultdict(list)
self._oai_system_message = [{"content": system_message, "role": "system"}]
self._description = description if description is not None else system_message
self._is_termination_msg = (
Expand Down Expand Up @@ -1066,6 +1067,31 @@ def clear_history(self, recipient: Optional[Agent] = None, nr_messages_to_preser
flush=True,
)

def update_model(self, preference_data: List[Dict[str, Any]], agent: Agent, **kwargs) -> Dict[str, Any]:
"""Update the model using the preference data and the conversation history.
Args:
preference_data (List[Dict]): a list of dictionaries containing the preference data.
agent (Agent): the agent to update the model.
**kwargs: additional keyword arguments for the update model function.
Returns:
Dict: a dictionary containing the update model statistics.
Raises:
ValueError: If no OpenAIWrapper client is found.
ValueError: If multiple model clients are registered.
NotImplementedError: If update_model is not implemented for the underlying client.
"""
if self.client is None:
raise ValueError("No OpenAIWrapper client is found.")
messages = self._oai_messages[agent]
update_model_stats = self.client.update_model(preference_data, messages, **kwargs)
self._update_model_metadata[agent].append(
{"messages": messages, "preference_data": preference_data, "update_stats": update_model_stats}
)
return update_model_stats

def generate_oai_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down
50 changes: 46 additions & 4 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ class Message(Protocol):
choices: List[Choice]
model: str

def create(self, **params: Any) -> ModelClientResponseProtocol:
... # pragma: no cover
def create(self, **params: Any) -> ModelClientResponseProtocol: ... # pragma: no cover

def message_retrieval(
self, response: ModelClientResponseProtocol
Expand All @@ -97,14 +96,30 @@ def message_retrieval(
"""
... # pragma: no cover

def cost(self, response: ModelClientResponseProtocol) -> float:
... # pragma: no cover
def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover

@staticmethod
def get_usage(response: ModelClientResponseProtocol) -> Dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
... # pragma: no cover

def update_model(
self, preference_data: List[Dict[str, Any]], inference_messages: List[Dict[str, Any]], **kwargs: Any
) -> Dict[str, Any]:
"""Optional method to learn from the preference data, if the model supports learning. Can be missing.
Learn from the preference data.
Args:
preference_data: The preference data.
inference_messages: The messages used for inference.
**kwargs: other arguments.
Returns:
Learning stats.
"""
... # pragma: no cover


class PlaceHolderClient:
def __init__(self, config):
Expand Down Expand Up @@ -503,6 +518,33 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs:
]
return params

def update_model(
self, preference_data: List[Any], inference_messages: List[Dict[str, Any]], **kwargs: Any
) -> Dict[str, Any]:
"""Learn from the preference data.
update_model is not supported for multiple model clients as it would be ambiguous which client was responsible for the inference messages.
Args:
preference_data: The preference data.
inference_messages: The messages used for inference.
**kwargs: other arguments.
Returns:
Learning stats.
Raises:
ValueError: If multiple model clients are registered.
NotImplementedError: If update_model is not implemented for the client.
"""
if len(self._clients) != 1:
raise ValueError("update_model is not supported for multiple model clients.")
client = self._clients[0]
if hasattr(client, "update_model") and callable(getattr(client, "update_model")):
return client.update_model(preference_data, inference_messages, **kwargs)
else:
raise NotImplementedError(f"update_model is not implemented for {client.__class__.__name__}.")

def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol:
"""Make a completion for a given config using available clients.
Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs.
Expand Down

0 comments on commit 3cbcc45

Please sign in to comment.