From 7ae06bb73f697688cc399add50b136dc7c4b8ea1 Mon Sep 17 00:00:00 2001 From: Pablo Valdunciel Sanchez Date: Tue, 24 Sep 2024 13:29:17 +0200 Subject: [PATCH] Support for Azure OpenAI API Authored-by: Pablo Valdunciel --- CHANGELOG.md | 2 ++ src/llm_interface.jl | 11 +++++++++ src/llm_openai.jl | 51 +++++++++++++++++++++++++++++++++++++++++ src/user_preferences.jl | 15 +++++++++++- 4 files changed, 78 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e1a9ddf0..3bc52536 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Support for [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference). Requires two environment variables to be st: `AZURE_OPENAI_API_KEY` and `AZURE_OPENAI_HOST`(i.e. https://.openai.azure.com). + ### Fixed ## [0.56.1] diff --git a/src/llm_interface.jl b/src/llm_interface.jl index 40669844..77fb8067 100644 --- a/src/llm_interface.jl +++ b/src/llm_interface.jl @@ -148,6 +148,17 @@ Requires two environment variables to be set: """ struct DatabricksOpenAISchema <: AbstractOpenAISchema end +""" + AzureOpenAISchema + +AzureOpenAISchema() allows user to call Azure OpenAI API. [API Reference](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) + +Requires two environment variables to be set: +- `AZURE_OPENAI_API_KEY`: Azure token +- `AZURE_OPENAI_HOST`: Address of the Azure resource (`"https://.openai.azure.com"`) +""" +struct AzureOpenAISchema <: AbstractOpenAISchema end + """ FireworksOpenAISchema diff --git a/src/llm_openai.jl b/src/llm_openai.jl index 629f64b9..2e83e81a 100644 --- a/src/llm_openai.jl +++ b/src/llm_openai.jl @@ -271,6 +271,34 @@ function OpenAI.create_chat(schema::DatabricksOpenAISchema, kwargs...) end end +function OpenAI.create_chat(schema::AzureOpenAISchema, + api_key::AbstractString, + model::AbstractString, + conversation; + api_version::String = "2023-03-15-preview", + http_kwargs::NamedTuple = NamedTuple(), + streamcallback::Any = nothing, + url::String = "https://.openai.azure.com", + kwargs...) + + # Build the corresponding provider object + provider = OpenAI.AzureProvider(; + api_key = isempty(AZURE_OPENAI_API_KEY) ? api_key : AZURE_OPENAI_API_KEY, + base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) * "/openai/deployments/$model", + api_version = api_version + ) + # Override standard OpenAI request endpoint + OpenAI.openai_request( + "chat/completions", + provider; + method = "POST", + http_kwargs = http_kwargs, + messages = conversation, + query = Dict("api-version" => provider.api_version), + streamcallback = streamcallback, + kwargs... + ) +end # Extend OpenAI create_embeddings to allow for testing function OpenAI.create_embeddings(schema::AbstractOpenAISchema, @@ -367,6 +395,29 @@ function OpenAI.create_embeddings(schema::FireworksOpenAISchema, base_url = url) OpenAI.create_embeddings(provider, docs, model; kwargs...) end +function OpenAI.create_embeddings(schema::AzureOpenAISchema, + api_key::AbstractString, + docs, + model::AbstractString; + api_version::String = "2023-03-15-preview", + url::String = "https://.openai.azure.com", + kwargs...) + + # Build the corresponding provider object + provider = OpenAI.AzureProvider(; + api_key = isempty(AZURE_OPENAI_API_KEY) ? api_key : AZURE_OPENAI_API_KEY, + base_url = (isempty(AZURE_OPENAI_HOST) ? url : AZURE_OPENAI_HOST) * "/openai/deployments/$model", + api_version = api_version) + # Override standard OpenAI request endpoint + OpenAI.openai_request( + "embeddings", + provider; + method = "POST", + input = docs, + query = Dict("api-version" => provider.api_version), + kwargs... + ) +end ## Temporary fix -- it will be moved upstream function OpenAI.create_embeddings(provider::AbstractCustomProvider, diff --git a/src/user_preferences.jl b/src/user_preferences.jl index e8a6ed10..baff2c41 100644 --- a/src/user_preferences.jl +++ b/src/user_preferences.jl @@ -11,6 +11,8 @@ Check your preferences by calling `get_preferences(key::String)`. # Available Preferences (for `set_preferences!`) - `OPENAI_API_KEY`: The API key for the OpenAI API. See [OpenAI's documentation](https://platform.openai.com/docs/quickstart?context=python) for more information. +- `AZURE_OPENAI_API_KEY`: The API key for the Azure OpenAI API. See [Azure OpenAI's documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) for more information. +- `AZURE_OPENAI_HOST`: The host for the Azure OpenAI API. See [Azure OpenAI's documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) for more information. - `MISTRALAI_API_KEY`: The API key for the Mistral AI API. See [Mistral AI's documentation](https://docs.mistral.ai/) for more information. - `COHERE_API_KEY`: The API key for the Cohere API. See [Cohere's documentation](https://docs.cohere.com/docs/the-cohere-platform) for more information. - `DATABRICKS_API_KEY`: The API key for the Databricks Foundation Model API. See [Databricks' documentation](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html) for more information. @@ -39,6 +41,8 @@ Define your `register_model!()` calls in your `startup.jl` file to make them ava # Available ENV Variables - `OPENAI_API_KEY`: The API key for the OpenAI API. +- `AZURE_OPENAI_API_KEY`: The API key for the Azure OpenAI API. +- `AZURE_OPENAI_HOST`: The host for the Azure OpenAI API. This is the URL built as `https://.openai.azure.com`. - `MISTRALAI_API_KEY`: The API key for the Mistral AI API. - `COHERE_API_KEY`: The API key for the Cohere API. - `LOCAL_SERVER`: The URL of the local server to use for `ai*` calls. Defaults to `http://localhost:10897/v1`. This server is called when you call `model="local"` @@ -62,6 +66,8 @@ const PREFERENCES = nothing "Keys that are allowed to be set via `set_preferences!`" const ALLOWED_PREFERENCES = ["MISTRALAI_API_KEY", "OPENAI_API_KEY", + "AZURE_OPENAI_API_KEY", + "AZURE_OPENAI_HOST", "COHERE_API_KEY", "DATABRICKS_API_KEY", "DATABRICKS_HOST", @@ -138,6 +144,8 @@ global MODEL_IMAGE_GENERATION::String = @load_preference("MODEL_IMAGE_GENERATION # First, load from preferences, then from environment variables # Instantiate empty global variables global OPENAI_API_KEY::String = "" +global AZURE_OPENAI_API_KEY::String = "" +global AZURE_OPENAI_HOST::String = "" global MISTRALAI_API_KEY::String = "" global COHERE_API_KEY::String = "" global DATABRICKS_API_KEY::String = "" @@ -163,7 +171,12 @@ function load_api_keys!() # Note: Disable this warning by setting OPENAI_API_KEY to anything isempty(OPENAI_API_KEY) && @warn "OPENAI_API_KEY variable not set! OpenAI models will not be available - set API key directly via `PromptingTools.OPENAI_API_KEY=`!" - + global AZURE_OPENAI_API_KEY + AZURE_OPENAI_API_KEY = @load_preference("AZURE_OPENAI_API_KEY", + default=get(ENV, "AZURE_OPENAI_API_KEY", "")) + global AZURE_OPENAI_HOST + AZURE_OPENAI_HOST = @load_preference("AZURE_OPENAI_HOST", + default=get(ENV, "AZURE_OPENAI_HOST", "")) global MISTRALAI_API_KEY MISTRALAI_API_KEY = @load_preference("MISTRALAI_API_KEY", default=get(ENV, "MISTRALAI_API_KEY", ""))