diff --git a/src/backend/bisheng/utils/embedding.py b/src/backend/bisheng/utils/embedding.py index 50457a89d..85105b3f0 100644 --- a/src/backend/bisheng/utils/embedding.py +++ b/src/backend/bisheng/utils/embedding.py @@ -1,10 +1,13 @@ import httpx -from bisheng.settings import settings -from bisheng_langchain.embeddings import CustomHostEmbedding, HostEmbeddings + +from langchain_openai import AzureOpenAIEmbeddings from langchain.embeddings.base import Embeddings from langchain_community.utils.openai import is_openai_v1 from langchain_openai.embeddings import OpenAIEmbeddings +from bisheng.settings import settings +from bisheng_langchain.embeddings import CustomHostEmbedding, HostEmbeddings + def decide_embeddings(model: str) -> Embeddings: """embed method""" @@ -15,7 +18,10 @@ def decide_embeddings(model: str) -> Embeddings: if is_openai_v1() and params.get('openai_proxy'): params['http_client'] = httpx.Client(proxies=params.get('openai_proxy')) params['http_async_client'] = httpx.AsyncClient(proxies=params.get('openai_proxy')) - return OpenAIEmbeddings(**params) + if params.get('openai_api_type') in ("azure", "azure_ad", "azuread"): + return AzureOpenAIEmbeddings(**params) + else: + return OpenAIEmbeddings(**params) elif component == 'custom': return CustomHostEmbedding(**params) else: