Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switched to AzureOpenAI for api_type=="azure" #1232

Merged
merged 16 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 31 additions & 43 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from autogen.oai import completion

from autogen.oai.openai_utils import get_key, OAI_PRICE1K
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, get_key, OAI_PRICE1K
from autogen.token_count_utils import count_token
from autogen._pydantic import model_dump

Expand All @@ -21,9 +21,10 @@
except ImportError:
ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
OpenAI = object
AzureOpenAI = object
else:
# raises exception if openai>=1 is installed and something is wrong with imports
from openai import OpenAI, APIError, __version__ as OPENAIVERSION
from openai import OpenAI, AzureOpenAI, APIError, __version__ as OPENAIVERSION
from openai.resources import Completions
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
Expand Down Expand Up @@ -52,8 +53,18 @@ class OpenAIWrapper:
"""A wrapper class for openai client."""

cache_path_root: str = ".cache"
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version", "tags"}
extra_kwargs = {
"cache_seed",
"filter_func",
"allow_format_str_template",
"context",
"api_version",
"api_type",
"tags",
}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
aopenai_kwargs = set(inspect.getfullargspec(AzureOpenAI.__init__).kwonlyargs)
openai_kwargs = openai_kwargs | aopenai_kwargs
total_usage_summary: Optional[Dict[str, Any]] = None
actual_usage_summary: Optional[Dict[str, Any]] = None

Expand Down Expand Up @@ -105,46 +116,10 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base
self._clients = [self._client(extra_kwargs, openai_config)]
self._config_list = [extra_kwargs]

def _process_for_azure(
self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default"
) -> None:
# deal with api_version
query_segment = f"{segment}_query"
headers_segment = f"{segment}_headers"
api_version = extra_kwargs.get("api_version")
if api_version is not None and query_segment not in config:
config[query_segment] = {"api-version": api_version}
if segment == "default":
# remove the api_version from extra_kwargs
extra_kwargs.pop("api_version")
if segment == "extra":
return
# deal with api_type
api_type = extra_kwargs.get("api_type")
if api_type is not None and api_type.startswith("azure") and headers_segment not in config:
api_key = config.get("api_key", os.environ.get("AZURE_OPENAI_API_KEY"))
config[headers_segment] = {"api-key": api_key}
# remove the api_type from extra_kwargs
extra_kwargs.pop("api_type")
# deal with model
model = extra_kwargs.get("model")
if model is None:
return
if "gpt-3.5" in model:
# hack for azure gpt-3.5
extra_kwargs["model"] = model = model.replace("gpt-3.5", "gpt-35")
base_url = config.get("base_url")
if base_url is None:
raise ValueError("to use azure openai api, base_url must be specified.")
suffix = f"/openai/deployments/{model}"
if not base_url.endswith(suffix):
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix

def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Separate the config into openai_config and extra_kwargs."""
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
self._process_for_azure(openai_config, extra_kwargs)
sonichi marked this conversation as resolved.
Show resolved Hide resolved
return openai_config, extra_kwargs

def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
Expand All @@ -156,10 +131,22 @@ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any
def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI:
"""Create a client with the given config to override openai_config,
after removing extra kwargs.

For Azure models/deployment names there's a convenience modification of model removing dots in
the it's value (Azure deploment names can't have dots). I.e. if you have Azure deployment name
"gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot
from the name and create a client that connects to "gpt-35-turbo" Azure deployment.
"""
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
self._process_for_azure(openai_config, config)
client = OpenAI(**openai_config)
api_type = config.get("api_type")
sonichi marked this conversation as resolved.
Show resolved Hide resolved
if api_type is not None and api_type.startswith("azure"):
openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model"))
if openai_config["azure_deployment"] is not None:
openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "")
openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None))
client = AzureOpenAI(**openai_config)
else:
client = OpenAI(**openai_config)
return client

@classmethod
Expand Down Expand Up @@ -242,8 +229,9 @@ def yes_or_no_filter(context, response):
full_config = {**config, **self._config_list[i]}
# separate the config into create_config and extra_kwargs
create_config, extra_kwargs = self._separate_create_config(full_config)
# process for azure
self._process_for_azure(create_config, extra_kwargs, "extra")
api_type = extra_kwargs.get("api_type")
if api_type and api_type.startswith("azure") and "model" in create_config:
create_config["model"] = create_config["model"].replace(".", "")
# construct the create params
params = self._construct_create_params(create_config, extra_kwargs)
# get the cache_seed, filter_func and context
Expand Down
40 changes: 24 additions & 16 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,15 @@ def test_aoai_chat_completion():
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo", "gpt-35-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
sonichi marked this conversation as resolved.
Show resolved Hide resolved
# for config in config_list:
# print(config)
# client = OpenAIWrapper(**config)
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
print(response)
print(client.extract_text_or_completion_object(response))

# test dialect
config = config_list[0]
config["azure_deployment"] = config["model"]
config["azure_endpoint"] = config.pop("base_url")
client = OpenAIWrapper(**config)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
print(response)
print(client.extract_text_or_completion_object(response))
Expand Down Expand Up @@ -93,21 +98,23 @@ def test_chat_completion():
def test_completion():
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
model = "gpt-3.5-turbo-instruct"
response = client.create(prompt="1+1=", model=model)
print(response)
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.parametrize(
"cache_seed, model",
"cache_seed",
[
(None, "gpt-3.5-turbo-instruct"),
(42, "gpt-3.5-turbo-instruct"),
None,
42,
],
)
def test_cost(cache_seed, model):
def test_cost(cache_seed):
config_list = config_list_openai_aoai(KEY_LOC)
model = "gpt-3.5-turbo-instruct"
client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed)
response = client.create(prompt="1+3=", model=model)
print(response.cost)
Expand All @@ -117,7 +124,8 @@ def test_cost(cache_seed, model):
def test_usage_summary():
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=None)
model = "gpt-3.5-turbo-instruct"
response = client.create(prompt="1+3=", model=model, cache_seed=None)

# usage should be recorded
assert client.actual_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
Expand All @@ -138,15 +146,15 @@ def test_usage_summary():
assert client.total_usage_summary is None, "total_usage_summary should be None"

# actual usage and all usage should be different
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=42)
response = client.create(prompt="1+3=", model=model, cache_seed=42)
assert client.total_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
assert client.actual_usage_summary is None, "No actual cost should be recorded"


if __name__ == "__main__":
test_aoai_chat_completion()
test_oai_tool_calling_extraction()
test_chat_completion()
# test_aoai_chat_completion()
# test_oai_tool_calling_extraction()
# test_chat_completion()
test_completion()
# test_cost()
test_usage_summary()
# # test_cost()
# test_usage_summary()
4 changes: 3 additions & 1 deletion test/oai/test_client_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ def test_chat_tools_stream() -> None:
def test_completion_stream() -> None:
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
# Azure can't have dot in model/deployment name
model = "gpt-35-turbo-instruct" if config_list[0].get("api_type") == "azure" else "gpt-3.5-turbo-instruct"
response = client.create(prompt="1+1=", model=model, stream=True)
print(response)
print(client.extract_text_or_completion_object(response))

Expand Down
Loading