Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Azure AD token provider #439

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,21 @@ To use the [Azure OpenAI Service](https://learn.microsoft.com/en-us/azure/cognit

where `AZURE_OPENAI_URI` is e.g. `https://custom-domain.openai.azure.com/openai/deployments/gpt-35-turbo`

##### Azure with Azure AD tokens

To use Azure AD tokens you can configure the gem with a proc like this:

```ruby
OpenAI.configure do |config|
config.azure_token_provider = ->() { your_code_caches_or_refreshes_token }
config.uri_base = ENV.fetch("AZURE_OPENAI_URI")
config.api_type = :azure
config.api_version = "2023-03-15-preview"
end
```

The azure_token_provider will be called on every request. This allows tokens to be cached and periodically refreshed by your custom code.

### Counting Tokens

OpenAI parses prompt text into [tokens](https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them), which are words or portions of words. (These tokens are unrelated to your API access_token.) Counting tokens can help you estimate your [costs](https://openai.com/pricing). It can also help you ensure your prompt text size is within the max-token limits of your model's context window, and choose an appropriate [`max_tokens`](https://platform.openai.com/docs/api-reference/chat/create#chat/create-max_tokens) completion parameter so your response will fit as well.
Expand Down
11 changes: 2 additions & 9 deletions lib/openai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@ def call(env)
end

class Configuration
attr_writer :access_token
attr_accessor :api_type, :api_version, :organization_id, :uri_base, :request_timeout,
:extra_headers
:extra_headers, :access_token, :azure_token_provider

DEFAULT_API_VERSION = "v1".freeze
DEFAULT_URI_BASE = "https://api.openai.com/".freeze
Expand All @@ -52,13 +51,7 @@ def initialize
@uri_base = DEFAULT_URI_BASE
@request_timeout = DEFAULT_REQUEST_TIMEOUT
@extra_headers = {}
end

def access_token
return @access_token if @access_token

error_text = "OpenAI access token missing! See https://github.com/alexrudall/ruby-openai#usage"
raise ConfigurationError, error_text
@azure_token_provider = nil
end
end

Expand Down
31 changes: 30 additions & 1 deletion lib/openai/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@ class Client
uri_base
request_timeout
extra_headers
azure_token_provider
].freeze
attr_reader *CONFIG_KEYS, :faraday_middleware

def initialize(config = {}, &faraday_middleware)
CONFIG_KEYS.each do |key|
# Set instance variables like api_type & access_token. Fall back to global config
# if not present.
instance_variable_set("@#{key}", config[key] || OpenAI.configuration.send(key))
instance_variable_set("@#{key}",
config.key?(key) ? config[key] : OpenAI.configuration.send(key))
end
@faraday_middleware = faraday_middleware
validate_credential_config!
validate_azure_credential_provider!
end

def chat(parameters: {})
Expand Down Expand Up @@ -87,5 +91,30 @@ def beta(apis)
client.add_headers("OpenAI-Beta": apis.map { |k, v| "#{k}=#{v}" }.join(";"))
end
end

private

def validate_credential_config!
if @access_token && @azure_token_provider
raise ConfigurationError,
"Only one of OpenAI access token or Azure token provider can be set! See https://github.com/alexrudall/ruby-openai#usage"
end

return if @access_token || @azure_token_provider

raise ConfigurationError,
"OpenAI access token or Azure token provider missing! See https://github.com/alexrudall/ruby-openai#usage"
end

def validate_azure_credential_provider!
return if @azure_token_provider.nil?

unless @azure_token_provider.respond_to?(:to_proc)
raise ConfigurationError,
"OpenAI Azure AD token provider must be a Proc, Lambda, or respond to to_proc."
end

@azure_token_provider = @azure_token_provider&.to_proc
end
end
end
15 changes: 12 additions & 3 deletions lib/openai/http_headers.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,18 @@ def openai_headers

def azure_headers
{
"Content-Type" => "application/json",
"api-key" => @access_token
}
"Content-Type" => "application/json"
}.merge(azure_auth_headers)
end

def azure_auth_headers
if @access_token
{ "api-key" => @access_token }
elsif @azure_token_provider
{ "Authorization" => "Bearer #{@azure_token_provider.call}" }
else
raise ConfigurationError, "access_token or azure_token_provider must be set."
end
end

def extra_headers
Expand Down
125 changes: 125 additions & 0 deletions spec/fixtures/cassettes/http_json_post_with_azure_token_provider.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion spec/openai/client/client_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
let!(:c2) do
OpenAI::Client.new(
access_token: "access_token2",
organization_id: nil,
request_timeout: 1,
uri_base: "https://example.com/"
)
Expand Down
72 changes: 72 additions & 0 deletions spec/openai/client/http_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,56 @@
end
end

describe ".json_post" do
context "with azure_token_provider" do
let(:token_provider) do
counter = 0
lambda do
counter += 1
"some dynamic token #{counter}"
end
end

let(:client) do
OpenAI::Client.new(
access_token: nil,
azure_token_provider: token_provider,
api_type: :azure,
uri_base: "https://custom-domain.openai.azure.com/openai/deployments/gpt-35-turbo",
api_version: "2024-02-01"
)
end

let(:cassette) { "http json post with azure token provider" }

it "calls the token provider on every request" do
expect(token_provider).to receive(:call).twice.and_call_original
VCR.use_cassette(cassette, record: :none) do
client.chat(
parameters: {
messages: [
{
"role" => "user",
"content" => "Hello world!"
}
]
}
)
client.chat(
parameters: {
messages: [
{
"role" => "user",
"content" => "Who were the founders of Microsoft?"
}
]
}
)
end
end
end
end

describe ".to_json_stream" do
context "with a proc" do
let(:user_proc) { proc { |x| x } }
Expand Down Expand Up @@ -269,6 +319,28 @@
expect(headers).to eq({ "Content-Type" => "application/json",
"api-key" => OpenAI.configuration.access_token })
}

context "with azure_token_provider" do
let(:token) { "some dynamic token" }
let(:token_provider) { -> { token } }

around do |ex|
old_access_token = OpenAI.configuration.access_token
OpenAI.configuration.access_token = nil
OpenAI.configuration.azure_token_provider = token_provider

ex.call
ensure
OpenAI.configuration.azure_token_provider = nil
OpenAI.configuration.access_token = old_access_token
end

it {
expect(token_provider).to receive(:call).once.and_call_original
expect(headers).to eq({ "Content-Type" => "application/json",
"Authorization" => "Bearer #{token}" })
}
end
end
end
end