Skip to content

Commit

Permalink
refactor: Workflow execution logic (#1913)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohuzhang1 authored Dec 26, 2024
1 parent bb58ac6 commit efa73c8
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 70 deletions.
17 changes: 17 additions & 0 deletions apps/application/flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,20 @@ def __init__(self, content, view_type, runtime_node_id, chat_record_id, child_no
def to_dict(self):
return {'view_type': self.view_type, 'content': self.content, 'runtime_node_id': self.runtime_node_id,
'chat_record_id': self.chat_record_id, 'child_node': self.child_node}


class NodeChunk:
def __init__(self):
self.status = 0
self.chunk_list = []

def add_chunk(self, chunk):
self.chunk_list.append(chunk)

def end(self, chunk=None):
if chunk is not None:
self.add_chunk(chunk)
self.status = 200

def is_end(self):
return self.status == 200
4 changes: 3 additions & 1 deletion apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from rest_framework import serializers
from rest_framework.exceptions import ValidationError, ErrorDetail

from application.flow.common import Answer
from application.flow.common import Answer, NodeChunk
from application.models import ChatRecord
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
Expand Down Expand Up @@ -175,6 +175,7 @@ def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None,
if up_node_id_list is None:
up_node_id_list = []
self.up_node_id_list = up_node_id_list
self.node_chunk = NodeChunk()
self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
"".join([*sorted(up_node_id_list),
node.id]))),
Expand Down Expand Up @@ -214,6 +215,7 @@ def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]:

def get_write_error_context(self, e):
self.status = 500
self.answer_text = str(e)
self.err_message = str(e)
self.context['run_time'] = time.time() - self.context['start_time']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import List, Dict

from django.db.models import QuerySet

from django.db import connection
from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import VectorStore
Expand Down Expand Up @@ -77,6 +77,8 @@ def execute(self, dataset_id_list, dataset_setting, question,
embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
# 手动关闭数据库连接
connection.close()
if embedding_list is None:
return get_none_result(question)
paragraph_list = self.list_paragraph(embedding_list, vector)
Expand Down
120 changes: 56 additions & 64 deletions apps/application/flow/workflow_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
@date:2024/1/9 17:40
@desc:
"""
import concurrent
import json
import threading
import traceback
from concurrent.futures import ThreadPoolExecutor
from functools import reduce
from typing import List, Dict

from django.db import close_old_connections
from django.db.models import QuerySet
from langchain_core.prompts import PromptTemplate
from rest_framework import status
Expand Down Expand Up @@ -223,23 +225,6 @@ def pop(self):
return None


class NodeChunk:
def __init__(self):
self.status = 0
self.chunk_list = []

def add_chunk(self, chunk):
self.chunk_list.append(chunk)

def end(self, chunk=None):
if chunk is not None:
self.add_chunk(chunk)
self.status = 200

def is_end(self):
return self.status == 200


class WorkflowManage:
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None,
Expand Down Expand Up @@ -273,8 +258,9 @@ def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandl
self.status = 200
self.base_to_response = base_to_response
self.chat_record = chat_record
self.await_future_map = {}
self.child_node = child_node
self.future_list = []
self.lock = threading.Lock()
if start_node_id is not None:
self.load_node(chat_record, start_node_id, start_node_data)
else:
Expand Down Expand Up @@ -319,6 +305,7 @@ def get_node_params(n):
self.node_context.append(node)

def run(self):
close_old_connections()
if self.params.get('stream'):
return self.run_stream(self.start_node, None)
return self.run_block()
Expand All @@ -328,8 +315,9 @@ def run_block(self):
非流式响应
@return: 结果
"""
result = self.run_chain_async(None, None)
result.result()
self.run_chain_async(None, None)
while self.is_run():
pass
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
Expand All @@ -350,12 +338,22 @@ def run_stream(self, current_node, node_result_future):
流式响应
@return:
"""
result = self.run_chain_async(current_node, node_result_future)
return tools.to_stream_response_simple(self.await_result(result))
self.run_chain_async(current_node, node_result_future)
return tools.to_stream_response_simple(self.await_result())

def await_result(self, result):
def is_run(self, timeout=0.1):
self.lock.acquire()
try:
while await_result(result):
r = concurrent.futures.wait(self.future_list, timeout)
return len(r.not_done) > 0
except Exception as e:
return True
finally:
self.lock.release()

def await_result(self):
try:
while self.is_run():
while True:
chunk = self.node_chunk_manage.pop()
if chunk is not None:
Expand Down Expand Up @@ -383,42 +381,39 @@ def await_result(self, result):
'', True, message_tokens, answer_tokens, {})

def run_chain_async(self, current_node, node_result_future):
return executor.submit(self.run_chain_manage, current_node, node_result_future)
future = executor.submit(self.run_chain_manage, current_node, node_result_future)
self.future_list.append(future)

def run_chain_manage(self, current_node, node_result_future):
if current_node is None:
start_node = self.get_start_node()
current_node = get_node(start_node.type)(start_node, self.params, self)
self.node_chunk_manage.add_node_chunk(current_node.node_chunk)
# 添加节点
self.append_node(current_node)
result = self.run_chain(current_node, node_result_future)
if result is None:
return
node_list = self.get_next_node_list(current_node, result)
if len(node_list) == 1:
self.run_chain_manage(node_list[0], None)
elif len(node_list) > 1:

sorted_node_run_list = sorted(node_list, key=lambda n: n.node.y)
# 获取到可执行的子节点
result_list = [{'node': node, 'future': executor.submit(self.run_chain_manage, node, None)} for node in
node_list]
self.set_await_map(result_list)
[r.get('future').result() for r in result_list]

def set_await_map(self, node_run_list):
sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y)
for index in range(len(sorted_node_run_list)):
self.await_future_map[sorted_node_run_list[index].get('node').runtime_node_id] = [
sorted_node_run_list[i].get('future')
for i in range(index)]
sorted_node_run_list]
try:
self.lock.acquire()
for r in result_list:
self.future_list.append(r.get('future'))
finally:
self.lock.release()

def run_chain(self, current_node, node_result_future=None):
if node_result_future is None:
node_result_future = self.run_node_future(current_node)
try:
is_stream = self.params.get('stream', True)
# 处理节点响应
await_future_list = self.await_future_map.get(current_node.runtime_node_id, None)
if await_future_list is not None:
[f.result() for f in await_future_list]
result = self.hand_event_node_result(current_node,
node_result_future) if is_stream else self.hand_node_result(
current_node, node_result_future)
Expand All @@ -434,16 +429,14 @@ def hand_node_result(self, current_node, node_result_future):
if result is not None:
# 阻塞获取结果
list(result)
# 添加节点
self.node_context.append(current_node)
return current_result
except Exception as e:
# 添加节点
self.node_context.append(current_node)
traceback.print_exc()
self.status = 500
current_node.get_write_error_context(e)
self.answer += str(e)
finally:
current_node.node_chunk.end()

def append_node(self, current_node):
for index in range(len(self.node_context)):
Expand All @@ -454,15 +447,14 @@ def append_node(self, current_node):
self.node_context.append(current_node)

def hand_event_node_result(self, current_node, node_result_future):
node_chunk = NodeChunk()
real_node_id = current_node.runtime_node_id
child_node = {}
view_type = current_node.view_type
try:
current_result = node_result_future.result()
result = current_result.write_context(current_node, self)
if result is not None:
if self.is_result(current_node, current_result):
self.node_chunk_manage.add_node_chunk(node_chunk)
for r in result:
content = r
child_node = {}
Expand All @@ -487,26 +479,24 @@ def hand_event_node_result(self, current_node, node_result_future):
'child_node': child_node,
'node_is_end': node_is_end,
'real_node_id': real_node_id})
node_chunk.add_chunk(chunk)
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
current_node.id,
current_node.up_node_id_list,
'', False, 0, 0, {'node_is_end': True,
'runtime_node_id': current_node.runtime_node_id,
'node_type': current_node.type,
'view_type': view_type,
'child_node': child_node,
'real_node_id': real_node_id})
node_chunk.end(chunk)
current_node.node_chunk.add_chunk(chunk)
chunk = (self.base_to_response
.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
current_node.id,
current_node.up_node_id_list,
'', False, 0, 0, {'node_is_end': True,
'runtime_node_id': current_node.runtime_node_id,
'node_type': current_node.type,
'view_type': view_type,
'child_node': child_node,
'real_node_id': real_node_id}))
current_node.node_chunk.add_chunk(chunk)
else:
list(result)
# 添加节点
self.append_node(current_node)
return current_result
except Exception as e:
# 添加节点
self.append_node(current_node)
traceback.print_exc()
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
Expand All @@ -519,12 +509,12 @@ def hand_event_node_result(self, current_node, node_result_future):
'view_type': current_node.view_type,
'child_node': {},
'real_node_id': real_node_id})
if not self.node_chunk_manage.contains(node_chunk):
self.node_chunk_manage.add_node_chunk(node_chunk)
node_chunk.end(chunk)
current_node.node_chunk.add_chunk(chunk)
current_node.get_write_error_context(e)
self.status = 500
return None
finally:
current_node.node_chunk.end()

def run_node_async(self, node):
future = executor.submit(self.run_node, node)
Expand Down Expand Up @@ -636,6 +626,8 @@ def get_next_node(self):

@staticmethod
def dependent_node(up_node_id, node):
if not node.node_chunk.is_end():
return False
if node.id == up_node_id:
if node.type == 'form-node':
if node.context.get('form_data', None) is not None:
Expand Down
3 changes: 3 additions & 0 deletions apps/setting/models_provider/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
@date:2024/7/22 11:18
@desc:
"""
from django.db import connection
from django.db.models import QuerySet

from common.config.embedding_config import ModelManage
Expand All @@ -15,6 +16,8 @@

def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
# 手动关闭数据库连接
connection.close()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
Expand Down
3 changes: 3 additions & 0 deletions apps/setting/serializers/model_apply_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
@date:2024/8/20 20:39
@desc:
"""
from django.db import connection
from django.db.models import QuerySet
from langchain_core.documents import Document
from rest_framework import serializers
Expand All @@ -18,6 +19,8 @@

def get_embedding_model(model_id):
model = QuerySet(Model).filter(id=model_id).first()
# 手动关闭数据库连接
connection.close()
embedding_model = ModelManage.get_model(model_id,
lambda _id: get_model(model, use_local=True))
return embedding_model
Expand Down
8 changes: 6 additions & 2 deletions apps/smartdoc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class Config(dict):
"DB_PORT": 5432,
"DB_USER": "root",
"DB_PASSWORD": "Password123@postgres",
"DB_ENGINE": "django.db.backends.postgresql_psycopg2",
"DB_ENGINE": "dj_db_conn_pool.backends.postgresql",
# 向量模型
"EMBEDDING_MODEL_NAME": "shibing624/text2vec-base-chinese",
"EMBEDDING_DEVICE": "cpu",
Expand Down Expand Up @@ -108,7 +108,11 @@ def get_db_setting(self) -> dict:
"PORT": self.get('DB_PORT'),
"USER": self.get('DB_USER'),
"PASSWORD": self.get('DB_PASSWORD'),
"ENGINE": self.get('DB_ENGINE')
"ENGINE": self.get('DB_ENGINE'),
"POOL_OPTIONS": {
"POOL_SIZE": 20,
"MAX_OVERFLOW": 5
}
}

def __init__(self, *args):
Expand Down
2 changes: 1 addition & 1 deletion installer/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ DB_HOST: 127.0.0.1
DB_PORT: 5432
DB_USER: root
DB_PASSWORD: Password123@postgres
DB_ENGINE: django.db.backends.postgresql_psycopg2
DB_ENGINE: dj_db_conn_pool.backends.postgresql
EMBEDDING_MODEL_PATH: /opt/maxkb/model/embedding
EMBEDDING_MODEL_NAME: /opt/maxkb/model/embedding/shibing624_text2vec-base-chinese

Expand Down
Loading

0 comments on commit efa73c8

Please sign in to comment.