Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update .env.template #1805

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ KNOWLEDGE_SEARCH_REWRITE=False
## qwen embedding model, See dbgpt/model/parameter.py
# EMBEDDING_MODEL=proxy_tongyi
# proxy_tongyi_proxy_backend=text-embedding-v1
# proxy_tongyi_proxy_api_key={your-api-key}

## qianfan embedding model, See dbgpt/model/parameter.py
#EMBEDDING_MODEL=proxy_qianfan
#proxy_qianfan_proxy_backend=bge-large-zh
#proxy_qianfan_proxy_api_key={your-api-key}
#proxy_qianfan_proxy_api_secret={your-secret-key}


## Common HTTP embedding model
# EMBEDDING_MODEL=proxy_http_openapi
Expand Down Expand Up @@ -297,4 +305,4 @@ DBGPT_LOG_LEVEL=INFO
#*******************************************************************#
#** FINANCIAL CHAT Config **#
#*******************************************************************#
# FIN_REPORT_MODEL=/app/models/bge-large-zh
# FIN_REPORT_MODEL=/app/models/bge-large-zh
1 change: 1 addition & 0 deletions dbgpt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def get_device() -> str:
"proxy_http_openapi": "proxy_http_openapi",
"proxy_ollama": "proxy_ollama",
"proxy_tongyi": "proxy_tongyi",
"proxy_qianfan": "proxy_qianfan",
# Rerank model, rerank mode is a special embedding model
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
Expand Down
8 changes: 8 additions & 0 deletions dbgpt/model/adapter/embeddings_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def load(self, model_name: str, param: BaseEmbeddingModelParameters) -> Embeddin
if proxy_param.proxy_backend:
tongyi_param["model_name"] = proxy_param.proxy_backend
return TongYiEmbeddings(**tongyi_param)
elif model_name in ["proxy_qianfan"]:
from dbgpt.rag.embedding import QianFanEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
qianfan_param = {"api_key": proxy_param.proxy_api_key}
if proxy_param.proxy_backend:
qianfan_param["model_name"] = proxy_param.proxy_backend
qianfan_param["api_secret"] = proxy_param.proxy_api_secret
return QianFanEmbeddings(**qianfan_param)
elif model_name in ["proxy_ollama"]:
from dbgpt.rag.embedding import OllamaEmbeddings

Expand Down
9 changes: 8 additions & 1 deletion dbgpt/model/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,13 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
},
)
proxy_api_secret: str = field(
default=None,
metadata={
"tags": "privacy",
"help": "The api secret of the current embedding model(OPENAI_API_SECRET)",
}
)
proxy_api_version: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -603,7 +610,7 @@ def is_rerank_model(self) -> bool:


_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,rerank_proxy_http_openapi",
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi",
}

EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
Expand Down
2 changes: 2 additions & 0 deletions dbgpt/rag/embedding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
OllamaEmbeddings,
OpenAPIEmbeddings,
TongYiEmbeddings,
QianFanEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401

Expand All @@ -33,4 +34,5 @@
"TongYiEmbeddings",
"CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings",
"QianFanEmbeddings"
]
97 changes: 97 additions & 0 deletions dbgpt/rag/embedding/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,3 +922,100 @@ def embed_query(self, text: str) -> List[float]:
Embeddings for the text.
"""
return self.embed_documents([text])[0]


class QianFanEmbeddings(BaseModel, Embeddings):
"""Baidu Qianfan Embeddings embedding models.
Embed:
.. code-block:: python

# embed the documents
vectors = embeddings.embed_documents([text1, text2, ...])

# embed the query
vectors = embeddings.embed_query(text)

""" # noqa: E501
client: Any
chunk_size: int = 16
endpoint: str = ""
"""Endpoint of the Qianfan Embedding, required if custom model used."""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
api_secret: Optional[str] = Field(
default=None, description="The Secret key for the embeddings API."
)
"""Model name
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu

for now, we support Embedding-V1 and
- Embedding-V1 (默认模型)
- bge-large-en
- bge-large-zh

preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set
"""
model_name: str = Field(
default="text-embedding-v1", description="The name of the model to use."
)
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS"""

model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra params for model invoke using with `do`."""

def __init__(self, **kwargs):
"""Initialize the QianFanEmbeddings."""
try:
import qianfan
except ImportError as exc:
raise ValueError(
"Could not import python package: qianfan. "
"Please install qianfan by running `pip install qianfan`."
) from exc

qianfan_ak = kwargs.get("api_key")
qianfan_sk = kwargs.get("api_secret")
model_name = kwargs.get("model_name")

if not qianfan_ak or not qianfan_sk or not model_name:
raise ValueError("API key, API secret, and model name are required to initialize QianFanEmbeddings.")

params = {
"model": model_name,
"ak": qianfan_ak,
"sk": qianfan_sk,
}

# Initialize the qianfan.Embedding client
kwargs["client"] = qianfan.Embedding(**params)
super().__init__(**kwargs)

def embed_query(self, text: str) -> List[float]:
resp = self.embed_documents([text])
return resp[0]

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Embeds a list of text documents using the AutoVOT algorithm.

Args:
texts (List[str]): A list of text documents to embed.

Returns:
List[List[float]]: A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values.
"""
text_in_chunks = [
texts[i: i + self.chunk_size]
for i in range(0, len(texts), self.chunk_size)
]
lst = []
for chunk in text_in_chunks:
resp = self.client.do(texts=chunk, **self.model_kwargs)
lst.extend([res["embedding"] for res in resp["data"]])
return lst
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def default_requires():
"chardet",
"sentencepiece",
"ollama",
"qianfan"
]
setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["rag"]
Expand Down
Loading