Skip to content

Commit

Permalink
Merge pull request chatchat-space#2435 from chatchat-space/reranker
Browse files Browse the repository at this point in the history
新增特性:使用Reranker模型对召回语句进行重排
  • Loading branch information
hzg0601 authored Dec 21, 2023
2 parents 60510ff + 129c765 commit d77f778
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 19 deletions.
15 changes: 13 additions & 2 deletions configs/model_config.py.example
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ EMBEDDING_MODEL = "bge-large-zh"
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto"

# 选用的reranker模型
RERANKER_MODEL = "bge-reranker-large"
# 是否启用reranker模型
USE_RERANKER = False
RERANKER_MAX_LENGTH = 1024
# 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置
EMBEDDING_KEYWORD_FILE = "keywords.txt"
EMBEDDING_MODEL_OUTPUT_PATH = "output"
Expand All @@ -19,8 +24,9 @@ EMBEDDING_MODEL_OUTPUT_PATH = "output"
# 列表中第一个模型将作为 API 和 WEBUI 的默认模型。
# 在这里,我们使用目前主流的两个离线模型,其中,chatglm3-6b 为默认加载模型。
# 如果你的显存不足,可使用 Qwen-1_8B-Chat, 该模型 FP16 仅需 3.8G显存。
# chatglm3-6b输出角色标签<|user|>及自问自答的问题是由于fschat=0.2.33并未正确适配chatglm3的对话模板
# 如需修正该问题,需修改fschat的源码,详细步骤见项目wiki->常见问题->Q20.

# chatglm3-6b输出角色标签<|user|>及自问自答的问题详见项目wiki->常见问题->Q20.

LLM_MODELS = ["chatglm3-6b", "zhipu-api", "openai-api"] # "Qwen-1_8B-Chat",

# AgentLM模型的名称 (可以不指定,指定之后就锁定进入Agent之后的Chain的模型,不指定就是LLM_MODELS[0])
Expand Down Expand Up @@ -236,6 +242,11 @@ MODEL_PATH = {

"Yi-34B-Chat": "https://huggingface.co/01-ai/Yi-34B-Chat",
},
"reranker":{
"bge-reranker-large":"BAAI/bge-reranker-large",
"bge-reranker-base":"BAAI/bge-reranker-base",
#TODO 增加在线reranker,如cohere
}
}


Expand Down
30 changes: 27 additions & 3 deletions server/chat/knowledge_base_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from fastapi import Body, Request
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
Expand All @@ -14,8 +21,8 @@
import json
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs


from server.reranker.reranker import LangchainReranker
from server.utils import embedding_device
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
Expand Down Expand Up @@ -76,7 +83,24 @@ async def knowledge_base_chat_iterator(
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)

# 加入reranker
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=top_k,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)
print(docs)
docs = reranker_model.compress_documents(documents=docs,
query=query)
print("---------after rerank------------------")
print(docs)
context = "\n".join([doc.page_content for doc in docs])

if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
Expand Down
125 changes: 111 additions & 14 deletions server/reranker/reranker.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,116 @@
from langchain.retrievers.document_compressors import CohereRerank
from llama_index.postprocessor import SentenceTransformerRerank
from sentence_transformers import SentenceTransformer,CrossEncoder
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from typing import Any, List, Optional
from sentence_transformers import CrossEncoder
from typing import Optional, Sequence
from langchain_core.documents import Document
from langchain.callbacks.manager import Callbacks
from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from llama_index.bridge.pydantic import Field,PrivateAttr

model_path = "/root/autodl-tmp/models/bge-reranker-large/"
instruction = "为这个句子生成表示以用于检索相关文章:"
reranker = SentenceTransformerRerank(
top_n=5,
model="local:"+model_path,
)
class LangchainReranker(BaseDocumentCompressor):
"""Document compressor that uses `Cohere Rerank API`."""
model_name_or_path:str = Field()
_model: Any = PrivateAttr()
top_n:int= Field()
device:str=Field()
max_length:int=Field()
batch_size: int = Field()
# show_progress_bar: bool = None
num_workers: int = Field()
# activation_fct = None
# apply_softmax = False

def __init__(self,
model_name_or_path:str,
top_n:int=3,
device:str="cuda",
max_length:int=1024,
batch_size: int = 32,
# show_progress_bar: bool = None,
num_workers: int = 0,
# activation_fct = None,
# apply_softmax = False,
):
# self.top_n=top_n
# self.model_name_or_path=model_name_or_path
# self.device=device
# self.max_length=max_length
# self.batch_size=batch_size
# self.show_progress_bar=show_progress_bar
# self.num_workers=num_workers
# self.activation_fct=activation_fct
# self.apply_softmax=apply_softmax

reranker_model = SentenceTransformer(model_name_or_path=model_path,device="cuda")
self._model = CrossEncoder(model_name=model_name_or_path,max_length=1024,device=device)
super().__init__(
top_n=top_n,
model_name_or_path=model_name_or_path,
device=device,
max_length=max_length,
batch_size=batch_size,
# show_progress_bar=show_progress_bar,
num_workers=num_workers,
# activation_fct=activation_fct,
# apply_softmax=apply_softmax
)

reranker_ce = CrossEncoder(model_name=model_path,device="cuda",max_length=1024)
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""
Compress documents using Cohere's rerank API.
reranker_ce.predict([[],[]])

print("Load reranker")
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
if len(documents) == 0: # to avoid empty api call
return []
doc_list = list(documents)
_docs = [d.page_content for d in doc_list]
sentence_pairs = [[query,_doc] for _doc in _docs]
results = self._model.predict(sentences=sentence_pairs,
batch_size=self.batch_size,
# show_progress_bar=self.show_progress_bar,
num_workers=self.num_workers,
# activation_fct=self.activation_fct,
# apply_softmax=self.apply_softmax,
convert_to_tensor=True
)
top_k = self.top_n if self.top_n < len(results) else len(results)

values, indices = results.topk(top_k)
final_results = []
for value, index in zip(values,indices):
doc = doc_list[index]
doc.metadata["relevance_score"] = value
final_results.append(doc)
return final_results
if __name__ == "__main__":
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH)
from server.utils import embedding_device
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=3,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)

0 comments on commit d77f778

Please sign in to comment.