Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switched to AzureOpenAI for api_type=="azure" #1232

Merged
merged 16 commits into from
Jan 17, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 17 additions & 42 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from autogen.oai import completion

from autogen.oai.openai_utils import get_key, OAI_PRICE1K
from autogen.oai.openai_utils import DEFAULT_AZURE_API_VERSION, get_key, OAI_PRICE1K
from autogen.token_count_utils import count_token
from autogen._pydantic import model_dump

Expand All @@ -21,9 +21,10 @@
except ImportError:
ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
OpenAI = object
AzureOpenAI = object
else:
# raises exception if openai>=1 is installed and something is wrong with imports
from openai import OpenAI, APIError, __version__ as OPENAIVERSION
from openai import OpenAI, AzureOpenAI, APIError, __version__ as OPENAIVERSION
from openai.resources import Completions
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined]
Expand Down Expand Up @@ -105,46 +106,10 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base
self._clients = [self._client(extra_kwargs, openai_config)]
self._config_list = [extra_kwargs]

def _process_for_azure(
self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default"
) -> None:
# deal with api_version
query_segment = f"{segment}_query"
headers_segment = f"{segment}_headers"
api_version = extra_kwargs.get("api_version")
if api_version is not None and query_segment not in config:
config[query_segment] = {"api-version": api_version}
if segment == "default":
# remove the api_version from extra_kwargs
extra_kwargs.pop("api_version")
if segment == "extra":
return
# deal with api_type
api_type = extra_kwargs.get("api_type")
if api_type is not None and api_type.startswith("azure") and headers_segment not in config:
api_key = config.get("api_key", os.environ.get("AZURE_OPENAI_API_KEY"))
config[headers_segment] = {"api-key": api_key}
# remove the api_type from extra_kwargs
extra_kwargs.pop("api_type")
# deal with model
model = extra_kwargs.get("model")
if model is None:
return
if "gpt-3.5" in model:
# hack for azure gpt-3.5
extra_kwargs["model"] = model = model.replace("gpt-3.5", "gpt-35")
base_url = config.get("base_url")
if base_url is None:
raise ValueError("to use azure openai api, base_url must be specified.")
suffix = f"/openai/deployments/{model}"
if not base_url.endswith(suffix):
config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix

def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Separate the config into openai_config and extra_kwargs."""
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
self._process_for_azure(openai_config, extra_kwargs)
sonichi marked this conversation as resolved.
Show resolved Hide resolved
return openai_config, extra_kwargs

def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
Expand All @@ -158,8 +123,19 @@ def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> Open
after removing extra kwargs.
"""
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
self._process_for_azure(openai_config, config)
client = OpenAI(**openai_config)
api_type = config.get("api_type")
sonichi marked this conversation as resolved.
Show resolved Hide resolved
if api_type is not None and api_type.startswith("azure"):
api_key = config.get("api_key", os.environ.get("AZURE_OPENAI_API_KEY"))
api_version = config.get("api_version", DEFAULT_AZURE_API_VERSION)
model = config.get("model")
base_url = config.get("base_url")
if base_url is None:
raise ValueError("to use azure openai api, base_url must be specified.")
sonichi marked this conversation as resolved.
Show resolved Hide resolved
client = AzureOpenAI(
azure_deployment=model, api_version=api_version, api_key=api_key, azure_endpoint=base_url
)
else:
client = OpenAI(**openai_config)
return client

@classmethod
Expand Down Expand Up @@ -242,8 +218,6 @@ def yes_or_no_filter(context, response):
full_config = {**config, **self._config_list[i]}
# separate the config into create_config and extra_kwargs
create_config, extra_kwargs = self._separate_create_config(full_config)
# process for azure
self._process_for_azure(create_config, extra_kwargs, "extra")
# construct the create params
params = self._construct_create_params(create_config, extra_kwargs)
# get the cache_seed, filter_func and context
Expand Down Expand Up @@ -540,6 +514,7 @@ def _completions_create(self, client: OpenAI, params: Dict[str, Any]) -> ChatCom
# If streaming is not enabled, send a regular chat completion request
params = params.copy()
params["stream"] = False
params.pop("api_type", None)
sonichi marked this conversation as resolved.
Show resolved Hide resolved
response = completions.create(**params)

return response
Expand Down
Loading