From 0339ba1c0781433fefeb963835880eb96007a3f7 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 15 Aug 2024 21:50:34 -0400 Subject: [PATCH 01/12] update --- .../langchain_openai/embeddings/azure.py | 73 ++++++++----------- .../langchain_openai/embeddings/base.py | 62 +++++++--------- .../openai/langchain_openai/llms/base.py | 2 +- 3 files changed, 58 insertions(+), 79 deletions(-) diff --git a/libs/partners/openai/langchain_openai/embeddings/azure.py b/libs/partners/openai/langchain_openai/embeddings/azure.py index 79b90b95f282e..1d4d20b3d6156 100644 --- a/libs/partners/openai/langchain_openai/embeddings/azure.py +++ b/libs/partners/openai/langchain_openai/embeddings/azure.py @@ -2,12 +2,14 @@ from __future__ import annotations -import os from typing import Callable, Dict, Optional, Union import openai from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + from_env, + secret_from_env, +) from langchain_openai.embeddings.base import OpenAIEmbeddings @@ -100,7 +102,9 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): [-0.009100092574954033, 0.005071679595857859, -0.0029193938244134188] """ # noqa: E501 - azure_endpoint: Union[str, None] = None + azure_endpoint: Optional[str] = Field( + default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None) + ) """Your Azure endpoint, including the resource. Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided. @@ -113,9 +117,26 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): If given sets the base client URL to include `/deployments/{azure_deployment}`. Note: this means you won't be able to use non-deployment endpoints. """ - openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") + # Check OPENAI_KEY for backwards compatibility. + # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using + # other forms of azure credentials. + openai_api_key: Optional[SecretStr] = Field( + alias="api_key", + default_factory=secret_from_env( + ["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None + ), + ) """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided.""" - azure_ad_token: Optional[SecretStr] = None + open_api_version = Field( + default_factory=from_env("OPENAI_API_VERSION", default="2023-05-15") + ) + """Automatically inferred from env var `OPENAI_API_VERSION` if not provided. + + Set to "2023-05-15" by default if env variable `OPENAI_API_VERSION` is not set. + """ + azure_ad_token: Optional[SecretStr] = Field( + default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None) + ) """Your Azure Active Directory token. Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. @@ -128,52 +149,18 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): Will be invoked on every request. """ + openai_api_type: Optional[str] = Field( + default_factory=from_env("OPENAI_API_TYPE", default="azure") + ) openai_api_version: Optional[str] = Field(default=None, alias="api_version") """Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" validate_base_url: bool = True chunk_size: int = 2048 """Maximum number of texts to embed in each batch""" - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - # Check OPENAI_KEY for backwards compatibility. - # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using - # other forms of azure credentials. - openai_api_key = ( - values["openai_api_key"] - or os.getenv("AZURE_OPENAI_API_KEY") - or os.getenv("OPENAI_API_KEY") - ) - values["openai_api_key"] = ( - convert_to_secret_str(openai_api_key) if openai_api_key else None - ) - values["openai_api_base"] = ( - values["openai_api_base"] - if "openai_api_base" in values - else os.getenv("OPENAI_API_BASE") - ) - values["openai_api_version"] = values["openai_api_version"] or os.getenv( - "OPENAI_API_VERSION", default="2023-05-15" - ) - values["openai_api_type"] = get_from_dict_or_env( - values, "openai_api_type", "OPENAI_API_TYPE", default="azure" - ) - values["openai_organization"] = ( - values["openai_organization"] - or os.getenv("OPENAI_ORG_ID") - or os.getenv("OPENAI_ORGANIZATION") - ) - values["openai_proxy"] = get_from_dict_or_env( - values, "openai_proxy", "OPENAI_PROXY", default="" - ) - values["azure_endpoint"] = values["azure_endpoint"] or os.getenv( - "AZURE_OPENAI_ENDPOINT" - ) - azure_ad_token = values["azure_ad_token"] or os.getenv("AZURE_OPENAI_AD_TOKEN") - values["azure_ad_token"] = ( - convert_to_secret_str(azure_ad_token) if azure_ad_token else None - ) # For backwards compatibility. Before openai v1, no distinction was made # between azure_endpoint and base_url (openai_api_base). openai_api_base = values["openai_api_base"] diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 1e78302a9e42d..7e48b6325d295 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -23,9 +23,9 @@ from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.utils import ( - convert_to_secret_str, - get_from_dict_or_env, + from_env, get_pydantic_field_names, + secret_from_env, ) logger = logging.getLogger(__name__) @@ -188,18 +188,31 @@ class OpenAIEmbeddings(BaseModel, Embeddings): openai_api_version: Optional[str] = Field(default=None, alias="api_version") """Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" # to support Azure OpenAI Service custom endpoints - openai_api_base: Optional[str] = Field(default=None, alias="base_url") + openai_api_base: Optional[str] = Field( + alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None) + ) """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" # to support Azure OpenAI Service custom endpoints - openai_api_type: Optional[str] = None + openai_api_type: Optional[str] = Field( + default_factory=from_env("OPENAI_API_TYPE", default=None) + ) # to support explicit proxy for OpenAI - openai_proxy: Optional[str] = None + openai_proxy: Optional[str] = Field( + default_factory=from_env("OPENAI_PROXY", default=None) + ) embedding_ctx_length: int = 8191 """The maximum number of tokens to embed at once.""" - openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") + openai_api_key: Optional[SecretStr] = Field( + alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) + ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" - openai_organization: Optional[str] = Field(default=None, alias="organization") + openai_organization: Optional[str] = Field( + alias="organization", + default_factory=from_env( + ["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None + ), + ) """Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" allowed_special: Union[Literal["all"], Set[str], None] = None disallowed_special: Union[Literal["all"], Set[str], Sequence[str], None] = None @@ -284,24 +297,9 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["model_kwargs"] = extra return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - openai_api_key = get_from_dict_or_env( - values, "openai_api_key", "OPENAI_API_KEY" - ) - values["openai_api_key"] = ( - convert_to_secret_str(openai_api_key) if openai_api_key else None - ) - values["openai_api_base"] = values["openai_api_base"] or os.getenv( - "OPENAI_API_BASE" - ) - values["openai_api_type"] = get_from_dict_or_env( - values, "openai_api_type", "OPENAI_API_TYPE", default="" - ) - values["openai_proxy"] = get_from_dict_or_env( - values, "openai_proxy", "OPENAI_PROXY", default="" - ) if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): default_api_version = "2023-05-15" # Azure OpenAI embedding models allow a maximum of 16 texts @@ -310,18 +308,12 @@ def validate_environment(cls, values: Dict) -> Dict: values["chunk_size"] = min(values["chunk_size"], 16) else: default_api_version = "" - values["openai_api_version"] = get_from_dict_or_env( - values, - "openai_api_version", - "OPENAI_API_VERSION", - default=default_api_version, - ) - # Check OPENAI_ORGANIZATION for backwards compatibility. - values["openai_organization"] = ( - values["openai_organization"] - or os.getenv("OPENAI_ORG_ID") - or os.getenv("OPENAI_ORGANIZATION") - ) + + if values["openai_api_version"] is None: + values["openai_api_version"] = os.getenv( + "OPENAI_API_VERSION", default=default_api_version + ) + if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): raise ValueError( "If you are using Azure, " diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index 0584b52a751fa..78e83cc0ad860 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -161,7 +161,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: ) return values - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: From 2b286e879c9c27d822e76cec90331c6f4b775725 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 15 Aug 2024 22:02:06 -0400 Subject: [PATCH 02/12] update --- .../langchain_openai/chat_models/azure.py | 75 ++++++++----------- .../langchain_openai/embeddings/azure.py | 5 +- .../langchain_openai/embeddings/base.py | 6 +- 3 files changed, 34 insertions(+), 52 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index eaf31a56a33e6..127558f03286c 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -34,7 +34,10 @@ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import ( + from_env, + secret_from_env, +) from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.pydantic import is_basemodel_subclass @@ -474,10 +477,13 @@ class Joke(BaseModel): } """ # noqa: E501 - azure_endpoint: Union[str, None] = None + azure_endpoint: Optional[str] = Field( + default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None) + ) """Your Azure endpoint, including the resource. - + Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided. + Example: `https://example-resource.azure.openai.com/` """ deployment_name: Union[str, None] = Field(default=None, alias="azure_deployment") @@ -486,15 +492,29 @@ class Joke(BaseModel): If given sets the base client URL to include `/deployments/{azure_deployment}`. Note: this means you won't be able to use non-deployment endpoints. """ - openai_api_version: str = Field(default="", alias="api_version") + openai_api_version: Optional[str] = Field( + alias="api_version", + default_factory=from_env("OPENAI_API_VERSION", default=None), + ) """Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" - openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") + # Check OPENAI_KEY for backwards compatibility. + # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using + # other forms of azure credentials. + openai_api_key: Optional[SecretStr] = Field( + alias="api_key", + default_factory=secret_from_env( + ["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None + ), + ) """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided.""" - azure_ad_token: Optional[SecretStr] = None + azure_ad_token: Optional[SecretStr] = Field( + default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None) + ) """Your Azure Active Directory token. - + Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. - For more: + + For more: https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id. """ azure_ad_token_provider: Union[Callable[[], str], None] = None @@ -516,7 +536,9 @@ class Joke(BaseModel): correct cost. """ - openai_api_type: str = "" + openai_api_type: Optional[str] = Field( + default_factory=from_env("OPENAI_API_TYPE", default="azure") + ) """Legacy, for openai<1.0.0 support.""" validate_base_url: bool = True """If legacy arg openai_api_base is passed in, try to infer if it is a base_url or @@ -546,7 +568,7 @@ def lc_secrets(self) -> Dict[str, str]: def is_lc_serializable(cls) -> bool: return True - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: @@ -554,45 +576,12 @@ def validate_environment(cls, values: Dict) -> Dict: if values["n"] > 1 and values["streaming"]: raise ValueError("n must be 1 when streaming.") - # Check OPENAI_KEY for backwards compatibility. - # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using - # other forms of azure credentials. - openai_api_key = ( - values["openai_api_key"] - or os.getenv("AZURE_OPENAI_API_KEY") - or os.getenv("OPENAI_API_KEY") - ) - values["openai_api_key"] = ( - convert_to_secret_str(openai_api_key) if openai_api_key else None - ) - values["openai_api_base"] = ( - values["openai_api_base"] - if "openai_api_base" in values - else os.getenv("OPENAI_API_BASE") - ) - values["openai_api_version"] = values["openai_api_version"] or os.getenv( - "OPENAI_API_VERSION" - ) # Check OPENAI_ORGANIZATION for backwards compatibility. values["openai_organization"] = ( values["openai_organization"] or os.getenv("OPENAI_ORG_ID") or os.getenv("OPENAI_ORGANIZATION") ) - values["azure_endpoint"] = values["azure_endpoint"] or os.getenv( - "AZURE_OPENAI_ENDPOINT" - ) - azure_ad_token = values["azure_ad_token"] or os.getenv("AZURE_OPENAI_AD_TOKEN") - values["azure_ad_token"] = ( - convert_to_secret_str(azure_ad_token) if azure_ad_token else None - ) - - values["openai_api_type"] = get_from_dict_or_env( - values, "openai_api_type", "OPENAI_API_TYPE", default="azure" - ) - values["openai_proxy"] = get_from_dict_or_env( - values, "openai_proxy", "OPENAI_PROXY", default="" - ) # For backwards compatibility. Before openai v1, no distinction was made # between azure_endpoint and base_url (openai_api_base). openai_api_base = values["openai_api_base"] diff --git a/libs/partners/openai/langchain_openai/embeddings/azure.py b/libs/partners/openai/langchain_openai/embeddings/azure.py index 1d4d20b3d6156..365c9dece26e8 100644 --- a/libs/partners/openai/langchain_openai/embeddings/azure.py +++ b/libs/partners/openai/langchain_openai/embeddings/azure.py @@ -6,10 +6,7 @@ import openai from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import ( - from_env, - secret_from_env, -) +from langchain_core.utils import from_env, secret_from_env from langchain_openai.embeddings.base import OpenAIEmbeddings diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index 7e48b6325d295..eed868c7abc29 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -22,11 +22,7 @@ import tiktoken from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator -from langchain_core.utils import ( - from_env, - get_pydantic_field_names, - secret_from_env, -) +from langchain_core.utils import from_env, get_pydantic_field_names, secret_from_env logger = logging.getLogger(__name__) From 06b7aa713967823a0cedaa5216e01bb51e4d8d13 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Thu, 15 Aug 2024 22:02:39 -0400 Subject: [PATCH 03/12] x --- libs/core/langchain_core/utils/utils.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index 0956ea1dd90b6..89b7f1d7a7f16 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -344,6 +344,12 @@ def secret_from_env(key: str, /) -> Callable[[], SecretStr]: ... def secret_from_env(key: str, /, *, default: str) -> Callable[[], SecretStr]: ... +@overload +def secret_from_env( + key: Sequence[str], /, *, default: None +) -> Callable[[], Optional[SecretStr]]: ... + + @overload def secret_from_env( key: str, /, *, default: None @@ -355,7 +361,7 @@ def secret_from_env(key: str, /, *, error_message: str) -> Callable[[], SecretSt def secret_from_env( - key: str, + key: Union[str, Sequence[str]], /, *, default: Union[str, _NoDefaultType, None] = _NoDefault, @@ -376,9 +382,14 @@ def secret_from_env( def get_secret_from_env() -> Optional[SecretStr]: """Get a value from an environment variable.""" - if key in os.environ: - return SecretStr(os.environ[key]) - elif isinstance(default, str): + if isinstance(key, (list, tuple)): + for k in key: + if k in os.environ: + return SecretStr(os.environ[k]) + if isinstance(key, str): + if key in os.environ: + return SecretStr(os.environ[key]) + if isinstance(default, str): return SecretStr(default) elif isinstance(default, type(None)): return None From 07334e3a1f80e0667aef5312f66c308e5a277396 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 16 Aug 2024 10:49:18 -0400 Subject: [PATCH 04/12] update --- .../langchain_openai/chat_models/azure.py | 5 +- .../openai/langchain_openai/llms/azure.py | 72 +++++++------------ .../openai/langchain_openai/llms/base.py | 42 +++++------ 3 files changed, 42 insertions(+), 77 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index 127558f03286c..c08c8d0159cf5 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -34,10 +34,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool -from langchain_core.utils import ( - from_env, - secret_from_env, -) +from langchain_core.utils import from_env, secret_from_env from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.pydantic import is_basemodel_subclass diff --git a/libs/partners/openai/langchain_openai/llms/azure.py b/libs/partners/openai/langchain_openai/llms/azure.py index fc0f8e84b5f43..3d1a623fad929 100644 --- a/libs/partners/openai/langchain_openai/llms/azure.py +++ b/libs/partners/openai/langchain_openai/llms/azure.py @@ -1,12 +1,11 @@ from __future__ import annotations import logging -import os from typing import Any, Callable, Dict, List, Mapping, Optional, Union import openai from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env +from langchain_core.utils import from_env, secret_from_env from langchain_openai.llms.base import BaseOpenAI @@ -30,7 +29,9 @@ class AzureOpenAI(BaseOpenAI): openai = AzureOpenAI(model_name="gpt-3.5-turbo-instruct") """ - azure_endpoint: Union[str, None] = None + azure_endpoint: Optional[str] = Field( + default_factory=from_env("AZURE_OPENAI_ENDPOINT", default=None) + ) """Your Azure endpoint, including the resource. Automatically inferred from env var `AZURE_OPENAI_ENDPOINT` if not provided. @@ -43,16 +44,28 @@ class AzureOpenAI(BaseOpenAI): If given sets the base client URL to include `/deployments/{azure_deployment}`. Note: this means you won't be able to use non-deployment endpoints. """ - openai_api_version: str = Field(default="", alias="api_version") + openai_api_version: Optional[str] = Field( + alias="api_version", + default_factory=from_env("OPENAI_API_VERSION", default=None), + ) """Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" - openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") - """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided.""" - azure_ad_token: Optional[SecretStr] = None + # Check OPENAI_KEY for backwards compatibility. + # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using + # other forms of azure credentials. + openai_api_key: Optional[SecretStr] = Field( + alias="api_key", + default_factory=secret_from_env( + ["AZURE_OPENAI_API_KEY", "OPENAI_API_KEY"], default=None + ), + ) + azure_ad_token: Optional[SecretStr] = Field( + default_factory=secret_from_env("AZURE_OPENAI_AD_TOKEN", default=None) + ) """Your Azure Active Directory token. Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided. - For more: + For more: https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id. """ azure_ad_token_provider: Union[Callable[[], str], None] = None @@ -60,7 +73,9 @@ class AzureOpenAI(BaseOpenAI): Will be invoked on every request. """ - openai_api_type: str = "" + openai_api_type: Optional[str] = Field( + default_factory=from_env("OPENAI_API_TYPE", default="azure") + ) """Legacy, for openai<1.0.0 support.""" validate_base_url: bool = True """For backwards compatibility. If legacy val openai_api_base is passed in, try to @@ -84,7 +99,7 @@ def is_lc_serializable(cls) -> bool: """Return whether this model can be serialized by Langchain.""" return True - @root_validator() + @root_validator(pre=False, skip_on_failure=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: @@ -93,43 +108,6 @@ def validate_environment(cls, values: Dict) -> Dict: raise ValueError("Cannot stream results when n > 1.") if values["streaming"] and values["best_of"] > 1: raise ValueError("Cannot stream results when best_of > 1.") - - # Check OPENAI_KEY for backwards compatibility. - # TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using - # other forms of azure credentials. - openai_api_key = ( - values["openai_api_key"] - or os.getenv("AZURE_OPENAI_API_KEY") - or os.getenv("OPENAI_API_KEY") - ) - values["openai_api_key"] = ( - convert_to_secret_str(openai_api_key) if openai_api_key else None - ) - - values["azure_endpoint"] = values["azure_endpoint"] or os.getenv( - "AZURE_OPENAI_ENDPOINT" - ) - azure_ad_token = values["azure_ad_token"] or os.getenv("AZURE_OPENAI_AD_TOKEN") - values["azure_ad_token"] = ( - convert_to_secret_str(azure_ad_token) if azure_ad_token else None - ) - values["openai_api_base"] = values["openai_api_base"] or os.getenv( - "OPENAI_API_BASE" - ) - values["openai_proxy"] = get_from_dict_or_env( - values, "openai_proxy", "OPENAI_PROXY", default="" - ) - values["openai_organization"] = ( - values["openai_organization"] - or os.getenv("OPENAI_ORG_ID") - or os.getenv("OPENAI_ORGANIZATION") - ) - values["openai_api_version"] = values["openai_api_version"] or os.getenv( - "OPENAI_API_VERSION" - ) - values["openai_api_type"] = get_from_dict_or_env( - values, "openai_api_type", "OPENAI_API_TYPE", default="azure" - ) # For backwards compatibility. Before openai v1, no distinction was made # between azure_endpoint and base_url (openai_api_base). openai_api_base = values["openai_api_base"] diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index 78e83cc0ad860..f6ab6c1fd730b 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import os import sys from typing import ( AbstractSet, @@ -29,11 +28,9 @@ from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import ( - convert_to_secret_str, - get_from_dict_or_env, get_pydantic_field_names, ) -from langchain_core.utils.utils import build_extra_kwargs +from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env logger = logging.getLogger(__name__) @@ -90,15 +87,26 @@ class BaseOpenAI(BaseLLM): """Generates best_of completions server-side and returns the "best".""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" - openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") + openai_api_key: Optional[SecretStr] = Field( + alias="api_key", default_factory=secret_from_env("OPENAI_API_KEY", default=None) + ) """Automatically inferred from env var `OPENAI_API_KEY` if not provided.""" - openai_api_base: Optional[str] = Field(default=None, alias="base_url") + openai_api_base: Optional[str] = Field( + alias="base_url", default_factory=from_env("OPENAI_API_BASE", default=None) + ) """Base URL path for API requests, leave blank if not using a proxy or service emulator.""" - openai_organization: Optional[str] = Field(default=None, alias="organization") + openai_organization: Optional[str] = Field( + alias="organization", + default_factory=from_env( + ["OPENAI_ORG_ID", "OPENAI_ORGANIZATION"], default=None + ), + ) """Automatically inferred from env var `OPENAI_ORG_ID` if not provided.""" # to support explicit proxy for OpenAI - openai_proxy: Optional[str] = None + openai_proxy: Optional[str] = Field( + default_factory=from_env("OPENAI_PROXY", default=None) + ) batch_size: int = 20 """Batch size to use when passing multiple documents to generate.""" request_timeout: Union[float, Tuple[float, float], Any, None] = Field( @@ -171,24 +179,6 @@ def validate_environment(cls, values: Dict) -> Dict: if values["streaming"] and values["best_of"] > 1: raise ValueError("Cannot stream results when best_of > 1.") - openai_api_key = get_from_dict_or_env( - values, "openai_api_key", "OPENAI_API_KEY" - ) - values["openai_api_key"] = ( - convert_to_secret_str(openai_api_key) if openai_api_key else None - ) - values["openai_api_base"] = values["openai_api_base"] or os.getenv( - "OPENAI_API_BASE" - ) - values["openai_proxy"] = get_from_dict_or_env( - values, "openai_proxy", "OPENAI_PROXY", default="" - ) - values["openai_organization"] = ( - values["openai_organization"] - or os.getenv("OPENAI_ORG_ID") - or os.getenv("OPENAI_ORGANIZATION") - ) - client_params = { "api_key": ( values["openai_api_key"].get_secret_value() From 210b3c75f272d430ecba3bda05e04f5debd477a5 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 16 Aug 2024 11:01:13 -0400 Subject: [PATCH 05/12] fixes --- libs/partners/openai/langchain_openai/chat_models/base.py | 2 +- libs/partners/openai/langchain_openai/embeddings/azure.py | 4 +--- libs/partners/openai/langchain_openai/embeddings/base.py | 2 +- libs/partners/openai/langchain_openai/llms/azure.py | 2 +- libs/partners/openai/langchain_openai/llms/base.py | 6 ++---- 5 files changed, 6 insertions(+), 10 deletions(-) diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index d35ff75ecb65e..4f161f256b997 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -392,7 +392,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: ) return values - @root_validator(pre=False, skip_on_failure=True) + @root_validator(pre=False, skip_on_failure=True, allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: diff --git a/libs/partners/openai/langchain_openai/embeddings/azure.py b/libs/partners/openai/langchain_openai/embeddings/azure.py index 365c9dece26e8..3725a7662858f 100644 --- a/libs/partners/openai/langchain_openai/embeddings/azure.py +++ b/libs/partners/openai/langchain_openai/embeddings/azure.py @@ -124,7 +124,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): ), ) """Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided.""" - open_api_version = Field( + openai_api_version: Optional[str] = Field( default_factory=from_env("OPENAI_API_VERSION", default="2023-05-15") ) """Automatically inferred from env var `OPENAI_API_VERSION` if not provided. @@ -149,8 +149,6 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): openai_api_type: Optional[str] = Field( default_factory=from_env("OPENAI_API_TYPE", default="azure") ) - openai_api_version: Optional[str] = Field(default=None, alias="api_version") - """Automatically inferred from env var `OPENAI_API_VERSION` if not provided.""" validate_base_url: bool = True chunk_size: int = 2048 """Maximum number of texts to embed in each batch""" diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index eed868c7abc29..132d4577262c4 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -293,7 +293,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["model_kwargs"] = extra return values - @root_validator(pre=False, skip_on_failure=True) + @root_validator(pre=False, skip_on_failure=True, allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): diff --git a/libs/partners/openai/langchain_openai/llms/azure.py b/libs/partners/openai/langchain_openai/llms/azure.py index 3d1a623fad929..1c2da6676f090 100644 --- a/libs/partners/openai/langchain_openai/llms/azure.py +++ b/libs/partners/openai/langchain_openai/llms/azure.py @@ -99,7 +99,7 @@ def is_lc_serializable(cls) -> bool: """Return whether this model can be serialized by Langchain.""" return True - @root_validator(pre=False, skip_on_failure=True) + @root_validator(pre=False, skip_on_failure=True, allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index f6ab6c1fd730b..464b40e2ba919 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -27,9 +27,7 @@ from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.pydantic_v1 import Field, SecretStr, root_validator -from langchain_core.utils import ( - get_pydantic_field_names, -) +from langchain_core.utils import get_pydantic_field_names from langchain_core.utils.utils import build_extra_kwargs, from_env, secret_from_env logger = logging.getLogger(__name__) @@ -169,7 +167,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: ) return values - @root_validator(pre=False, skip_on_failure=True) + @root_validator(pre=False, skip_on_failure=True, allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" if values["n"] < 1: From ddcef20593fb11f0b7c231040d1be26f39d85427 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 16 Aug 2024 11:13:52 -0400 Subject: [PATCH 06/12] x --- .../unit_tests/chat_models/test_azure.py | 41 ++++++++++--------- .../embeddings/test_azure_embeddings.py | 23 ++++++----- 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py b/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py index 601fb69cd5c9f..c00b1984f5c7b 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py @@ -1,6 +1,7 @@ """Test Azure OpenAI Chat API wrapper.""" import os +from unittest import mock from langchain_openai import AzureChatOpenAI @@ -39,22 +40,24 @@ def test_initialize_more() -> None: def test_initialize_azure_openai_with_openai_api_base_set() -> None: - os.environ["OPENAI_API_BASE"] = "https://api.openai.com" - llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg] - api_key="xyz", # type: ignore[arg-type] - azure_endpoint="my-base-url", - azure_deployment="35-turbo-dev", - openai_api_version="2023-05-15", - temperature=0, - openai_api_base=None, - ) - assert llm.openai_api_key is not None - assert llm.openai_api_key.get_secret_value() == "xyz" - assert llm.azure_endpoint == "my-base-url" - assert llm.deployment_name == "35-turbo-dev" - assert llm.openai_api_version == "2023-05-15" - assert llm.temperature == 0 - - ls_params = llm._get_ls_params() - assert ls_params["ls_provider"] == "azure" - assert ls_params["ls_model_name"] == "35-turbo-dev" + with mock.patch.dict( + os.environ, {"OPENAI_API_BASE": "https://api.openai.com"} + ): + llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg] + api_key="xyz", # type: ignore[arg-type] + azure_endpoint="my-base-url", + azure_deployment="35-turbo-dev", + openai_api_version="2023-05-15", + temperature=0, + openai_api_base=None, + ) + assert llm.openai_api_key is not None + assert llm.openai_api_key.get_secret_value() == "xyz" + assert llm.azure_endpoint == "my-base-url" + assert llm.deployment_name == "35-turbo-dev" + assert llm.openai_api_version == "2023-05-15" + assert llm.temperature == 0 + + ls_params = llm._get_ls_params() + assert ls_params["ls_provider"] == "azure" + assert ls_params["ls_model_name"] == "35-turbo-dev" diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_azure_embeddings.py b/libs/partners/openai/tests/unit_tests/embeddings/test_azure_embeddings.py index 869b45176cb48..7550817f290a3 100644 --- a/libs/partners/openai/tests/unit_tests/embeddings/test_azure_embeddings.py +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_azure_embeddings.py @@ -1,4 +1,5 @@ import os +from unittest import mock from langchain_openai import AzureOpenAIEmbeddings @@ -15,13 +16,15 @@ def test_initialize_azure_openai() -> None: def test_intialize_azure_openai_with_base_set() -> None: - os.environ["OPENAI_API_BASE"] = "https://api.openai.com" - embeddings = AzureOpenAIEmbeddings( # type: ignore[call-arg, call-arg] - model="text-embedding-large", - api_key="xyz", # type: ignore[arg-type] - azure_endpoint="my-base-url", - azure_deployment="35-turbo-dev", - openai_api_version="2023-05-15", - openai_api_base=None, - ) - assert embeddings.model == "text-embedding-large" + with mock.patch.dict( + os.environ, {"OPENAI_API_BASE": "https://api.openai.com"} + ): + embeddings = AzureOpenAIEmbeddings( # type: ignore[call-arg, call-arg] + model="text-embedding-large", + api_key="xyz", # type: ignore[arg-type] + azure_endpoint="my-base-url", + azure_deployment="35-turbo-dev", + openai_api_version="2023-05-15", + openai_api_base=None, + ) + assert embeddings.model == "text-embedding-large" From 6c8b040ed895d711c4f99f0fb1862d4ce6b92f5a Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 16 Aug 2024 12:16:50 -0400 Subject: [PATCH 07/12] x --- libs/core/langchain_core/utils/utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index dd92d0a947ba7..f60fea5efcbe0 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -359,13 +359,7 @@ def secret_from_env(key: str, /, *, default: str) -> Callable[[], SecretStr]: .. @overload def secret_from_env( - key: Sequence[str], /, *, default: None -) -> Callable[[], Optional[SecretStr]]: ... - - -@overload -def secret_from_env( - key: str, /, *, default: None + key: Union[str, Sequence[str]], /, *, default: None ) -> Callable[[], Optional[SecretStr]]: ... From 1645340680270cc8cdc6f8ffa3e1c7d76c87b51a Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 16 Aug 2024 12:48:54 -0400 Subject: [PATCH 08/12] x --- libs/core/langchain_core/utils/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/utils/utils.py b/libs/core/langchain_core/utils/utils.py index f60fea5efcbe0..57e834e78ba56 100644 --- a/libs/core/langchain_core/utils/utils.py +++ b/libs/core/langchain_core/utils/utils.py @@ -301,7 +301,9 @@ def from_env( @overload -def from_env(key: str, /, *, default: None) -> Callable[[], Optional[str]]: ... +def from_env( + key: Union[str, Sequence[str]], /, *, default: None +) -> Callable[[], Optional[str]]: ... def from_env( From ac2dfc2e269d6d0eec52624f270e3cef62448302 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Fri, 16 Aug 2024 14:56:57 -0400 Subject: [PATCH 09/12] xx --- .../openai/tests/unit_tests/chat_models/test_azure.py | 4 +--- .../tests/unit_tests/embeddings/test_azure_embeddings.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py b/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py index c00b1984f5c7b..b57768c8ba003 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_azure.py @@ -40,9 +40,7 @@ def test_initialize_more() -> None: def test_initialize_azure_openai_with_openai_api_base_set() -> None: - with mock.patch.dict( - os.environ, {"OPENAI_API_BASE": "https://api.openai.com"} - ): + with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}): llm = AzureChatOpenAI( # type: ignore[call-arg, call-arg] api_key="xyz", # type: ignore[arg-type] azure_endpoint="my-base-url", diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_azure_embeddings.py b/libs/partners/openai/tests/unit_tests/embeddings/test_azure_embeddings.py index 7550817f290a3..8ce50eaaf6081 100644 --- a/libs/partners/openai/tests/unit_tests/embeddings/test_azure_embeddings.py +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_azure_embeddings.py @@ -16,9 +16,7 @@ def test_initialize_azure_openai() -> None: def test_intialize_azure_openai_with_base_set() -> None: - with mock.patch.dict( - os.environ, {"OPENAI_API_BASE": "https://api.openai.com"} - ): + with mock.patch.dict(os.environ, {"OPENAI_API_BASE": "https://api.openai.com"}): embeddings = AzureOpenAIEmbeddings( # type: ignore[call-arg, call-arg] model="text-embedding-large", api_key="xyz", # type: ignore[arg-type] From c0365d3e9ec1e8024a007e7686c35a072a00c142 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 3 Sep 2024 14:30:04 -0700 Subject: [PATCH 10/12] fmt --- .../chat_models/test_azure_standard.py | 24 ++++++++++++- .../chat_models/test_base_standard.py | 2 +- .../embeddings/test_azure_standard.py | 34 +++++++++++++++++++ .../embeddings/test_base_standard.py | 32 +++++++++++++++++ 4 files changed, 90 insertions(+), 2 deletions(-) create mode 100644 libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py create mode 100644 libs/partners/openai/tests/unit_tests/embeddings/test_base_standard.py diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py b/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py index 0465fcbc7a421..465c1dc0c22a5 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_azure_standard.py @@ -1,6 +1,6 @@ """Standard LangChain interface tests""" -from typing import Type +from typing import Tuple, Type import pytest from langchain_core.language_models import BaseChatModel @@ -25,3 +25,25 @@ def chat_model_params(self) -> dict: @pytest.mark.xfail(reason="AzureOpenAI does not support tool_choice='any'") def test_bind_tool_pydantic(self, model: BaseChatModel) -> None: super().test_bind_tool_pydantic(model) + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + return ( + { + "AZURE_OPENAI_API_KEY": "api_key", + "AZURE_OPENAI_ENDPOINT": "https://endpoint.com", + "AZURE_OPENAI_AD_TOKEN": "token", + "OPENAI_ORG_ID": "org_id", + "OPENAI_API_VERSION": "yyyy-mm-dd", + "OPENAI_API_TYPE": "type", + }, + {}, + { + "openai_api_key": "api_key", + "azure_endpoint": "https://endpoint.com", + "azure_ad_token": "token", + "openai_organization": "org_id", + "openai_api_version": "yyyy-mm-dd", + "openai_api_type": "type", + }, + ) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py index 03a1fc734afb4..8049da874cbf3 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base_standard.py @@ -18,7 +18,7 @@ def init_from_env_params(self) -> Tuple[dict, dict, dict]: return ( { "OPENAI_API_KEY": "api_key", - "OPENAI_ORGANIZATION": "org_id", + "OPENAI_ORG_ID": "org_id", "OPENAI_API_BASE": "api_base", "OPENAI_PROXY": "https://proxy.com", }, diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py b/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py new file mode 100644 index 0000000000000..c8f29a2962561 --- /dev/null +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py @@ -0,0 +1,34 @@ +from typing import Tuple, Type + +from langchain_core.embeddings import Embeddings +from langchain_standard_tests.unit_tests.embeddings import EmbeddingsUnitTests + +from langchain_openai import AzureOpenAIEmbeddings + + +class TestAzureOpenAIStandard(EmbeddingsUnitTests): + @property + def embeddings_class(self) -> Type[Embeddings]: + return AzureOpenAIEmbeddings + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + return ( + { + "AZURE_OPENAI_API_KEY": "api_key", + "AZURE_OPENAI_ENDPOINT": "https://endpoint.com", + "AZURE_OPENAI_AD_TOKEN": "token", + "OPENAI_ORG_ID": "org_id", + "OPENAI_API_VERSION": "yyyy-mm-dd", + "OPENAI_API_TYPE": "type", + }, + {}, + { + "openai_api_key": "api_key", + "azure_endpoint": "https://endpoint.com", + "azure_ad_token": "token", + "openai_organization": "org_id", + "openai_api_version": "yyyy-mm-dd", + "openai_api_type": "type", + }, + ) diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_base_standard.py b/libs/partners/openai/tests/unit_tests/embeddings/test_base_standard.py new file mode 100644 index 0000000000000..b265e5600eb35 --- /dev/null +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_base_standard.py @@ -0,0 +1,32 @@ +"""Standard LangChain interface tests""" + +from typing import Tuple, Type + +from langchain_core.embeddings import Embeddings +from langchain_standard_tests.unit_tests.embeddings import EmbeddingsUnitTests + +from langchain_openai import OpenAIEmbeddings + + +class TestOpenAIStandard(EmbeddingsUnitTests): + @property + def embeddings_class(self) -> Type[Embeddings]: + return OpenAIEmbeddings + + @property + def init_from_env_params(self) -> Tuple[dict, dict, dict]: + return ( + { + "OPENAI_API_KEY": "api_key", + "OPENAI_ORG_ID": "org_id", + "OPENAI_API_BASE": "api_base", + "OPENAI_PROXY": "https://proxy.com", + }, + {}, + { + "openai_api_key": "api_key", + "openai_organization": "org_id", + "openai_api_base": "api_base", + "openai_proxy": "https://proxy.com", + }, + ) From 75dad49bf4f4e5645fc7a8c2eb887555d7833b3d Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 3 Sep 2024 14:38:32 -0700 Subject: [PATCH 11/12] fmt --- .../openai/tests/unit_tests/embeddings/test_azure_standard.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py b/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py index c8f29a2962561..187cd9e6595fe 100644 --- a/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py @@ -11,6 +11,10 @@ class TestAzureOpenAIStandard(EmbeddingsUnitTests): def embeddings_class(self) -> Type[Embeddings]: return AzureOpenAIEmbeddings + @property + def embedding_model_params(self) -> dict: + return {"azure_endpoint": "https://endpoint.com"} + @property def init_from_env_params(self) -> Tuple[dict, dict, dict]: return ( From 31cd1f36a26769eb34408972b10fb1432702d51b Mon Sep 17 00:00:00 2001 From: Bagatur Date: Tue, 3 Sep 2024 14:38:49 -0700 Subject: [PATCH 12/12] fmt --- .../openai/tests/unit_tests/embeddings/test_azure_standard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py b/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py index 187cd9e6595fe..b5f1591c476ca 100644 --- a/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py +++ b/libs/partners/openai/tests/unit_tests/embeddings/test_azure_standard.py @@ -13,7 +13,7 @@ def embeddings_class(self) -> Type[Embeddings]: @property def embedding_model_params(self) -> dict: - return {"azure_endpoint": "https://endpoint.com"} + return {"api_key": "api_key", "azure_endpoint": "https://endpoint.com"} @property def init_from_env_params(self) -> Tuple[dict, dict, dict]: