Skip to content

Commit

Permalink
Merge branch 'feat/0.3.1.4' of https://github.com/dataelement/bisheng
Browse files Browse the repository at this point in the history
…into feat/0.3.1.4
  • Loading branch information
QwQ-wuwuwu committed Jun 26, 2024
2 parents 990a92e + c6d56c0 commit cfbad0d
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 281 deletions.
48 changes: 47 additions & 1 deletion src/backend/bisheng/chat/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
from bisheng.chat.utils import judge_source, process_graph, process_source_document
from bisheng.database.base import session_getter
from bisheng.database.models.report import Report
from bisheng.interface.importing.utils import import_by_type
from bisheng.interface.initialize.loading import instantiate_llm
from bisheng.settings import settings
from bisheng.utils.docx_temp import test_replace_string
from bisheng.utils.logger import logger
from bisheng.utils.minio_client import MinioClient
from bisheng.utils.threadpool import thread_pool
from bisheng.utils.util import get_cache_key
from bisheng_langchain.chains.autogen.auto_gen import AutoGenChain
from langchain.chains.llm import LLMChain
from langchain_core.prompts.prompt import PromptTemplate
from sqlmodel import select


Expand Down Expand Up @@ -100,6 +105,30 @@ async def process_report(self,
close_resp = ChatResponse(type='close', category='system', user_id=user_id)
await session.send_json(client_id, chat_id, close_resp)

def recommend_question(self, langchain_obj, chat_history: list):
prompt = """给定以下历史聊天消息:
{history}
总结提炼用户可能接下来会提问的3个问题,请直接输出问题,使用换行符分割问题,不要添加任何修饰文字或前后缀。
"""
if hasattr(langchain_obj, 'llm'):
llm_chain = LLMChain(llm=langchain_obj.llm,
prompt=PromptTemplate.from_template(prompt))
else:
keyword_conf = settings.get_default_llm() or {}
if keyword_conf:
node_type = keyword_conf.pop('type', 'HostQwenChat') # 兼容旧配置
class_object = import_by_type(_type='llms', name=node_type)
llm = instantiate_llm(node_type, class_object, keyword_conf)

llm_chain = LLMChain(llm=llm, prompt=PromptTemplate.from_template(prompt))
if llm_chain:
questions = llm_chain.predict(history=chat_history)
return questions.split('\n')
else:
logger.info('llm_chain is None recommend_over')
return []

async def process_message(self,
session: ChatManager,
client_id: str,
Expand Down Expand Up @@ -142,6 +171,12 @@ async def process_message(self,
flow_id=client_id,
chat_id=chat_id,
)

questions = []
if is_begin and langchain_object.memory and langchain_object.memory.buffer:
questions = self.recommend_question(langchain_object,
langchain_object.memory.buffer)

except Exception as e:
# Log stack trace
logger.exception(e)
Expand Down Expand Up @@ -188,8 +223,19 @@ async def process_message(self,
user_id=user_id,
source=int(source))
await session.send_json(client_id, chat_id, response)
if questions:
# 提示问题
for question in questions:
question_resp = ChatResponse(type='start',
message='',
category='autoQuestion',
user_id=user_id)
await session.send_json(client_id, chat_id, question_resp, add=False)
question_resp.type = 'end'
question_resp.message = question
await session.send_json(client_id, chat_id, question_resp, add=False)

# 循环结束
# 循环结束
if is_begin:
close_resp = ChatResponse(type='close', user_id=user_id)
await session.send_json(client_id, chat_id, close_resp)
Expand Down
19 changes: 12 additions & 7 deletions src/backend/bisheng/interface/chains/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from typing import Any, ClassVar, Dict, List, Optional, Type

from langchain import chains
from langchain_experimental import sql

from bisheng_langchain import chains as bisheng_chains
from bisheng_langchain.rag.bisheng_rag_chain import BishengRetrievalQA

from bisheng.custom.customs import get_custom_nodes
from bisheng.interface.base import LangChainTypeCreator
from bisheng.interface.importing.utils import import_class
from bisheng.settings import settings
from bisheng.template.frontend_node.chains import ChainFrontendNode
from bisheng.utils.logger import logger
from bisheng.utils.util import build_template_from_class, build_template_from_method

from bisheng_langchain import chains as bisheng_chains
from bisheng_langchain import sql as bisheng_sql
from bisheng_langchain.rag.bisheng_rag_chain import BishengRetrievalQA
from langchain import chains
from langchain_experimental import sql

# Assuming necessary imports for Field, Template, and FrontendNode classes

Expand Down Expand Up @@ -61,6 +59,13 @@ def type_to_loader_dict(self) -> Dict:
}
self.type_dict.update(community)

# sql community
bisheng_sql_add = {
chain_name: import_class(f'bisheng_langchain.sql.{chain_name}')
for chain_name in bisheng_sql.__all__
}
self.type_dict.update(bisheng_sql_add)

from bisheng.interface.chains.custom import CUSTOM_CHAINS

self.type_dict.update(CUSTOM_CHAINS)
Expand Down
10 changes: 6 additions & 4 deletions src/backend/bisheng/utils/threadpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,15 @@ async def as_completed(self,
# self.async_task_result.append(future)

def cancel_task(self, key_list: List[str]):
res = [False] * len(key_list)
with self.lock:
res = []
for key in key_list:
for index, key in enumerate(key_list):
if self.async_task.get(key):
logger.info('clean_pending_task key={}', key)
for task in self.async_task.get(key):
res.append(task.result().cancel())
cancel_res = task.result().cancel()
logger.info('clean_pending_task key={} task={} res={}', key, task,
cancel_res)
res[index] = cancel_res
if self.future_dict.get(key):
for task in self.future_dict.get(key):
res.append(task.cancel())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# import base64
import copy
import base64
from typing import Optional

import requests
from typing import Any, Iterator, List, Mapping, Optional, Union


class ELLMClient(object):
def __init__(self,
api_base_url: Optional[str] = None):

def __init__(self, api_base_url: Optional[str] = None):
self.ep = api_base_url
self.client = requests.Session()
self.timeout = 10000
Expand All @@ -26,8 +26,8 @@ def __init__(self,
'ellm': 'ELLM'
},
'form': {
'det': 'mrcnn-v5.1',
'recog': 'transformer-v2.8-gamma-faster',
'det': 'general_text_det_v2.0',
'recog': 'general_text_reg_nb_v1.0_faster',
'ellm': 'ELLM'
},
'hand': {
Expand All @@ -48,9 +48,7 @@ def predict(self, inp):
req_data = {'data': [b64_image], 'param': params}

try:
r = self.client.post(url=self.ep,
json=req_data,
timeout=self.timeout)
r = self.client.post(url=self.ep, json=req_data, timeout=self.timeout)
return r.json()
except Exception as e:
return {'status_code': 400, 'status_message': str(e)}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def transpdf2png(pdf_file):
class UniversalKVLoader(BaseLoader):
"""Extract key-value from pdf or image.
"""

def __init__(self,
file_path: str,
ellm_model_url: str = None,
Expand Down Expand Up @@ -83,7 +84,7 @@ def load(self) -> List[Document]:

kv_results = defaultdict(list)
for key, value in key_values.items():
kv_results[key] = value['text']
kv_results[key].extend([v['text'] for v in value])

content = json.dumps(kv_results, indent=2, ensure_ascii=False)
file_name = os.path.basename(self.file_path)
Expand All @@ -95,7 +96,7 @@ def load(self) -> List[Document]:
pdf_images = transpdf2png(self.file_path)

kv_results = defaultdict(list)
for pdf_name in pdf_images:
for index, pdf_name in enumerate(pdf_images):
page = int(pdf_name.split('page_')[-1])
if page > self.max_pages:
continue
Expand All @@ -110,7 +111,7 @@ def load(self) -> List[Document]:
raise ValueError(f'universal kv load failed: {resp}')

for key, value in key_values.items():
kv_results[key].extend(value['text'])
kv_results[key].extend([v['text'] for v in value])

content = json.dumps(kv_results, indent=2, ensure_ascii=False)
file_name = os.path.basename(self.file_path)
Expand Down
3 changes: 3 additions & 0 deletions src/bisheng-langchain/bisheng_langchain/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from bisheng_langchain.sql.base import SQLDatabaseChain

__all__ = ['SQLDatabaseChain']
Loading

0 comments on commit cfbad0d

Please sign in to comment.