Skip to content

Commit

Permalink
partners/openai + community: Async Azure AD token provider support fo…
Browse files Browse the repository at this point in the history
…r Azure OpenAI (#27488)

This PR introduces a new `azure_ad_async_token_provider` attribute to
the `AzureOpenAI` and `AzureChatOpenAI` classes in `partners/openai` and
`community` packages, given it's currently supported on `openai` package
as
[AsyncAzureADTokenProvider](https://github.com/openai/openai-python/blob/main/src/openai/lib/azure.py#L33)
type.

The reason for creating a new attribute is to avoid breaking changes.
Let's say you have an existing code that uses a `AzureOpenAI` or
`AzureChatOpenAI` instance to perform both sync and async operations.
The `azure_ad_token_provider` will work exactly as it is today, while
`azure_ad_async_token_provider` will override it for async requests.


If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17.
  • Loading branch information
fedeoliv authored Oct 22, 2024
1 parent 3468442 commit ab205e7
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 11 deletions.
16 changes: 14 additions & 2 deletions libs/community/langchain_community/chat_models/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os
import warnings
from typing import Any, Callable, Dict, List, Union
from typing import Any, Awaitable, Callable, Dict, List, Union

from langchain_core._api.deprecation import deprecated
from langchain_core.outputs import ChatResult
Expand Down Expand Up @@ -90,7 +90,13 @@ class AzureChatOpenAI(ChatOpenAI):
azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
"""
model_version: str = ""
"""Legacy, for openai<1.0.0 support."""
Expand Down Expand Up @@ -208,6 +214,12 @@ def validate_environment(cls, values: Dict) -> Dict:
"http_client": values["http_client"],
}
values["client"] = openai.AzureOpenAI(**client_params).chat.completions

azure_ad_async_token_provider = values["azure_ad_async_token_provider"]

if azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = azure_ad_async_token_provider

values["async_client"] = openai.AsyncAzureOpenAI(
**client_params
).chat.completions
Expand Down
16 changes: 14 additions & 2 deletions libs/community/langchain_community/embeddings/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
import warnings
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Union

from langchain_core._api.deprecation import deprecated
from langchain_core.utils import get_from_dict_or_env
Expand Down Expand Up @@ -49,7 +49,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
"""
openai_api_version: Optional[str] = Field(default=None, alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
Expand Down Expand Up @@ -162,6 +168,12 @@ def post_init_validator(self) -> Self:
"http_client": self.http_client,
}
self.client = openai.AzureOpenAI(**client_params).embeddings

if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)

self.async_client = openai.AsyncAzureOpenAI(**client_params).embeddings
else:
self.client = openai.Embedding
Expand Down
15 changes: 14 additions & 1 deletion libs/community/langchain_community/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AbstractSet,
Any,
AsyncIterator,
Awaitable,
Callable,
Collection,
Dict,
Expand Down Expand Up @@ -804,7 +805,13 @@ class AzureOpenAI(BaseOpenAI):
azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
"""
openai_api_type: str = ""
"""Legacy, for openai<1.0.0 support."""
Expand Down Expand Up @@ -922,6 +929,12 @@ def validate_environment(cls, values: Dict) -> Dict:
"http_client": values["http_client"],
}
values["client"] = openai.AzureOpenAI(**client_params).completions

azure_ad_async_token_provider = values["azure_ad_async_token_provider"]

if azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = azure_ad_async_token_provider

values["async_client"] = openai.AsyncAzureOpenAI(
**client_params
).completions
Expand Down
28 changes: 26 additions & 2 deletions libs/partners/openai/langchain_openai/chat_models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@

import logging
import os
from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, TypeVar, Union
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
TypedDict,
TypeVar,
Union,
)

import openai
from langchain_core.language_models.chat_models import LangSmithParams
Expand Down Expand Up @@ -494,7 +505,14 @@ class Joke(BaseModel):
azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""

azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
"""

model_version: str = ""
Expand Down Expand Up @@ -633,6 +651,12 @@ def validate_environment(self) -> Self:
self.client = self.root_client.chat.completions
if not self.async_client:
async_specific = {"http_client": self.http_async_client}

if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)

self.root_async_client = openai.AsyncAzureOpenAI(
**client_params,
**async_specific, # type: ignore[arg-type]
Expand Down
16 changes: 14 additions & 2 deletions libs/partners/openai/langchain_openai/embeddings/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Callable, Optional, Union
from typing import Awaitable, Callable, Optional, Union

import openai
from langchain_core.utils import from_env, secret_from_env
Expand Down Expand Up @@ -146,7 +146,13 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
"""
openai_api_type: Optional[str] = Field(
default_factory=from_env("OPENAI_API_TYPE", default="azure")
Expand Down Expand Up @@ -203,6 +209,12 @@ def validate_environment(self) -> Self:
).embeddings
if not self.async_client:
async_specific: dict = {"http_client": self.http_async_client}

if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)

self.async_client = openai.AsyncAzureOpenAI(
**client_params, # type: ignore[arg-type]
**async_specific,
Expand Down
16 changes: 14 additions & 2 deletions libs/partners/openai/langchain_openai/llms/azure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union

import openai
from langchain_core.language_models import LangSmithParams
Expand Down Expand Up @@ -73,7 +73,13 @@ class AzureOpenAI(BaseOpenAI):
azure_ad_token_provider: Union[Callable[[], str], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every request.
Will be invoked on every sync request. For async requests,
will be invoked if `azure_ad_async_token_provider` is not provided.
"""
azure_ad_async_token_provider: Union[Callable[[], Awaitable[str]], None] = None
"""A function that returns an Azure Active Directory token.
Will be invoked on every async request.
"""
openai_api_type: Optional[str] = Field(
default_factory=from_env("OPENAI_API_TYPE", default="azure")
Expand Down Expand Up @@ -158,6 +164,12 @@ def validate_environment(self) -> Self:
).completions
if not self.async_client:
async_specific = {"http_client": self.http_async_client}

if self.azure_ad_async_token_provider:
client_params["azure_ad_token_provider"] = (
self.azure_ad_async_token_provider
)

self.async_client = openai.AsyncAzureOpenAI(
**client_params,
**async_specific, # type: ignore[arg-type]
Expand Down

0 comments on commit ab205e7

Please sign in to comment.