From 702d7b4f27b680c77fc5a1a9e7498361c864540d Mon Sep 17 00:00:00 2001 From: dusens <35623865+dusens@users.noreply.github.com> Date: Fri, 9 Aug 2024 16:19:17 +0800 Subject: [PATCH 1/2] Update .env.template Missing proxy_tongyi_proxy_api_key={your-api-key} --- .env.template | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.env.template b/.env.template index f90af90ee..8b0ae5394 100644 --- a/.env.template +++ b/.env.template @@ -96,6 +96,7 @@ KNOWLEDGE_SEARCH_REWRITE=False ## qwen embedding model, See dbgpt/model/parameter.py # EMBEDDING_MODEL=proxy_tongyi # proxy_tongyi_proxy_backend=text-embedding-v1 +# proxy_tongyi_proxy_api_key={your-api-key} ## Common HTTP embedding model # EMBEDDING_MODEL=proxy_http_openapi @@ -297,4 +298,4 @@ DBGPT_LOG_LEVEL=INFO #*******************************************************************# #** FINANCIAL CHAT Config **# #*******************************************************************# -# FIN_REPORT_MODEL=/app/models/bge-large-zh \ No newline at end of file +# FIN_REPORT_MODEL=/app/models/bge-large-zh From ed86e04889072ff0545e4c7a6b9e03650263eba1 Mon Sep 17 00:00:00 2001 From: dusensen <1025130869@qq.com> Date: Tue, 13 Aug 2024 15:19:25 +0800 Subject: [PATCH 2/2] feat(model): Support qianfan embedding --- .env.template | 7 ++ dbgpt/configs/model_config.py | 1 + dbgpt/model/adapter/embeddings_loader.py | 8 ++ dbgpt/model/parameter.py | 9 ++- dbgpt/rag/embedding/__init__.py | 2 + dbgpt/rag/embedding/embeddings.py | 97 ++++++++++++++++++++++++ setup.py | 1 + 7 files changed, 124 insertions(+), 1 deletion(-) diff --git a/.env.template b/.env.template index 8b0ae5394..c20cb012c 100644 --- a/.env.template +++ b/.env.template @@ -98,6 +98,13 @@ KNOWLEDGE_SEARCH_REWRITE=False # proxy_tongyi_proxy_backend=text-embedding-v1 # proxy_tongyi_proxy_api_key={your-api-key} +## qianfan embedding model, See dbgpt/model/parameter.py +#EMBEDDING_MODEL=proxy_qianfan +#proxy_qianfan_proxy_backend=bge-large-zh +#proxy_qianfan_proxy_api_key={your-api-key} +#proxy_qianfan_proxy_api_secret={your-secret-key} + + ## Common HTTP embedding model # EMBEDDING_MODEL=proxy_http_openapi # proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index 4d02a2730..e4a05f7ae 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -289,6 +289,7 @@ def get_device() -> str: "proxy_http_openapi": "proxy_http_openapi", "proxy_ollama": "proxy_ollama", "proxy_tongyi": "proxy_tongyi", + "proxy_qianfan": "proxy_qianfan", # Rerank model, rerank mode is a special embedding model "bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"), "bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"), diff --git a/dbgpt/model/adapter/embeddings_loader.py b/dbgpt/model/adapter/embeddings_loader.py index b10cf34af..cbc504fdf 100644 --- a/dbgpt/model/adapter/embeddings_loader.py +++ b/dbgpt/model/adapter/embeddings_loader.py @@ -58,6 +58,14 @@ def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddin if proxy_param.proxy_backend: tongyi_param["model_name"] = proxy_param.proxy_backend return TongYiEmbeddings(**tongyi_param) + elif model_name in ["proxy_qianfan"]: + from dbgpt.rag.embedding import QianFanEmbeddings + proxy_param = cast(ProxyEmbeddingParameters, param) + qianfan_param = {"api_key": proxy_param.proxy_api_key} + if proxy_param.proxy_backend: + qianfan_param["model_name"] = proxy_param.proxy_backend + qianfan_param["api_secret"] = proxy_param.proxy_api_secret + return QianFanEmbeddings(**qianfan_param) elif model_name in ["proxy_ollama"]: from dbgpt.rag.embedding import OllamaEmbeddings diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index f0a2974c8..8debe18fa 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -558,6 +558,13 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters): "help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure" }, ) + proxy_api_secret: str = field( + default=None, + metadata={ + "tags": "privacy", + "help": "The api secret of the current embedding model(OPENAI_API_SECRET)", + } + ) proxy_api_version: Optional[str] = field( default=None, metadata={ @@ -603,7 +610,7 @@ def is_rerank_model(self) -> bool: _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = { - ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi", + ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi", } EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {} diff --git a/dbgpt/rag/embedding/__init__.py b/dbgpt/rag/embedding/__init__.py index fcd4590f9..c14987de8 100644 --- a/dbgpt/rag/embedding/__init__.py +++ b/dbgpt/rag/embedding/__init__.py @@ -15,6 +15,7 @@ OllamaEmbeddings, OpenAPIEmbeddings, TongYiEmbeddings, + QianFanEmbeddings, ) from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401 @@ -33,4 +34,5 @@ "TongYiEmbeddings", "CrossEncoderRerankEmbeddings", "OpenAPIRerankEmbeddings", + "QianFanEmbeddings" ] diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index 7d14c0fb5..773fb9aa0 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -922,3 +922,100 @@ def embed_query(self, text: str) -> List[float]: Embeddings for the text. """ return self.embed_documents([text])[0] + + +class QianFanEmbeddings(BaseModel, Embeddings): + """Baidu Qianfan Embeddings embedding models. + Embed: + .. code-block:: python + + # embed the documents + vectors = embeddings.embed_documents([text1, text2, ...]) + + # embed the query + vectors = embeddings.embed_query(text) + + """ # noqa: E501 + client: Any + chunk_size: int = 16 + endpoint: str = "" + """Endpoint of the Qianfan Embedding, required if custom model used.""" + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) + api_key: Optional[str] = Field( + default=None, description="The API key for the embeddings API." + ) + api_secret: Optional[str] = Field( + default=None, description="The Secret key for the embeddings API." + ) + """Model name + you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu + + for now, we support Embedding-V1 and + - Embedding-V1 (默认模型) + - bge-large-en + - bge-large-zh + + preset models are mapping to an endpoint. + `model` will be ignored if `endpoint` is set + """ + model_name: str = Field( + default="text-embedding-v1", description="The name of the model to use." + ) + init_kwargs: Dict[str, Any] = Field(default_factory=dict) + """init kwargs for qianfan client init, such as `query_per_second` which is + associated with qianfan resource object to limit QPS""" + + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """extra params for model invoke using with `do`.""" + + def __init__(self, **kwargs): + """Initialize the QianFanEmbeddings.""" + try: + import qianfan + except ImportError as exc: + raise ValueError( + "Could not import python package: qianfan. " + "Please install qianfan by running `pip install qianfan`." + ) from exc + + qianfan_ak = kwargs.get("api_key") + qianfan_sk = kwargs.get("api_secret") + model_name = kwargs.get("model_name") + + if not qianfan_ak or not qianfan_sk or not model_name: + raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.") + + params = { + "model": model_name, + "ak": qianfan_ak, + "sk": qianfan_sk, + } + + # Initialize the qianfan.Embedding client + kwargs["client"] = qianfan.Embedding(**params) + super().__init__(**kwargs) + + def embed_query(self, text: str) -> List[float]: + resp = self.embed_documents([text]) + return resp[0] + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """ + Embeds a list of text documents using the AutoVOT algorithm. + + Args: + texts (List[str]): A list of text documents to embed. + + Returns: + List[List[float]]: A list of embeddings for each document in the input list. + Each embedding is represented as a list of float values. + """ + text_in_chunks = [ + texts[i: i + self.chunk_size] + for i in range(0, len(texts), self.chunk_size) + ] + lst = [] + for chunk in text_in_chunks: + resp = self.client.do(texts=chunk, **self.model_kwargs) + lst.extend([res["embedding"] for res in resp["data"]]) + return lst diff --git a/setup.py b/setup.py index cbe5592ce..4b3ede1b9 100644 --- a/setup.py +++ b/setup.py @@ -685,6 +685,7 @@ def default_requires(): "chardet", "sentencepiece", "ollama", + "qianfan" ] setup_spec.extras["default"] += setup_spec.extras["framework"] setup_spec.extras["default"] += setup_spec.extras["rag"]