diff --git a/autogen/oai/client.py b/autogen/oai/client.py index fff480120337..65ad14254091 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -11,7 +11,7 @@ from autogen.oai import completion -from autogen.oai.openai_utils import get_key, OAI_PRICE1K +from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, get_key, OAI_PRICE1K from autogen.token_count_utils import count_token from autogen._pydantic import model_dump @@ -21,9 +21,10 @@ except ImportError: ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.") OpenAI = object + AzureOpenAI = object else: # raises exception if openai>=1 is installed and something is wrong with imports - from openai import OpenAI, APIError, __version__ as OPENAIVERSION + from openai import OpenAI, AzureOpenAI, APIError, __version__ as OPENAIVERSION from openai.resources import Completions from openai.types.chat import ChatCompletion from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined] @@ -52,8 +53,18 @@ class OpenAIWrapper: """A wrapper class for openai client.""" cache_path_root: str = ".cache" - extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"} + extra_kwargs = { + "cache_seed", + "filter_func", + "allow_format_str_template", + "context", + "api_version", + "api_type", + "tags", + } openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) + aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs) + openai_kwargs = openai_kwargs | aopenai_kwargs total_usage_summary: Optional[Dict[str, Any]] = None actual_usage_summary: Optional[Dict[str, Any]] = None @@ -105,46 +116,10 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base self._clients = [self._client(extra_kwargs, openai_config)] self._config_list = [extra_kwargs] - def _process_for_azure( - self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default" - ) -> None: - # deal with api_version - query_segment = f"{segment}_query" - headers_segment = f"{segment}_headers" - api_version = extra_kwargs.get("api_version") - if api_version is not None and query_segment not in config: - config[query_segment] = {"api-version": api_version} - if segment == "default": - # remove the api_version from extra_kwargs - extra_kwargs.pop("api_version") - if segment == "extra": - return - # deal with api_type - api_type = extra_kwargs.get("api_type") - if api_type is not None and api_type.startswith("azure") and headers_segment not in config: - api_key = config.get("api_key", os.environ.get("AZURE_OPENAI_API_KEY")) - config[headers_segment] = {"api-key": api_key} - # remove the api_type from extra_kwargs - extra_kwargs.pop("api_type") - # deal with model - model = extra_kwargs.get("model") - if model is None: - return - if "gpt-3.5" in model: - # hack for azure gpt-3.5 - extra_kwargs["model"] = model = model.replace("gpt-3.5", "gpt-35") - base_url = config.get("base_url") - if base_url is None: - raise ValueError("to use azure openai api, base_url must be specified.") - suffix = f"/openai/deployments/{model}" - if not base_url.endswith(suffix): - config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix - def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Separate the config into openai_config and extra_kwargs.""" openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs} extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs} - self._process_for_azure(openai_config, extra_kwargs) return openai_config, extra_kwargs def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -156,10 +131,22 @@ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI: """Create a client with the given config to override openai_config, after removing extra kwargs. + + For Azure models/deployment names there's a convenience modification of model removing dots in + the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name + "gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot + from the name and create a client that connects to "gpt-35-turbo" Azure deployment. """ openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} - self._process_for_azure(openai_config, config) - client = OpenAI(**openai_config) + api_type = config.get("api_type") + if api_type is not None and api_type.startswith("azure"): + openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model")) + if openai_config["azure_deployment"] is not None: + openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") + openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) + client = AzureOpenAI(**openai_config) + else: + client = OpenAI(**openai_config) return client @classmethod @@ -242,8 +229,9 @@ def yes_or_no_filter(context, response): full_config = {**config, **self._config_list[i]} # separate the config into create_config and extra_kwargs create_config, extra_kwargs = self._separate_create_config(full_config) - # process for azure - self._process_for_azure(create_config, extra_kwargs, "extra") + api_type = extra_kwargs.get("api_type") + if api_type and api_type.startswith("azure") and "model" in create_config: + create_config["model"] = create_config["model"].replace(".", "") # construct the create params params = self._construct_create_params(create_config, extra_kwargs) # get the cache_seed, filter_func and context diff --git a/test/oai/test_client.py b/test/oai/test_client.py index 26a05396160c..7f561187d491 100644 --- a/test/oai/test_client.py +++ b/test/oai/test_client.py @@ -31,10 +31,15 @@ def test_aoai_chat_completion(): filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]}, ) client = OpenAIWrapper(config_list=config_list) - # for config in config_list: - # print(config) - # client = OpenAIWrapper(**config) - # response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) + response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) + print(response) + print(client.extract_text_or_completion_object(response)) + + # test dialect + config = config_list[0] + config["azure_deployment"] = config["model"] + config["azure_endpoint"] = config.pop("base_url") + client = OpenAIWrapper(**config) response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) print(response) print(client.extract_text_or_completion_object(response)) @@ -93,21 +98,23 @@ def test_chat_completion(): def test_completion(): config_list = config_list_openai_aoai(KEY_LOC) client = OpenAIWrapper(config_list=config_list) - response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct") + model = "gpt-3.5-turbo-instruct" + response = client.create(prompt="1+1=", model=model) print(response) print(client.extract_text_or_completion_object(response)) @pytest.mark.skipif(skip, reason="openai>=1 not installed") @pytest.mark.parametrize( - "cache_seed, model", + "cache_seed", [ - (None, "gpt-3.5-turbo-instruct"), - (42, "gpt-3.5-turbo-instruct"), + None, + 42, ], ) -def test_cost(cache_seed, model): +def test_cost(cache_seed): config_list = config_list_openai_aoai(KEY_LOC) + model = "gpt-3.5-turbo-instruct" client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed) response = client.create(prompt="1+3=", model=model) print(response.cost) @@ -117,7 +124,8 @@ def test_cost(cache_seed, model): def test_usage_summary(): config_list = config_list_openai_aoai(KEY_LOC) client = OpenAIWrapper(config_list=config_list) - response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=None) + model = "gpt-3.5-turbo-instruct" + response = client.create(prompt="1+3=", model=model, cache_seed=None) # usage should be recorded assert client.actual_usage_summary["total_cost"] > 0, "total_cost should be greater than 0" @@ -138,15 +146,15 @@ def test_usage_summary(): assert client.total_usage_summary is None, "total_usage_summary should be None" # actual usage and all usage should be different - response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=42) + response = client.create(prompt="1+3=", model=model, cache_seed=42) assert client.total_usage_summary["total_cost"] > 0, "total_cost should be greater than 0" assert client.actual_usage_summary is None, "No actual cost should be recorded" if __name__ == "__main__": - test_aoai_chat_completion() - test_oai_tool_calling_extraction() - test_chat_completion() + # test_aoai_chat_completion() + # test_oai_tool_calling_extraction() + # test_chat_completion() test_completion() - # test_cost() - test_usage_summary() + # # test_cost() + # test_usage_summary() diff --git a/test/oai/test_client_stream.py b/test/oai/test_client_stream.py index 6a20c4ffa21a..63ee782f68e3 100644 --- a/test/oai/test_client_stream.py +++ b/test/oai/test_client_stream.py @@ -286,7 +286,9 @@ def test_chat_tools_stream() -> None: def test_completion_stream() -> None: config_list = config_list_openai_aoai(KEY_LOC) client = OpenAIWrapper(config_list=config_list) - response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True) + # Azure can't have dot in model/deployment name + model = "gpt-35-turbo-instruct" if config_list[0].get("api_type") == "azure" else "gpt-3.5-turbo-instruct" + response = client.create(prompt="1+1=", model=model, stream=True) print(response) print(client.extract_text_or_completion_object(response))