Skip to content

Commit

Permalink
feat(model): Support qianfan embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
dusens committed Aug 13, 2024
1 parent 8661d19 commit ed86e04
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 1 deletion.
7 changes: 7 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ KNOWLEDGE_SEARCH_REWRITE=False
# proxy_tongyi_proxy_backend=text-embedding-v1
# proxy_tongyi_proxy_api_key={your-api-key}

## qianfan embedding model, See dbgpt/model/parameter.py
#EMBEDDING_MODEL=proxy_qianfan
#proxy_qianfan_proxy_backend=bge-large-zh
#proxy_qianfan_proxy_api_key={your-api-key}
#proxy_qianfan_proxy_api_secret={your-secret-key}


## Common HTTP embedding model
# EMBEDDING_MODEL=proxy_http_openapi
# proxy_http_openapi_proxy_server_url=http://localhost:8100/api/v1/embeddings
Expand Down
1 change: 1 addition & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def get_device() -> str:
"proxy_http_openapi": "proxy_http_openapi",
"proxy_ollama": "proxy_ollama",
"proxy_tongyi": "proxy_tongyi",
"proxy_qianfan": "proxy_qianfan",
# Rerank model, rerank mode is a special embedding model
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
Expand Down
8 changes: 8 additions & 0 deletions dbgpt/model/adapter/embeddings_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddin
if proxy_param.proxy_backend:
tongyi_param["model_name"] = proxy_param.proxy_backend
return TongYiEmbeddings(**tongyi_param)
elif model_name in ["proxy_qianfan"]:
from dbgpt.rag.embedding import QianFanEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
qianfan_param = {"api_key": proxy_param.proxy_api_key}
if proxy_param.proxy_backend:
qianfan_param["model_name"] = proxy_param.proxy_backend
qianfan_param["api_secret"] = proxy_param.proxy_api_secret
return QianFanEmbeddings(**qianfan_param)
elif model_name in ["proxy_ollama"]:
from dbgpt.rag.embedding import OllamaEmbeddings

Expand Down
9 changes: 8 additions & 1 deletion dbgpt/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,13 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
},
)
proxy_api_secret: str = field(
default=None,
metadata={
"tags": "privacy",
"help": "The api secret of the current embedding model(OPENAI_API_SECRET)",
}
)
proxy_api_version: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -603,7 +610,7 @@ def is_rerank_model(self) -> bool:


_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi",
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi",
}

EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/rag/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
OllamaEmbeddings,
OpenAPIEmbeddings,
TongYiEmbeddings,
QianFanEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401

Expand All @@ -33,4 +34,5 @@
"TongYiEmbeddings",
"CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings",
"QianFanEmbeddings"
]
97 changes: 97 additions & 0 deletions dbgpt/rag/embedding/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,3 +922,100 @@ def embed_query(self, text: str) -> List[float]:
Embeddings for the text.
"""
return self.embed_documents([text])[0]


class QianFanEmbeddings(BaseModel, Embeddings):
"""Baidu Qianfan Embeddings embedding models.
Embed:
.. code-block:: python
# embed the documents
vectors = embeddings.embed_documents([text1, text2, ...])
# embed the query
vectors = embeddings.embed_query(text)
""" # noqa: E501
client: Any
chunk_size: int = 16
endpoint: str = ""
"""Endpoint of the Qianfan Embedding, required if custom model used."""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
api_secret: Optional[str] = Field(
default=None, description="The Secret key for the embeddings API."
)
"""Model name
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
for now, we support Embedding-V1 and
- Embedding-V1 (默认模型)
- bge-large-en
- bge-large-zh
preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set
"""
model_name: str = Field(
default="text-embedding-v1", description="The name of the model to use."
)
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS"""

model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra params for model invoke using with `do`."""

def __init__(self, **kwargs):
"""Initialize the QianFanEmbeddings."""
try:
import qianfan
except ImportError as exc:
raise ValueError(
"Could not import python package: qianfan. "
"Please install qianfan by running `pip install qianfan`."
) from exc

qianfan_ak = kwargs.get("api_key")
qianfan_sk = kwargs.get("api_secret")
model_name = kwargs.get("model_name")

if not qianfan_ak or not qianfan_sk or not model_name:
raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.")

params = {
"model": model_name,
"ak": qianfan_ak,
"sk": qianfan_sk,
}

# Initialize the qianfan.Embedding client
kwargs["client"] = qianfan.Embedding(**params)
super().__init__(**kwargs)

def embed_query(self, text: str) -> List[float]:
resp = self.embed_documents([text])
return resp[0]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of text documents using the AutoVOT algorithm.
Args:
texts (List[str]): A list of text documents to embed.
Returns:
List[List[float]]: A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i: i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
resp = self.client.do(texts=chunk, **self.model_kwargs)
lst.extend([res["embedding"] for res in resp["data"]])
return lst
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def default_requires():
"chardet",
"sentencepiece",
"ollama",
"qianfan"
]
setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["rag"]
Expand Down

0 comments on commit ed86e04

Please sign in to comment.