diff --git a/.github/workflows/contrib-tests.yml b/.github/workflows/contrib-tests.yml
index 25bcef05b71d..27a616b77190 100644
--- a/.github/workflows/contrib-tests.yml
+++ b/.github/workflows/contrib-tests.yml
@@ -45,15 +45,14 @@ jobs:
- name: Install packages and dependencies for RetrieveChat
run: |
pip install -e .[retrievechat]
- pip uninstall -y openai
- name: Test RetrieveChat
run: |
- pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py
+ pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install coverage>=5.3
- coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib
+ coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
if: matrix.python-version == '3.10'
@@ -82,16 +81,15 @@ jobs:
- name: Install packages and dependencies for Compression
run: |
pip install -e .
- pip uninstall -y openai
- name: Test Compression
if: matrix.python-version != '3.10' # diversify the python versions
run: |
- pytest test/agentchat/contrib/test_compressible_agent.py
+ pytest test/agentchat/contrib/test_compressible_agent.py --skip-openai
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install coverage>=5.3
- coverage run -a -m pytest test/agentchat/contrib/test_compressible_agent.py
+ coverage run -a -m pytest test/agentchat/contrib/test_compressible_agent.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
if: matrix.python-version == '3.10'
@@ -120,16 +118,15 @@ jobs:
- name: Install packages and dependencies for GPTAssistantAgent
run: |
pip install -e .
- pip uninstall -y openai
- name: Test GPTAssistantAgent
if: matrix.python-version != '3.11' # diversify the python versions
run: |
- pytest test/agentchat/contrib/test_gpt_assistant.py
+ pytest test/agentchat/contrib/test_gpt_assistant.py --skip-openai
- name: Coverage
if: matrix.python-version == '3.11'
run: |
pip install coverage>=5.3
- coverage run -a -m pytest test/agentchat/contrib/test_gpt_assistant.py
+ coverage run -a -m pytest test/agentchat/contrib/test_gpt_assistant.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
if: matrix.python-version == '3.11'
@@ -158,16 +155,15 @@ jobs:
- name: Install packages and dependencies for Teachability
run: |
pip install -e .[teachable]
- pip uninstall -y openai
- - name: Test Teachability
+ - name: Test TeachableAgent
if: matrix.python-version != '3.9' # diversify the python versions
run: |
- pytest test/agentchat/contrib/test_teachable_agent.py
+ pytest test/agentchat/contrib/test_teachable_agent.py --skip-openai
- name: Coverage
if: matrix.python-version == '3.9'
run: |
pip install coverage>=5.3
- coverage run -a -m pytest test/agentchat/contrib/test_teachable_agent.py
+ coverage run -a -m pytest test/agentchat/contrib/test_teachable_agent.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
if: matrix.python-version == '3.9'
@@ -196,15 +192,14 @@ jobs:
- name: Install packages and dependencies for LMM
run: |
pip install -e .[lmm]
- pip uninstall -y openai
- name: Test LMM and LLaVA
run: |
- pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py
+ pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py --skip-openai
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install coverage>=5.3
- coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py
+ coverage run -a -m pytest test/agentchat/contrib/test_img_utils.py test/agentchat/contrib/test_lmm.py test/agentchat/contrib/test_llava.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
if: matrix.python-version == '3.10'
diff --git a/autogen/_pydantic.py b/autogen/_pydantic.py
index ef0cad66e743..89dbc4fd2910 100644
--- a/autogen/_pydantic.py
+++ b/autogen/_pydantic.py
@@ -4,7 +4,7 @@
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import get_origin
-__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema")
+__all__ = ("JsonSchemaValue", "model_dump", "model_dump_json", "type2schema", "evaluate_forwardref")
PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.")
diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py
index ed548af58105..1b08ade80ebf 100644
--- a/autogen/agentchat/conversable_agent.py
+++ b/autogen/agentchat/conversable_agent.py
@@ -855,24 +855,23 @@ def generate_tool_calls_reply(
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
- if "tool_calls" in message and message["tool_calls"]:
- tool_calls = message["tool_calls"]
- tool_returns = []
- for tool_call in tool_calls:
- id = tool_call["id"]
- function_call = tool_call.get("function", {})
- func = self._function_map.get(function_call.get("name", None), None)
- if asyncio.coroutines.iscoroutinefunction(func):
- continue
- _, func_return = self.execute_function(function_call)
- tool_returns.append(
- {
- "tool_call_id": id,
- "role": "tool",
- "name": func_return.get("name", ""),
- "content": func_return.get("content", ""),
- }
- )
+ tool_returns = []
+ for tool_call in message.get("tool_calls", []):
+ id = tool_call["id"]
+ function_call = tool_call.get("function", {})
+ func = self._function_map.get(function_call.get("name", None), None)
+ if asyncio.coroutines.iscoroutinefunction(func):
+ continue
+ _, func_return = self.execute_function(function_call)
+ tool_returns.append(
+ {
+ "tool_call_id": id,
+ "role": "tool",
+ "name": func_return.get("name", ""),
+ "content": func_return.get("content", ""),
+ }
+ )
+ if tool_returns:
return True, {
"role": "tool",
"tool_responses": tool_returns,
@@ -908,14 +907,12 @@ async def a_generate_tool_calls_reply(
func = self._function_map.get(tool_call.get("function", {}).get("name", None), None)
if func and asyncio.coroutines.iscoroutinefunction(func):
async_tool_calls.append(self._a_execute_tool_call(tool_call))
- if len(async_tool_calls) > 0:
+ if async_tool_calls:
tool_returns = await asyncio.gather(*async_tool_calls)
return True, {
"role": "tool",
"tool_responses": tool_returns,
- "content": "\n\n".join(
- [self._str_for_tool_response(tool_return["content"]) for tool_return in tool_returns]
- ),
+ "content": "\n\n".join([self._str_for_tool_response(tool_return) for tool_return in tool_returns]),
}
return False, None
@@ -1019,7 +1016,9 @@ def check_termination_and_human_reply(
]
)
- response = {"role": "user", "content": reply, "tool_responses": tool_returns}
+ response = {"role": "user", "content": reply}
+ if tool_returns:
+ response["tool_responses"] = tool_returns
return True, response
@@ -1127,7 +1126,10 @@ async def a_check_termination_and_human_reply(
]
)
- response = {"role": "user", "content": reply, "tool_responses": tool_returns}
+ response = {"role": "user", "content": reply}
+ if tool_returns:
+ response["tool_responses"] = tool_returns
+
return True, response
# increment the consecutive_auto_reply_counter
diff --git a/autogen/function_utils.py b/autogen/function_utils.py
index f289d9e4d2e6..efb50060b391 100644
--- a/autogen/function_utils.py
+++ b/autogen/function_utils.py
@@ -73,7 +73,7 @@ def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
return get_typed_annotation(annotation, globalns)
-def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type, str], Type]]:
+def get_param_annotations(typed_signature: inspect.Signature) -> Dict[int, Union[Annotated[Type[Any], str], Type[Any]]]:
"""Get the type annotations of the parameters of a function
Args:
@@ -111,7 +111,7 @@ class ToolFunction(BaseModel):
def get_parameter_json_schema(
- k: str, v: Union[Annotated[Type, str], Type], default_values: Dict[str, Any]
+ k: str, v: Union[Annotated[Type[Any], str], Type[Any]], default_values: Dict[str, Any]
) -> JsonSchemaValue:
"""Get a JSON schema for a parameter as defined by the OpenAI API
@@ -124,10 +124,14 @@ def get_parameter_json_schema(
A Pydanitc model for the parameter
"""
- def type2description(k: str, v: Union[Annotated[Type, str], Type]) -> str:
+ def type2description(k: str, v: Union[Annotated[Type[Any], str], Type[Any]]) -> str:
# handles Annotated
if hasattr(v, "__metadata__"):
- return v.__metadata__[0]
+ retval = v.__metadata__[0]
+ if isinstance(retval, str):
+ return retval
+ else:
+ raise ValueError(f"Invalid description {retval} for parameter {k}, should be a string.")
else:
return k
@@ -166,7 +170,9 @@ def get_default_values(typed_signature: inspect.Signature) -> Dict[str, Any]:
def get_parameters(
- required: List[str], param_annotations: Dict[str, Union[Annotated[Type, str], Type]], default_values: Dict[str, Any]
+ required: List[str],
+ param_annotations: Dict[str, Union[Annotated[Type[Any], str], Type[Any]]],
+ default_values: Dict[str, Any],
) -> Parameters:
"""Get the parameters of a function as defined by the OpenAI API
@@ -278,7 +284,7 @@ def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Paramet
return model_dump(function)
-def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type], BaseModel]]:
+def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[T, Type[Any]], BaseModel]]:
"""Get a function to load a parameter if it is a Pydantic model
Args:
@@ -319,7 +325,7 @@ def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]:
# a function that loads the parameters before calling the original function
@functools.wraps(func)
- def load_parameters_if_needed(*args, **kwargs):
+ def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
# load the BaseModels if needed
for k, f in kwargs_mapping.items():
kwargs[k] = f(kwargs[k], param_annotations[k])
@@ -327,7 +333,19 @@ def load_parameters_if_needed(*args, **kwargs):
# call the original function
return func(*args, **kwargs)
- return load_parameters_if_needed
+ @functools.wraps(func)
+ async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any:
+ # load the BaseModels if needed
+ for k, f in kwargs_mapping.items():
+ kwargs[k] = f(kwargs[k], param_annotations[k])
+
+ # call the original function
+ return await func(*args, **kwargs)
+
+ if inspect.iscoroutinefunction(func):
+ return _a_load_parameters_if_needed
+ else:
+ return _load_parameters_if_needed
def serialize_to_str(x: Any) -> str:
diff --git a/autogen/oai/client.py b/autogen/oai/client.py
index feaf56a76096..1bdfd835d1e2 100644
--- a/autogen/oai/client.py
+++ b/autogen/oai/client.py
@@ -2,21 +2,36 @@
import os
import sys
-from typing import List, Optional, Dict, Callable, Union
+from typing import Any, List, Optional, Dict, Callable, Tuple, Union
import logging
import inspect
from flaml.automl.logger import logger_formatter
-from pydantic import ValidationError
+
+from pydantic import BaseModel
+
+from autogen.oai import completion
from autogen.oai.openai_utils import get_key, OAI_PRICE1K
from autogen.token_count_utils import count_token
+from autogen._pydantic import model_dump
TOOL_ENABLED = False
try:
import openai
+except ImportError:
+ ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
+ OpenAI = object
+else:
+ # raises exception if openai>=1 is installed and something is wrong with imports
from openai import OpenAI, APIError, __version__ as OPENAIVERSION
+ from openai.resources import Completions
from openai.types.chat import ChatCompletion
- from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
+ from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
+ from openai.types.chat.chat_completion_chunk import (
+ ChoiceDeltaToolCall,
+ ChoiceDeltaToolCallFunction,
+ ChoiceDeltaFunctionCall,
+ )
from openai.types.completion import Completion
from openai.types.completion_usage import CompletionUsage
import diskcache
@@ -24,9 +39,7 @@
if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True
ERROR = None
-except ImportError:
- ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
- OpenAI = object
+
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
@@ -41,10 +54,10 @@ class OpenAIWrapper:
cache_path_root: str = ".cache"
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
- total_usage_summary: Dict = None
- actual_usage_summary: Dict = None
+ total_usage_summary: Optional[Dict[str, Any]] = None
+ actual_usage_summary: Optional[Dict[str, Any]] = None
- def __init__(self, *, config_list: List[Dict] = None, **base_config):
+ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any):
"""
Args:
config_list: a list of config dicts to override the base_config.
@@ -81,7 +94,9 @@ def __init__(self, *, config_list: List[Dict] = None, **base_config):
logger.warning("openai client was provided with an empty config_list, which may not be intended.")
if config_list:
config_list = [config.copy() for config in config_list] # make a copy before modifying
- self._clients = [self._client(config, openai_config) for config in config_list] # could modify the config
+ self._clients: List[OpenAI] = [
+ self._client(config, openai_config) for config in config_list
+ ] # could modify the config
self._config_list = [
{**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}}
for config in config_list
@@ -90,7 +105,9 @@ def __init__(self, *, config_list: List[Dict] = None, **base_config):
self._clients = [self._client(extra_kwargs, openai_config)]
self._config_list = [extra_kwargs]
- def _process_for_azure(self, config: Dict, extra_kwargs: Dict, segment: str = "default"):
+ def _process_for_azure(
+ self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default"
+ ) -> None:
# deal with api_version
query_segment = f"{segment}_query"
headers_segment = f"{segment}_headers"
@@ -123,20 +140,20 @@ def _process_for_azure(self, config: Dict, extra_kwargs: Dict, segment: str = "d
if not base_url.endswith(suffix):
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix
- def _separate_openai_config(self, config):
+ def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Separate the config into openai_config and extra_kwargs."""
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
self._process_for_azure(openai_config, extra_kwargs)
return openai_config, extra_kwargs
- def _separate_create_config(self, config):
+ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Separate the config into create_config and extra_kwargs."""
create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
return create_config, extra_kwargs
- def _client(self, config, openai_config):
+ def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI:
"""Create a client with the given config to override openai_config,
after removing extra kwargs.
"""
@@ -148,21 +165,21 @@ def _client(self, config, openai_config):
@classmethod
def instantiate(
cls,
- template: str | Callable | None,
- context: Optional[Dict] = None,
+ template: Optional[Union[str, Callable[[Dict[str, Any]], str]]],
+ context: Optional[Dict[str, Any]] = None,
allow_format_str_template: Optional[bool] = False,
- ):
+ ) -> Optional[str]:
if not context or template is None:
- return template
+ return template # type: ignore [return-value]
if isinstance(template, str):
return template.format(**context) if allow_format_str_template else template
return template(context)
- def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict:
+ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Prime the create_config with additional_kwargs."""
# Validate the config
- prompt = create_config.get("prompt")
- messages = create_config.get("messages")
+ prompt: Optional[str] = create_config.get("prompt")
+ messages: Optional[List[Dict[str, Any]]] = create_config.get("messages")
if (prompt is None) == (messages is None):
raise ValueError("Either prompt or messages should be in create config but not both.")
context = extra_kwargs.get("context")
@@ -185,11 +202,11 @@ def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> D
}
if m.get("content")
else m
- for m in messages
+ for m in messages # type: ignore [union-attr]
]
return params
- def create(self, **config):
+ def create(self, **config: Any) -> ChatCompletion:
"""Make a completion for a given config using openai's clients.
Besides the kwargs allowed in openai's client, we allow the following additional kwargs.
The config in each client will be overridden by the config.
@@ -239,11 +256,11 @@ def yes_or_no_filter(context, response):
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
# Try to get the response from cache
key = get_key(params)
- response = cache.get(key, None)
+ response: ChatCompletion = cache.get(key, None)
if response is not None:
try:
- response.cost
+ response.cost # type: ignore [attr-defined]
except AttributeError:
# update attribute if cost is not calculated
response.cost = self.cost(response)
@@ -264,7 +281,7 @@ def yes_or_no_filter(context, response):
if error_code == "content_filter":
# raise the error for content_filter
raise
- logger.debug(f"config {i} failed", exc_info=1)
+ logger.debug(f"config {i} failed", exc_info=True)
if i == last:
raise
else:
@@ -284,9 +301,129 @@ def yes_or_no_filter(context, response):
response.pass_filter = pass_filter
return response
continue # filter is not passed; try the next config
+ raise RuntimeError("Should not reach here.")
+
+ @staticmethod
+ def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int:
+ """Update the dict from the chunk.
+
+ Reads `chunk.field` and if present updates `d[field]` accordingly.
+
+ Args:
+ chunk: The chunk.
+ d: The dict to be updated in place.
+ field: The field.
+
+ Returns:
+ The updated dict.
+
+ """
+ completion_tokens = 0
+ assert isinstance(d, dict), d
+ if hasattr(chunk, field) and getattr(chunk, field) is not None:
+ new_value = getattr(chunk, field)
+ if isinstance(new_value, list) or isinstance(new_value, dict):
+ raise NotImplementedError(
+ f"Field {field} is a list or dict, which is currently not supported. "
+ "Only string and numbers are supported."
+ )
+ if field not in d:
+ d[field] = ""
+ if isinstance(new_value, str):
+ d[field] += getattr(chunk, field)
+ else:
+ d[field] = new_value
+ completion_tokens = 1
+
+ return completion_tokens
+
+ @staticmethod
+ def _update_function_call_from_chunk(
+ function_call_chunk: Union[ChoiceDeltaToolCallFunction, ChoiceDeltaFunctionCall],
+ full_function_call: Optional[Dict[str, Any]],
+ completion_tokens: int,
+ ) -> Tuple[Dict[str, Any], int]:
+ """Update the function call from the chunk.
+
+ Args:
+ function_call_chunk: The function call chunk.
+ full_function_call: The full function call.
+ completion_tokens: The number of completion tokens.
+
+ Returns:
+ The updated full function call and the updated number of completion tokens.
+
+ """
+ # Handle function call
+ if function_call_chunk:
+ if full_function_call is None:
+ full_function_call = {}
+ for field in ["name", "arguments"]:
+ completion_tokens += OpenAIWrapper._update_dict_from_chunk(
+ function_call_chunk, full_function_call, field
+ )
- def _completions_create(self, client, params):
- completions = client.chat.completions if "messages" in params else client.completions
+ if full_function_call:
+ return full_function_call, completion_tokens
+ else:
+ raise RuntimeError("Function call is not found, this should not happen.")
+
+ @staticmethod
+ def _update_tool_calls_from_chunk(
+ tool_calls_chunk: ChoiceDeltaToolCall,
+ full_tool_call: Optional[Dict[str, Any]],
+ completion_tokens: int,
+ ) -> Tuple[Dict[str, Any], int]:
+ """Update the tool call from the chunk.
+
+ Args:
+ tool_call_chunk: The tool call chunk.
+ full_tool_call: The full tool call.
+ completion_tokens: The number of completion tokens.
+
+ Returns:
+ The updated full tool call and the updated number of completion tokens.
+
+ """
+ # future proofing for when tool calls other than function calls are supported
+ if tool_calls_chunk.type and tool_calls_chunk.type != "function":
+ raise NotImplementedError(
+ f"Tool call type {tool_calls_chunk.type} is currently not supported. "
+ "Only function calls are supported."
+ )
+
+ # Handle tool call
+ assert full_tool_call is None or isinstance(full_tool_call, dict), full_tool_call
+ if tool_calls_chunk:
+ if full_tool_call is None:
+ full_tool_call = {}
+ for field in ["index", "id", "type"]:
+ completion_tokens += OpenAIWrapper._update_dict_from_chunk(tool_calls_chunk, full_tool_call, field)
+
+ if hasattr(tool_calls_chunk, "function") and tool_calls_chunk.function:
+ if "function" not in full_tool_call:
+ full_tool_call["function"] = None
+
+ full_tool_call["function"], completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
+ tool_calls_chunk.function, full_tool_call["function"], completion_tokens
+ )
+
+ if full_tool_call:
+ return full_tool_call, completion_tokens
+ else:
+ raise RuntimeError("Tool call is not found, this should not happen.")
+
+ def _completions_create(self, client: OpenAI, params: Dict[str, Any]) -> ChatCompletion:
+ """Create a completion for a given config using openai's client.
+
+ Args:
+ client: The openai client.
+ params: The params for the completion.
+
+ Returns:
+ The completion.
+ """
+ completions: Completions = client.chat.completions if "messages" in params else client.completions # type: ignore [attr-defined]
# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
response_contents = [""] * params.get("n", 1)
@@ -297,31 +434,52 @@ def _completions_create(self, client, params):
print("\033[32m", end="")
# Prepare for potential function call
- full_function_call = None
+ full_function_call: Optional[Dict[str, Any]] = None
+ full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None
+
# Send the chat completion request to OpenAI's API and process the response in chunks
for chunk in completions.create(**params):
if chunk.choices:
for choice in chunk.choices:
content = choice.delta.content
- function_call_chunk = choice.delta.function_call
+ tool_calls_chunks = choice.delta.tool_calls
finish_reasons[choice.index] = choice.finish_reason
+ # todo: remove this after function calls are removed from the API
+ # the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail
+ # begin block
+ function_call_chunk = (
+ choice.delta.function_call if hasattr(choice.delta, "function_call") else None
+ )
# Handle function call
if function_call_chunk:
- if hasattr(function_call_chunk, "name") and function_call_chunk.name:
- if full_function_call is None:
- full_function_call = {"name": "", "arguments": ""}
- full_function_call["name"] += function_call_chunk.name
- completion_tokens += 1
- if hasattr(function_call_chunk, "arguments") and function_call_chunk.arguments:
- full_function_call["arguments"] += function_call_chunk.arguments
- completion_tokens += 1
- if choice.finish_reason == "function_call":
- # Need something here? I don't think so.
- pass
- if not content:
- continue
- # End handle function call
+ # Handle function call
+ if function_call_chunk:
+ full_function_call, completion_tokens = self._update_function_call_from_chunk(
+ function_call_chunk, full_function_call, completion_tokens
+ )
+ if not content:
+ continue
+ # end block
+
+ # Handle tool calls
+ if tool_calls_chunks:
+ for tool_calls_chunk in tool_calls_chunks:
+ # the current tool call to be reconstructed
+ ix = tool_calls_chunk.index
+ if full_tool_calls is None:
+ full_tool_calls = []
+ if ix >= len(full_tool_calls):
+ # in case ix is not sequential
+ full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1)
+
+ full_tool_calls[ix], completion_tokens = self._update_tool_calls_from_chunk(
+ tool_calls_chunk, full_tool_calls[ix], completion_tokens
+ )
+ if not content:
+ continue
+
+ # End handle tool calls
# If content is present, print it to the terminal and update response variables
if content is not None:
@@ -329,7 +487,8 @@ def _completions_create(self, client, params):
response_contents[choice.index] += content
completion_tokens += 1
else:
- print()
+ # print()
+ pass
# Reset the terminal text color
print("\033[0m\n")
@@ -356,17 +515,23 @@ def _completions_create(self, client, params):
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
- role="assistant", content=response_contents[i], function_call=full_function_call
+ role="assistant",
+ content=response_contents[i],
+ function_call=full_function_call,
+ tool_calls=full_tool_calls,
),
logprobs=None,
)
else:
# OpenAI versions below 1.5.0
- choice = Choice(
+ choice = Choice( # type: ignore [call-arg]
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
- role="assistant", content=response_contents[i], function_call=full_function_call
+ role="assistant",
+ content=response_contents[i],
+ function_call=full_function_call,
+ tool_calls=full_tool_calls,
),
)
@@ -379,7 +544,7 @@ def _completions_create(self, client, params):
return response
- def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None:
+ def _update_usage_summary(self, response: Union[ChatCompletion, Completion], use_cache: bool) -> None:
"""Update the usage summary.
Usage is calculated no matter filter is passed or not.
@@ -391,17 +556,17 @@ def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache
usage.completion_tokens = 0 if usage.completion_tokens is None else usage.completion_tokens
usage.total_tokens = 0 if usage.total_tokens is None else usage.total_tokens
except (AttributeError, AssertionError):
- logger.debug("Usage attribute is not found in the response.", exc_info=1)
+ logger.debug("Usage attribute is not found in the response.", exc_info=True)
return
- def update_usage(usage_summary):
+ def update_usage(usage_summary: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if usage_summary is None:
- usage_summary = {"total_cost": response.cost}
+ usage_summary = {"total_cost": response.cost} # type: ignore [union-attr]
else:
- usage_summary["total_cost"] += response.cost
+ usage_summary["total_cost"] += response.cost # type: ignore [union-attr]
usage_summary[response.model] = {
- "cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost,
+ "cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost, # type: ignore [union-attr]
"prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0) + usage.prompt_tokens,
"completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0)
+ usage.completion_tokens,
@@ -416,7 +581,7 @@ def update_usage(usage_summary):
def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
"""Print the usage summary."""
- def print_usage(usage_summary, usage_type="total"):
+ def print_usage(usage_summary: Optional[Dict[str, Any]], usage_type: str = "total") -> None:
word_from_type = "including" if usage_type == "total" else "excluding"
if usage_summary is None:
print("No actual cost incurred (all completions are using cache).", flush=True)
@@ -475,20 +640,20 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
model = response.model
if model not in OAI_PRICE1K:
# TODO: add logging to warn that the model is not found
- logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=1)
+ logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True)
return 0
- n_input_tokens = response.usage.prompt_tokens
- n_output_tokens = response.usage.completion_tokens
+ n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr]
+ n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr]
tmp_price1K = OAI_PRICE1K[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
- return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
- return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
+ return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return]
+ return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]
@classmethod
def extract_text_or_completion_object(
- cls, response: ChatCompletion | Completion
+ cls, response: Union[ChatCompletion, Completion]
) -> Union[List[str], List[ChatCompletionMessage]]:
"""Extract the text or ChatCompletion objects from a completion or chat response.
@@ -500,18 +665,18 @@ def extract_text_or_completion_object(
"""
choices = response.choices
if isinstance(response, Completion):
- return [choice.text for choice in choices]
+ return [choice.text for choice in choices] # type: ignore [union-attr]
if TOOL_ENABLED:
- return [
- choice.message
- if choice.message.function_call is not None or choice.message.tool_calls is not None
- else choice.message.content
+ return [ # type: ignore [return-value]
+ choice.message # type: ignore [union-attr]
+ if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr]
+ else choice.message.content # type: ignore [union-attr]
for choice in choices
]
else:
- return [
- choice.message if choice.message.function_call is not None else choice.message.content
+ return [ # type: ignore [return-value]
+ choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr]
for choice in choices
]
diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py
index eb0b84cf7ed4..525a1027cc09 100644
--- a/autogen/oai/openai_utils.py
+++ b/autogen/oai/openai_utils.py
@@ -3,7 +3,7 @@
import os
import tempfile
from pathlib import Path
-from typing import Dict, List, Optional, Set, Union
+from typing import Any, Dict, List, Optional, Set, Union
from dotenv import find_dotenv, load_dotenv
@@ -50,7 +50,7 @@
}
-def get_key(config):
+def get_key(config: Dict[str, Any]) -> str:
"""Get a unique identifier of a configuration.
Args:
@@ -145,8 +145,8 @@ def config_list_openai_aoai(
exclude (str, optional): The API type to exclude from the configuration list. Can be 'openai' or 'aoai'. Defaults to None.
Returns:
- List[Dict]: A list of configuration dictionaries. Each dictionary contains keys for 'api_key', 'base_url', 'api_type',
- and 'api_version'.
+ List[Dict]: A list of configuration dictionaries. Each dictionary contains keys for 'api_key',
+ and optionally 'base_url', 'api_type', and 'api_version'.
Raises:
FileNotFoundError: If the specified key files are not found and the corresponding API key is not set in the environment variables.
@@ -241,7 +241,6 @@ def config_list_openai_aoai(
# Assuming OpenAI API_KEY in os.environ["OPENAI_API_KEY"]
api_keys=os.environ.get("OPENAI_API_KEY", "").split("\n"),
base_urls=base_urls,
- # "api_type": "open_ai",
)
if exclude != "openai"
else []
@@ -366,23 +365,23 @@ def filter_config(config_list, filter_dict):
```
# Example configuration list with various models and API types
configs = [
- {'model': 'gpt-3.5-turbo', 'api_type': 'openai'},
- {'model': 'gpt-4', 'api_type': 'openai'},
+ {'model': 'gpt-3.5-turbo'},
+ {'model': 'gpt-4'},
{'model': 'gpt-3.5-turbo', 'api_type': 'azure'},
]
# Define filter criteria to select configurations for the 'gpt-3.5-turbo' model
- # that are also using the 'openai' API type
+ # that are also using the 'azure' API type
filter_criteria = {
'model': ['gpt-3.5-turbo'], # Only accept configurations for 'gpt-3.5-turbo'
- 'api_type': ['openai'] # Only accept configurations for 'openai' API type
+ 'api_type': ['azure'] # Only accept configurations for 'azure' API type
}
# Apply the filter to the configuration list
filtered_configs = filter_config(configs, filter_criteria)
# The resulting `filtered_configs` will be:
- # [{'model': 'gpt-3.5-turbo', 'api_type': 'openai'}]
+ # [{'model': 'gpt-3.5-turbo', 'api_type': 'azure', ...}]
```
Note:
@@ -426,10 +425,10 @@ def config_list_from_json(
Example:
```
# Suppose we have an environment variable 'CONFIG_JSON' with the following content:
- # '[{"model": "gpt-3.5-turbo", "api_type": "openai"}, {"model": "gpt-4", "api_type": "openai"}]'
+ # '[{"model": "gpt-3.5-turbo", "api_type": "azure"}, {"model": "gpt-4"}]'
# We can retrieve a filtered list of configurations like this:
- filter_criteria = {"api_type": ["openai"], "model": ["gpt-3.5-turbo"]}
+ filter_criteria = {"model": ["gpt-3.5-turbo"]}
configs = config_list_from_json('CONFIG_JSON', filter_dict=filter_criteria)
# The 'configs' variable will now contain only the configurations that match the filter criteria.
```
@@ -472,14 +471,12 @@ def get_config(
config = get_config(
api_key="sk-abcdef1234567890",
base_url="https://api.openai.com",
- api_type="openai",
api_version="v1"
)
# The 'config' variable will now contain:
# {
# "api_key": "sk-abcdef1234567890",
# "base_url": "https://api.openai.com",
- # "api_type": "openai",
# "api_version": "v1"
# }
```
diff --git a/autogen/version.py b/autogen/version.py
index fe404ae570d5..01ef12070dc3 100644
--- a/autogen/version.py
+++ b/autogen/version.py
@@ -1 +1 @@
-__version__ = "0.2.5"
+__version__ = "0.2.6"
diff --git a/notebook/agentchat_function_call_async.ipynb b/notebook/agentchat_function_call_async.ipynb
index 098457792fe7..27f977a5b1fc 100644
--- a/notebook/agentchat_function_call_async.ipynb
+++ b/notebook/agentchat_function_call_async.ipynb
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"id": "9fb85afb",
"metadata": {},
"outputs": [
@@ -134,40 +134,46 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
- "\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool Call (call_thUjscBN349eGd6xh3XrVT18): timer *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
- "\u001b[32m******************************************\u001b[0m\n",
+ "\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
- "\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n",
+ "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
+ "\n",
+ "\u001b[32m***** Response from calling tool \"timer\" *****\u001b[0m\n",
"Timer is done!\n",
- "\u001b[32m**************************************************\u001b[0m\n",
+ "\u001b[32m**********************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
- "\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool Call (call_ubo7cKE3TKumGHkqGjQtZisy): stopwatch *****\u001b[0m\n",
"Arguments: \n",
"{\"num_seconds\":\"5\"}\n",
- "\u001b[32m**********************************************\u001b[0m\n",
+ "\u001b[32m**************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[35m\n",
">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
- "\u001b[32m***** Response from calling function \"stopwatch\" *****\u001b[0m\n",
+ "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
+ "\n",
+ "\u001b[32m***** Response from calling tool \"stopwatch\" *****\u001b[0m\n",
"Stopwatch is done!\n",
- "\u001b[32m******************************************************\u001b[0m\n",
+ "\u001b[32m**************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
+ "Both the timer and the stopwatch for 5 seconds have been completed. \n",
+ "\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
@@ -239,7 +245,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"id": "2472f95c",
"metadata": {},
"outputs": [],
@@ -274,105 +280,20 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"id": "e2c9267a",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
- "\n",
- "\n",
- "1) Create a timer for 5 seconds.\n",
- "2) a stopwatch for 5 seconds.\n",
- "3) Pretty print the result as md.\n",
- "4) when 1-3 are done, terminate the group chat\n",
- "\n",
- "--------------------------------------------------------------------------------\n",
- "\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
- "\n",
- "\n",
- "\n",
- "--------------------------------------------------------------------------------\n",
- "\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
- "\n",
- "\u001b[32m***** Suggested function Call: timer *****\u001b[0m\n",
- "Arguments: \n",
- "{\"num_seconds\":\"5\"}\n",
- "\u001b[32m******************************************\u001b[0m\n",
- "\n",
- "--------------------------------------------------------------------------------\n",
- "\u001b[35m\n",
- ">>>>>>>> EXECUTING ASYNC FUNCTION timer...\u001b[0m\n",
- "\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
- "\n",
- "\u001b[32m***** Response from calling function \"timer\" *****\u001b[0m\n",
- "Timer is done!\n",
- "\u001b[32m**************************************************\u001b[0m\n",
- "\n",
- "--------------------------------------------------------------------------------\n",
- "\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
- "\n",
- "\n",
- "\n",
- "--------------------------------------------------------------------------------\n",
- "\u001b[33mchatbot\u001b[0m (to chat_manager):\n",
- "\n",
- "\u001b[32m***** Suggested function Call: stopwatch *****\u001b[0m\n",
- "Arguments: \n",
- "{\"num_seconds\":\"5\"}\n",
- "\u001b[32m**********************************************\u001b[0m\n",
- "\n",
- "--------------------------------------------------------------------------------\n",
- "\u001b[35m\n",
- ">>>>>>>> EXECUTING FUNCTION stopwatch...\u001b[0m\n",
- "\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
- "\n",
- "\u001b[32m***** Response from calling function \"stopwatch\" *****\u001b[0m\n",
- "Stopwatch is done!\n",
- "\u001b[32m******************************************************\u001b[0m\n",
- "\n",
- "--------------------------------------------------------------------------------\n",
- "\u001b[33mMarkdown_agent\u001b[0m (to chat_manager):\n",
- "\n",
- "The results are as follows:\n",
- "\n",
- "- Timer: Completed after `5 seconds`.\n",
- "- Stopwatch: Recorded time of `5 seconds`.\n",
- "\n",
- "**Timer and Stopwatch Summary:**\n",
- "Both the timer and stopwatch were set for `5 seconds` and have now concluded successfully. \n",
- "\n",
- "Now, let's proceed to terminate the group chat as requested.\n",
- "\u001b[32m***** Suggested function Call: terminate_group_chat *****\u001b[0m\n",
- "Arguments: \n",
- "{\"message\":\"All tasks have been completed. The group chat will now be terminated. Goodbye!\"}\n",
- "\u001b[32m*********************************************************\u001b[0m\n",
- "\n",
- "--------------------------------------------------------------------------------\n",
- "\u001b[35m\n",
- ">>>>>>>> EXECUTING FUNCTION terminate_group_chat...\u001b[0m\n",
- "\u001b[33muser_proxy\u001b[0m (to chat_manager):\n",
- "\n",
- "\u001b[32m***** Response from calling function \"terminate_group_chat\" *****\u001b[0m\n",
- "[GROUPCHAT_TERMINATE] All tasks have been completed. The group chat will now be terminated. Goodbye!\n",
- "\u001b[32m*****************************************************************\u001b[0m\n",
- "\n",
- "--------------------------------------------------------------------------------\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "await user_proxy.a_initiate_chat( # noqa: F704\n",
- " manager,\n",
- " message=\"\"\"\n",
- "1) Create a timer for 5 seconds.\n",
- "2) a stopwatch for 5 seconds.\n",
- "3) Pretty print the result as md.\n",
- "4) when 1-3 are done, terminate the group chat\"\"\",\n",
- ")"
+ "# todo: remove comment after fixing https://github.com/microsoft/autogen/issues/1205\n",
+ "# await user_proxy.a_initiate_chat( # noqa: F704\n",
+ "# manager,\n",
+ "# message=\"\"\"\n",
+ "# 1) Create a timer for 5 seconds.\n",
+ "# 2) a stopwatch for 5 seconds.\n",
+ "# 3) Pretty print the result as md.\n",
+ "# 4) when 1-3 are done, terminate the group chat\"\"\",\n",
+ "# )"
]
},
{
diff --git a/notebook/agentchat_function_call_currency_calculator.ipynb b/notebook/agentchat_function_call_currency_calculator.ipynb
index cdde6ec67a17..e3fb14b2abdb 100644
--- a/notebook/agentchat_function_call_currency_calculator.ipynb
+++ b/notebook/agentchat_function_call_currency_calculator.ipynb
@@ -6,7 +6,7 @@
"id": "ae1f50ec",
"metadata": {},
"source": [
- ""
+ ""
]
},
{
@@ -15,7 +15,7 @@
"id": "9a71fa36",
"metadata": {},
"source": [
- "# Auto Generated Agent Chat: Task Solving with Provided Tools as Functions\n",
+ "# Currency Calculator: Task Solving with Provided Tools as Functions\n",
"\n",
"AutoGen offers conversable agents powered by LLM, tool, or human, which can be used to perform tasks collectively via automated chat. This framework allows tool use and human participation through multi-agent conversation. Please find documentation about this feature [here](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat).\n",
"\n",
@@ -167,7 +167,7 @@
" quote_currency: Annotated[CurrencySymbol, \"Quote currency\"] = \"EUR\",\n",
") -> str:\n",
" quote_amount = exchange_rate(base_currency, quote_currency) * base_amount\n",
- " return f\"{quote_amount} {quote_currency}\""
+ " return f\"{quote_amount} {quote_currency}\"\n"
]
},
{
@@ -262,13 +262,9 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
- "\u001b[32m***** Suggested tool Call (call_2mZCDF9fe8WJh6SveIwdGGEy): currency_calculator *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool Call (call_ubo7cKE3TKumGHkqGjQtZisy): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
- "{\n",
- " \"base_amount\": 123.45,\n",
- " \"base_currency\": \"USD\",\n",
- " \"quote_currency\": \"EUR\"\n",
- "}\n",
+ "{\"base_amount\":123.45,\"base_currency\":\"USD\",\"quote_currency\":\"EUR\"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
@@ -276,9 +272,11 @@
">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
- "\u001b[32m***** Response from calling function \"currency_calculator\" *****\u001b[0m\n",
+ "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
+ "\n",
+ "\u001b[32m***** Response from calling tool \"currency_calculator\" *****\u001b[0m\n",
"112.22727272727272 EUR\n",
- "\u001b[32m****************************************************************\u001b[0m\n",
+ "\u001b[32m************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
@@ -424,15 +422,9 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
- "\u001b[32m***** Suggested tool Call (call_MLtsPcVJXhdpvDPNNxfTB3OB): currency_calculator *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool Call (call_0VuU2rATuOgYrGmcBnXzPXlh): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
- "{\n",
- " \"base\": {\n",
- " \"currency\": \"EUR\",\n",
- " \"amount\": 112.23\n",
- " },\n",
- " \"quote_currency\": \"USD\"\n",
- "}\n",
+ "{\"base\":{\"currency\":\"EUR\",\"amount\":112.23},\"quote_currency\":\"USD\"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
@@ -440,14 +432,16 @@
">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
- "\u001b[32m***** Response from calling function \"currency_calculator\" *****\u001b[0m\n",
+ "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
+ "\n",
+ "\u001b[32m***** Response from calling tool \"currency_calculator\" *****\u001b[0m\n",
"{\"currency\":\"USD\",\"amount\":123.45300000000002}\n",
- "\u001b[32m****************************************************************\u001b[0m\n",
+ "\u001b[32m************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
- "112.23 Euros is approximately 123.45 US Dollars.\n",
+ "112.23 Euros is equivalent to approximately 123.45 US Dollars.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
@@ -488,15 +482,9 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
- "\u001b[32m***** Suggested tool Call (call_WrBjnoLeXilBPuj9nTJLM5wh): currency_calculator *****\u001b[0m\n",
+ "\u001b[32m***** Suggested tool Call (call_A6lqMu7s5SyDvftTSeQTtPcj): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
- "{\n",
- " \"base\": {\n",
- " \"currency\": \"USD\",\n",
- " \"amount\": 123.45\n",
- " },\n",
- " \"quote_currency\": \"EUR\"\n",
- "}\n",
+ "{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
@@ -504,9 +492,11 @@
">>>>>>>> EXECUTING FUNCTION currency_calculator...\u001b[0m\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
- "\u001b[32m***** Response from calling function \"currency_calculator\" *****\u001b[0m\n",
+ "\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
+ "\n",
+ "\u001b[32m***** Response from calling tool \"currency_calculator\" *****\u001b[0m\n",
"{\"currency\":\"EUR\",\"amount\":112.22727272727272}\n",
- "\u001b[32m****************************************************************\u001b[0m\n",
+ "\u001b[32m************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
@@ -560,7 +550,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.6"
+ "version": "3.10.13"
}
},
"nbformat": 4,
diff --git a/test/agentchat/contrib/test_qdrant_retrievechat.py b/test/agentchat/contrib/test_qdrant_retrievechat.py
index 1d3c5afd6af9..1a18e78a4ded 100644
--- a/test/agentchat/contrib/test_qdrant_retrievechat.py
+++ b/test/agentchat/contrib/test_qdrant_retrievechat.py
@@ -4,6 +4,9 @@
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
from autogen import config_list_from_json
+sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
+from conftest import skip_openai # noqa: E402
+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
@@ -22,17 +25,17 @@
try:
import openai
-
- OPENAI_INSTALLED = True
except ImportError:
- OPENAI_INSTALLED = False
+ skip = True
+else:
+ skip = False or skip_openai
test_dir = os.path.join(os.path.dirname(__file__), "../..", "test_files")
@pytest.mark.skipif(
- sys.platform in ["darwin", "win32"] or not QDRANT_INSTALLED or not OPENAI_INSTALLED,
- reason="do not run on MacOS or windows or dependency is not installed",
+ sys.platform in ["darwin", "win32"] or not QDRANT_INSTALLED or skip,
+ reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_retrievechat():
conversations = {}
diff --git a/test/agentchat/contrib/test_retrievechat.py b/test/agentchat/contrib/test_retrievechat.py
index 574e3571b626..eeda1dc4891d 100644
--- a/test/agentchat/contrib/test_retrievechat.py
+++ b/test/agentchat/contrib/test_retrievechat.py
@@ -3,6 +3,9 @@
import sys
import autogen
+sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
+from conftest import skip_openai # noqa: E402
+
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
@@ -16,15 +19,15 @@
)
import chromadb
from chromadb.utils import embedding_functions as ef
-
- skip_test = False
except ImportError:
- skip_test = True
+ skip = True
+else:
+ skip = False or skip_openai
@pytest.mark.skipif(
- sys.platform in ["darwin", "win32"] or skip_test,
- reason="do not run on MacOS or windows or dependency is not installed",
+ sys.platform in ["darwin", "win32"] or skip,
+ reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_retrievechat():
conversations = {}
@@ -69,8 +72,8 @@ def test_retrievechat():
@pytest.mark.skipif(
- sys.platform in ["darwin", "win32"] or skip_test,
- reason="do not run on MacOS or windows or dependency is not installed",
+ sys.platform in ["darwin", "win32"] or skip,
+ reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_retrieve_config(caplog):
# test warning message when no docs_path is provided
diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py
index 0d647cbcc3d4..04bb95794fd3 100644
--- a/test/agentchat/test_conversable_agent.py
+++ b/test/agentchat/test_conversable_agent.py
@@ -1,12 +1,18 @@
+import asyncio
import copy
+import sys
+import time
from typing import Any, Callable, Dict, Literal
+import unittest
import pytest
from unittest.mock import patch
from pydantic import BaseModel, Field
from typing_extensions import Annotated
+import autogen
from autogen.agentchat import ConversableAgent, UserProxyAgent
+from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
from conftest import skip_openai
try:
@@ -445,6 +451,8 @@ def currency_calculator(
== '{"currency":"EUR","amount":100.1}'
)
+ assert not asyncio.coroutines.iscoroutinefunction(currency_calculator)
+
@pytest.mark.asyncio
async def test__wrap_function_async():
@@ -481,6 +489,8 @@ async def currency_calculator(
== '{"currency":"EUR","amount":100.1}'
)
+ assert asyncio.coroutines.iscoroutinefunction(currency_calculator)
+
def get_origin(d: Dict[str, Callable[..., Any]]) -> Dict[str, Callable[..., Any]]:
return {k: v._origin for k, v in d.items()}
@@ -624,6 +634,161 @@ async def exec_sh(script: Annotated[str, "Valid shell script to execute."]):
assert get_origin(user_proxy_1.function_map) == expected_function_map
+@pytest.mark.skipif(
+ skip or not sys.version.startswith("3.10"),
+ reason="do not run if openai is not installed or py!=3.10",
+)
+def test_function_registration_e2e_sync() -> None:
+ config_list = autogen.config_list_from_json(
+ OAI_CONFIG_LIST,
+ filter_dict={
+ "model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
+ },
+ file_location=KEY_LOC,
+ )
+
+ llm_config = {
+ "config_list": config_list,
+ }
+
+ coder = autogen.AssistantAgent(
+ name="chatbot",
+ system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
+ llm_config=llm_config,
+ )
+
+ # create a UserProxyAgent instance named "user_proxy"
+ user_proxy = autogen.UserProxyAgent(
+ name="user_proxy",
+ system_message="A proxy for the user for executing code.",
+ is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
+ human_input_mode="NEVER",
+ max_consecutive_auto_reply=10,
+ code_execution_config={"work_dir": "coding"},
+ )
+
+ # define functions according to the function description
+ timer_mock = unittest.mock.MagicMock()
+ stopwatch_mock = unittest.mock.MagicMock()
+
+ # An example async function
+ @user_proxy.register_for_execution()
+ @coder.register_for_llm(description="create a timer for N seconds")
+ def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
+ print("timer is running")
+ for i in range(int(num_seconds)):
+ print(".", end="")
+ time.sleep(0.01)
+ print()
+
+ timer_mock(num_seconds=num_seconds)
+ return "Timer is done!"
+
+ # An example sync function
+ @user_proxy.register_for_execution()
+ @coder.register_for_llm(description="create a stopwatch for N seconds")
+ def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
+ print("stopwatch is running")
+ # assert False, "stopwatch's alive!"
+ for i in range(int(num_seconds)):
+ print(".", end="")
+ time.sleep(0.01)
+ print()
+
+ stopwatch_mock(num_seconds=num_seconds)
+ return "Stopwatch is done!"
+
+ # start the conversation
+ # 'await' is used to pause and resume code execution for async IO operations.
+ # Without 'await', an async function returns a coroutine object but doesn't execute the function.
+ # With 'await', the async function is executed and the current function is paused until the awaited function returns a result.
+ user_proxy.initiate_chat( # noqa: F704
+ coder,
+ message="Create a timer for 2 seconds and then a stopwatch for 3 seconds.",
+ )
+
+ timer_mock.assert_called_once_with(num_seconds="2")
+ stopwatch_mock.assert_called_once_with(num_seconds="3")
+
+
+@pytest.mark.skipif(
+ skip or not sys.version.startswith("3.10"),
+ reason="do not run if openai is not installed or py!=3.10",
+)
+@pytest.mark.asyncio()
+async def test_function_registration_e2e_async() -> None:
+ config_list = autogen.config_list_from_json(
+ OAI_CONFIG_LIST,
+ filter_dict={
+ "model": ["gpt-4", "gpt-4-0314", "gpt4", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-v0314"],
+ },
+ file_location=KEY_LOC,
+ )
+
+ llm_config = {
+ "config_list": config_list,
+ }
+
+ coder = autogen.AssistantAgent(
+ name="chatbot",
+ system_message="For coding tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
+ llm_config=llm_config,
+ )
+
+ # create a UserProxyAgent instance named "user_proxy"
+ user_proxy = autogen.UserProxyAgent(
+ name="user_proxy",
+ system_message="A proxy for the user for executing code.",
+ is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
+ human_input_mode="NEVER",
+ max_consecutive_auto_reply=10,
+ code_execution_config={"work_dir": "coding"},
+ )
+
+ # define functions according to the function description
+ timer_mock = unittest.mock.MagicMock()
+ stopwatch_mock = unittest.mock.MagicMock()
+
+ # An example async function
+ @user_proxy.register_for_execution()
+ @coder.register_for_llm(description="create a timer for N seconds")
+ async def timer(num_seconds: Annotated[str, "Number of seconds in the timer."]) -> str:
+ print("timer is running")
+ for i in range(int(num_seconds)):
+ print(".", end="")
+ await asyncio.sleep(0.01)
+ print()
+
+ timer_mock(num_seconds=num_seconds)
+ return "Timer is done!"
+
+ # An example sync function
+ @user_proxy.register_for_execution()
+ @coder.register_for_llm(description="create a stopwatch for N seconds")
+ def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]) -> str:
+ print("stopwatch is running")
+ # assert False, "stopwatch's alive!"
+ for i in range(int(num_seconds)):
+ print(".", end="")
+ time.sleep(0.01)
+ print()
+
+ stopwatch_mock(num_seconds=num_seconds)
+ return "Stopwatch is done!"
+
+ # start the conversation
+ # 'await' is used to pause and resume code execution for async IO operations.
+ # Without 'await', an async function returns a coroutine object but doesn't execute the function.
+ # With 'await', the async function is executed and the current function is paused until the awaited function returns a result.
+ await user_proxy.a_initiate_chat( # noqa: F704
+ coder,
+ message="Create a timer for 4 seconds and then a stopwatch for 5 seconds.",
+ )
+
+ timer_mock.assert_called_once_with(num_seconds="4")
+ stopwatch_mock.assert_called_once_with(num_seconds="5")
+
+
@pytest.mark.skipif(
skip,
reason="do not run if skipping openai",
diff --git a/test/oai/test_client_stream.py b/test/oai/test_client_stream.py
index 284df0954660..6a20c4ffa21a 100644
--- a/test/oai/test_client_stream.py
+++ b/test/oai/test_client_stream.py
@@ -1,3 +1,6 @@
+import json
+from typing import Any, Dict, List, Literal, Optional, Union
+from unittest.mock import MagicMock
import pytest
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
import sys
@@ -13,12 +16,21 @@
else:
skip = False or skip_openai
+ # raises exception if openai>=1 is installed and something is wrong with imports
+ # otherwise the test will be skipped
+ from openai.types.chat.chat_completion_chunk import (
+ ChoiceDeltaFunctionCall,
+ ChoiceDeltaToolCall,
+ ChoiceDeltaToolCallFunction,
+ )
+ from openai.types.chat.chat_completion import ChatCompletionMessage # type: ignore [attr-defined]
+
KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
-def test_aoai_chat_completion_stream():
+def test_aoai_chat_completion_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
@@ -31,7 +43,7 @@ def test_aoai_chat_completion_stream():
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
-def test_chat_completion_stream():
+def test_chat_completion_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
@@ -43,8 +55,147 @@ def test_chat_completion_stream():
print(client.extract_text_or_completion_object(response))
+# no need for OpenAI, works with any model
+def test__update_dict_from_chunk() -> None:
+ # dictionaries and lists are not supported
+ mock = MagicMock()
+ empty_collections: List[Union[List[Any], Dict[str, Any]]] = [{}, []]
+ for c in empty_collections:
+ mock.c = c
+ with pytest.raises(NotImplementedError):
+ OpenAIWrapper._update_dict_from_chunk(mock, {}, "c")
+
+ org_d: Dict[str, Any] = {}
+ for i, v in enumerate([0, 1, False, True, 0.0, 1.0]):
+ field = "abcedfghijklmnopqrstuvwxyz"[i]
+ setattr(mock, field, v)
+
+ d = org_d.copy()
+ OpenAIWrapper._update_dict_from_chunk(mock, d, field)
+
+ org_d[field] = v
+ assert d == org_d
+
+ mock.s = "beginning"
+ OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
+ assert d["s"] == "beginning"
+
+ mock.s = " and"
+ OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
+ assert d["s"] == "beginning and"
+
+ mock.s = " end"
+ OpenAIWrapper._update_dict_from_chunk(mock, d, "s")
+ assert d["s"] == "beginning and end"
+
+
+@pytest.mark.skipif(skip, reason="openai>=1 not installed")
+def test__update_function_call_from_chunk() -> None:
+ function_call_chunks = [
+ ChoiceDeltaFunctionCall(arguments=None, name="get_current_weather"),
+ ChoiceDeltaFunctionCall(arguments='{"', name=None),
+ ChoiceDeltaFunctionCall(arguments="location", name=None),
+ ChoiceDeltaFunctionCall(arguments='":"', name=None),
+ ChoiceDeltaFunctionCall(arguments="San", name=None),
+ ChoiceDeltaFunctionCall(arguments=" Francisco", name=None),
+ ChoiceDeltaFunctionCall(arguments='"}', name=None),
+ ]
+ expected = {"name": "get_current_weather", "arguments": '{"location":"San Francisco"}'}
+
+ full_function_call = None
+ completion_tokens = 0
+ for function_call_chunk in function_call_chunks:
+ # print(f"{function_call_chunk=}")
+ full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk(
+ function_call_chunk=function_call_chunk,
+ full_function_call=full_function_call,
+ completion_tokens=completion_tokens,
+ )
+
+ print(f"{full_function_call=}")
+ print(f"{completion_tokens=}")
+
+ assert full_function_call == expected
+ assert completion_tokens == len(function_call_chunks)
+
+ ChatCompletionMessage(role="assistant", function_call=full_function_call, content=None)
+
+
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
-def test_chat_functions_stream():
+def test__update_tool_calls_from_chunk() -> None:
+ tool_calls_chunks = [
+ ChoiceDeltaToolCall(
+ index=0,
+ id="call_D2HOWGMekmkxXu9Ix3DUqJRv",
+ function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
+ type="function",
+ ),
+ ChoiceDeltaToolCall(
+ index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "S', name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="an F", name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="ranci", name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="sco, C", name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='A"}', name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=1,
+ id="call_22HgJep4nwoKU3UOr96xaLmd",
+ function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"),
+ type="function",
+ ),
+ ChoiceDeltaToolCall(
+ index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "N', name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ew Y", name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ork, ", name=None), type=None
+ ),
+ ChoiceDeltaToolCall(
+ index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='NY"}', name=None), type=None
+ ),
+ ]
+
+ full_tool_calls: List[Optional[Dict[str, Any]]] = [None, None]
+ completion_tokens = 0
+ for tool_calls_chunk in tool_calls_chunks:
+ index = tool_calls_chunk.index
+ full_tool_calls[index], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk(
+ tool_calls_chunk=tool_calls_chunk,
+ full_tool_call=full_tool_calls[index],
+ completion_tokens=completion_tokens,
+ )
+
+ print(f"{full_tool_calls=}")
+ print(f"{completion_tokens=}")
+
+ ChatCompletionMessage(role="assistant", tool_calls=full_tool_calls, content=None)
+
+
+# todo: remove when OpenAI removes functions from the API
+@pytest.mark.skipif(skip, reason="openai>=1 not installed")
+def test_chat_functions_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
@@ -76,8 +227,63 @@ def test_chat_functions_stream():
print(client.extract_text_or_completion_object(response))
+# test for tool support instead of the deprecated function calls
+@pytest.mark.skipif(skip, reason="openai>=1 not installed")
+def test_chat_tools_stream() -> None:
+ config_list = config_list_from_json(
+ env_or_file=OAI_CONFIG_LIST,
+ file_location=KEY_LOC,
+ filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
+ )
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA",
+ },
+ },
+ "required": ["location"],
+ },
+ },
+ },
+ ]
+ print(f"{config_list=}")
+ client = OpenAIWrapper(config_list=config_list)
+ response = client.create(
+ # the intention is to trigger two tool invocations as a response to a single message
+ messages=[{"role": "user", "content": "What's the weather like today in San Francisco and New York?"}],
+ tools=tools,
+ stream=True,
+ )
+ print(f"{response=}")
+ print(f"{type(response)=}")
+ print(f"{client.extract_text_or_completion_object(response)=}")
+ # check response
+ choices = response.choices
+ assert isinstance(choices, list)
+ assert len(choices) == 1
+ choice = choices[0]
+ assert choice.finish_reason == "tool_calls"
+ message = choice.message
+ tool_calls = message.tool_calls
+ assert isinstance(tool_calls, list)
+ assert len(tool_calls) == 2
+ arguments = [tool_call.function.arguments for tool_call in tool_calls]
+ locations = [json.loads(argument)["location"] for argument in arguments]
+ print(f"{locations=}")
+ assert any(["San Francisco" in location for location in locations])
+ assert any(["New York" in location for location in locations])
+
+
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
-def test_completion_stream():
+def test_completion_stream() -> None:
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
diff --git a/test/test_function_utils.py b/test/test_function_utils.py
index 422aaf9bbed0..705453e42700 100644
--- a/test/test_function_utils.py
+++ b/test/test_function_utils.py
@@ -1,3 +1,4 @@
+import asyncio
import inspect
import unittest.mock
from typing import Dict, List, Literal, Optional, Tuple
@@ -355,7 +356,7 @@ def test_get_load_param_if_needed_function() -> None:
assert actual == expected, actual
-def test_load_basemodels_if_needed() -> None:
+def test_load_basemodels_if_needed_sync() -> None:
@load_basemodels_if_needed
def f(
base: Annotated[Currency, "Base currency"],
@@ -363,6 +364,8 @@ def f(
) -> Tuple[Currency, CurrencySymbol]:
return base, quote_currency
+ assert not asyncio.coroutines.iscoroutinefunction(f)
+
actual = f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
assert isinstance(actual[0], Currency)
assert actual[0].amount == 123.45
@@ -370,6 +373,24 @@ def f(
assert actual[1] == "EUR"
+@pytest.mark.asyncio
+async def test_load_basemodels_if_needed_async() -> None:
+ @load_basemodels_if_needed
+ async def f(
+ base: Annotated[Currency, "Base currency"],
+ quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
+ ) -> Tuple[Currency, CurrencySymbol]:
+ return base, quote_currency
+
+ assert asyncio.coroutines.iscoroutinefunction(f)
+
+ actual = await f(base={"currency": "USD", "amount": 123.45}, quote_currency="EUR")
+ assert isinstance(actual[0], Currency)
+ assert actual[0].amount == 123.45
+ assert actual[0].currency == "USD"
+ assert actual[1] == "EUR"
+
+
def test_serialize_to_json():
assert serialize_to_str("abc") == "abc"
assert serialize_to_str(123) == "123"
diff --git a/website/docs/Use-Cases/agent_chat.md b/website/docs/Use-Cases/agent_chat.md
index 03f59484f8d3..000780c193a5 100644
--- a/website/docs/Use-Cases/agent_chat.md
+++ b/website/docs/Use-Cases/agent_chat.md
@@ -54,29 +54,75 @@ or Pydantic models:
The following examples illustrates the process of registering a custom function for currency exchange calculation that uses type hints and standard Python datatypes:
+1. First, we import necessary libraries and configure models using [`autogen.config_list_from_json`](../FAQ#set-your-api-endpoints) function:
+
``` python
from typing import Literal
+
+from pydantic import BaseModel, Field
from typing_extensions import Annotated
-from somewhere import exchange_rate
-# the agents are instances of AssistantAgent and UserProxyAgent
-from myagents import chatbot, user_proxy
+import autogen
+
+config_list = autogen.config_list_from_json(
+ "OAI_CONFIG_LIST",
+ filter_dict={
+ "model": ["gpt-4", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"],
+ },
+)
+```
+
+2. We create an assistant agent and user proxy. The assistant will be responsible for suggesting which functions to call and the user proxy for the actual execution of a proposed function:
+
+``` python
+llm_config = {
+ "config_list": config_list,
+ "timeout": 120,
+}
+
+chatbot = autogen.AssistantAgent(
+ name="chatbot",
+ system_message="For currency exchange tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.",
+ llm_config=llm_config,
+)
+
+# create a UserProxyAgent instance named "user_proxy"
+user_proxy = autogen.UserProxyAgent(
+ name="user_proxy",
+ is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"),
+ human_input_mode="NEVER",
+ max_consecutive_auto_reply=10,
+)
+```
+
+3. We define the function `currency_calculator` below as follows and decorate it with two decorators:
+ - [`@user_proxy.register_for_execution()`](../reference/agentchat/conversable_agent#register_for_execution) adding the function `currency_calculator` to `user_proxy.function_map`, and
+ - [`@chatbot.register_for_llm`](../reference/agentchat/conversable_agent#register_for_llm) adding a generated JSON schema of the function to `llm_config` of `chatbot`.
+
+``` python
CurrencySymbol = Literal["USD", "EUR"]
-# registers the function for execution (updates function map)
+
+def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float:
+ if base_currency == quote_currency:
+ return 1.0
+ elif base_currency == "USD" and quote_currency == "EUR":
+ return 1 / 1.1
+ elif base_currency == "EUR" and quote_currency == "USD":
+ return 1.1
+ else:
+ raise ValueError(f"Unknown currencies {base_currency}, {quote_currency}")
+
+
@user_proxy.register_for_execution()
-# creates JSON schema from type hints and registers the function to llm_config
@chatbot.register_for_llm(description="Currency exchange calculator.")
-# python function with type hints
def currency_calculator(
- # Annotated type is used for attaching description to the parameter
- base_amount: Annotated[float, "Amount of currency in base_currency"],
- # default values of parameters will be propagated to the LLM
- base_currency: Annotated[CurrencySymbol, "Base currency"] = "USD",
- quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
-) -> str: # return type must be either str, BaseModel or serializable by json.dumps()
- quote_amount = exchange_rate(base_currency, quote_currency) * base_amount
- return f"{quote_amount} {quote_currency}"
+ base_amount: Annotated[float, "Amount of currency in base_currency"],
+ base_currency: Annotated[CurrencySymbol, "Base currency"] = "USD",
+ quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR",
+) -> str:
+ quote_amount = exchange_rate(base_currency, quote_currency) * base_amount
+ return f"{quote_amount} {quote_currency}"
```
Notice the use of [Annotated](https://docs.python.org/3/library/typing.html?highlight=annotated#typing.Annotated) to specify the type and the description of each parameter. The return value of the function must be either string or serializable to string using the [`json.dumps()`](https://docs.python.org/3/library/json.html#json.dumps) or [`Pydantic` model dump to JSON](https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump_json) (both version 1.x and 2.x are supported).
@@ -99,7 +145,14 @@ You can check the JSON schema generated by the decorator `chatbot.llm_config["to
'description': 'Quote currency'}},
'required': ['base_amount']}}}]
```
-Agents can now use the function as follows:
+4. Agents can now use the function as follows:
+```python
+user_proxy.initiate_chat(
+ chatbot,
+ message="How much is 123.45 USD in EUR?",
+)
+```
+Output:
```
user_proxy (to chatbot):
@@ -139,12 +192,6 @@ encoded string automatically.
The following example shows how we could rewrite our currency exchange calculator example:
``` python
-from typing import Literal
-from typing_extensions import Annotated
-from pydantic import BaseModel, Field
-from somewhere import exchange_rate
-from myagents import chatbot, user_proxy
-
# defines a Pydantic model
class Currency(BaseModel):
# parameter of type CurrencySymbol
diff --git a/website/yarn.lock b/website/yarn.lock
index 3d7a9d9b6045..d28ccc6ddca1 100644
--- a/website/yarn.lock
+++ b/website/yarn.lock
@@ -4390,9 +4390,9 @@ flux@^4.0.1:
fbjs "^3.0.1"
follow-redirects@^1.0.0, follow-redirects@^1.14.7:
- version "1.15.2"
- resolved "https://registry.npmmirror.com/follow-redirects/-/follow-redirects-1.15.2.tgz#b460864144ba63f2681096f274c4e57026da2c13"
- integrity sha512-VQLG33o04KaQ8uYi2tVNbdrWp1QWxNNea+nmIB4EVM28v0hmP17z7aG1+wAkNzVq4KeXTq3221ye5qTJP91JwA==
+ version "1.15.4"
+ resolved "https://registry.yarnpkg.com/follow-redirects/-/follow-redirects-1.15.4.tgz#cdc7d308bf6493126b17ea2191ea0ccf3e535adf"
+ integrity sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==
fork-ts-checker-webpack-plugin@^6.0.5:
version "6.5.2"