diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 48c373307c1..3d30758a27c 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -7,6 +7,7 @@ @desc: """ import time +import uuid from abc import abstractmethod from typing import Type, Dict, List @@ -31,7 +32,7 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable: answer = step_variable['answer'] yield answer - workflow.answer += answer + workflow.append_answer(answer) if global_variable is not None: for key in global_variable: workflow.context[key] = global_variable[key] @@ -54,15 +55,27 @@ def handler(self, chat_id, 'message_tokens' in row and row.get('message_tokens') is not None]) answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row and row.get('answer_tokens') is not None]) - chat_record = ChatRecord(id=chat_record_id, - chat_id=chat_id, - problem_text=question, - answer_text=answer, - details=details, - message_tokens=message_tokens, - answer_tokens=answer_tokens, - run_time=time.time() - workflow.context['start_time'], - index=0) + answer_text_list = workflow.get_answer_text_list() + answer_text = '\n\n'.join(answer_text_list) + if workflow.chat_record is not None: + chat_record = workflow.chat_record + chat_record.answer_text = answer_text + chat_record.details = details + chat_record.message_tokens = message_tokens + chat_record.answer_tokens = answer_tokens + chat_record.answer_text_list = answer_text_list + chat_record.run_time = time.time() - workflow.context['start_time'] + else: + chat_record = ChatRecord(id=chat_record_id, + chat_id=chat_id, + problem_text=question, + answer_text=answer_text, + details=details, + message_tokens=message_tokens, + answer_tokens=answer_tokens, + answer_text_list=answer_text_list, + run_time=time.time() - workflow.context['start_time'], + index=0) self.chat_info.append_chat_record(chat_record, self.client_id) # 重新设置缓存 chat_cache.set(chat_id, @@ -118,7 +131,15 @@ class FlowParamsSerializer(serializers.Serializer): class INode: - def __init__(self, node, workflow_params, workflow_manage): + + @abstractmethod + def save_context(self, details, workflow_manage): + pass + + def get_answer_text(self): + return self.answer_text + + def __init__(self, node, workflow_params, workflow_manage, runtime_node_id=None): # 当前步骤上下文,用于存储当前步骤信息 self.status = 200 self.err_message = '' @@ -129,7 +150,12 @@ def __init__(self, node, workflow_params, workflow_manage): self.node_params_serializer = None self.flow_params_serializer = None self.context = {} + self.answer_text = None self.id = node.id + if runtime_node_id is None: + self.runtime_node_id = str(uuid.uuid1()) + else: + self.runtime_node_id = runtime_node_id def valid_args(self, node_params, flow_params): flow_params_serializer_class = self.get_flow_params_serializer_class() diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index c1ebd222437..cd8b08a974a 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -9,19 +9,23 @@ from .ai_chat_step_node import * from .application_node import BaseApplicationNode from .condition_node import * -from .question_node import * -from .search_dataset_node import * -from .start_node import * from .direct_reply_node import * +from .form_node import * from .function_lib_node import * from .function_node import * +from .question_node import * from .reranker_node import * + from .document_extract_node import * from .image_understand_step_node import * +from .search_dataset_node import * +from .start_node import * + node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode, - BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode, - BaseImageUnderstandNode] + BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, + BaseDocumentExtractNode, + BaseImageUnderstandNode, BaseFormNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index daa7b452fbc..d4835f6190e 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -32,7 +32,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): - workflow.answer += answer + node.answer_text = answer def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -73,6 +73,11 @@ def get_default_model_params_setting(model_id): class BaseChatNode(IChatNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + self.answer_text = details.get('answer') + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py index 7f4644a5815..f1abca40d8f 100644 --- a/apps/application/flow/step_node/application_node/impl/base_application_node.py +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -21,7 +21,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): - workflow.answer += answer + node.answer_text = answer def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -64,6 +64,12 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor class BaseApplicationNode(IApplicationNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + self.context['type'] = details.get('type') + self.answer_text = details.get('answer') + def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, **kwargs) -> NodeResult: from application.serializers.chat_message_serializers import ChatMessageSerializer diff --git a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py index 3164bb9feb7..2c5260f8952 100644 --- a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py +++ b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py @@ -14,6 +14,10 @@ class BaseConditionNode(IConditionNode): + def save_context(self, details, workflow_manage): + self.context['branch_id'] = details.get('branch_id') + self.context['branch_name'] = details.get('branch_name') + def execute(self, **kwargs) -> NodeResult: branch_list = self.node_params_serializer.data['branch'] branch = self._execute(branch_list) diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py index de79279d932..6a51edd6bae 100644 --- a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -13,6 +13,9 @@ class BaseReplyNode(IReplyNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.answer_text = details.get('answer') def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: if reply_type == 'referencing': result = self.get_reference_content(fields) diff --git a/apps/application/flow/step_node/form_node/__init__.py b/apps/application/flow/step_node/form_node/__init__.py new file mode 100644 index 00000000000..ce04b64aea8 --- /dev/null +++ b/apps/application/flow/step_node/form_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/11/4 14:48 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/form_node/i_form_node.py b/apps/application/flow/step_node/form_node/i_form_node.py new file mode 100644 index 00000000000..18ff91fda8b --- /dev/null +++ b/apps/application/flow/step_node/form_node/i_form_node.py @@ -0,0 +1,32 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_form_node.py + @date:2024/11/4 14:48 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class FormNodeParamsSerializer(serializers.Serializer): + form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list("表单配置")) + form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char('表单输出内容')) + + +class IFormNode(INode): + type = 'form-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return FormNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/form_node/impl/__init__.py b/apps/application/flow/step_node/form_node/impl/__init__.py new file mode 100644 index 00000000000..4cea85e1d9e --- /dev/null +++ b/apps/application/flow/step_node/form_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/11/4 14:49 + @desc: +""" +from .base_form_node import BaseFormNode diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py new file mode 100644 index 00000000000..1d51705772a --- /dev/null +++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py @@ -0,0 +1,84 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_form_node.py + @date:2024/11/4 14:52 + @desc: +""" +import json +import time +from typing import Dict + +from langchain_core.prompts import PromptTemplate + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.form_node.i_form_node import IFormNode + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable: + result = step_variable['result'] + yield result + node.answer_text = result + node.context['run_time'] = time.time() - node.context['start_time'] + + +class BaseFormNode(IFormNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + self.context['form_content_format'] = details.get('form_content_format') + self.context['form_field_list'] = details.get('form_field_list') + self.context['run_time'] = details.get('run_time') + self.context['start_time'] = details.get('start_time') + self.answer_text = details.get('result') + + def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult: + form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, + "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), + "is_submit": self.context.get("is_submit", False)} + form = f'{json.dumps(form_setting)}' + prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') + value = prompt_template.format(form=form) + return NodeResult( + {'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {}, + _write_context=write_context) + + def get_answer_text(self): + form_content_format = self.context.get('form_content_format') + form_field_list = self.context.get('form_field_list') + form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, + "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), + 'form_data': self.context.get('form_data', {}), + "is_submit": self.context.get("is_submit", False)} + form = f'{json.dumps(form_setting)}' + prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') + value = prompt_template.format(form=form) + return value + + def get_details(self, index: int, **kwargs): + form_content_format = self.context.get('form_content_format') + form_field_list = self.context.get('form_field_list') + form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, + "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), + 'form_data': self.context.get('form_data', {}), + "is_submit": self.context.get("is_submit", False)} + form = f'{json.dumps(form_setting)}' + prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') + value = prompt_template.format(form=form) + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "result": value, + "form_content_format": self.context.get('form_content_format'), + "form_field_list": self.context.get('form_field_list'), + 'form_data': self.context.get('form_data'), + 'start_time': self.context.get('start_time'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py index 64e1c556eb9..273b84d9786 100644 --- a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py +++ b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py @@ -91,6 +91,9 @@ def convert_value(name: str, value, _type, is_required, source, node): class BaseFunctionLibNodeNode(IFunctionLibNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + self.answer_text = details.get('result') def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult: function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first() if not function_lib.is_active: diff --git a/apps/application/flow/step_node/function_node/impl/base_function_node.py b/apps/application/flow/step_node/function_node/impl/base_function_node.py index f2aead83fa8..3336b308ac5 100644 --- a/apps/application/flow/step_node/function_node/impl/base_function_node.py +++ b/apps/application/flow/step_node/function_node/impl/base_function_node.py @@ -78,6 +78,10 @@ def convert_value(name: str, value, _type, is_required, source, node): class BaseFunctionNodeNode(IFunctionNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + self.answer_text = details.get('result') + def execute(self, input_field_list, code, **kwargs) -> NodeResult: params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'), field.get('is_required'), field.get('source'), self) diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 9731fbd80cc..e6a88a9f157 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -58,6 +58,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor class BaseImageUnderstandNode(IImageUnderstandNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + self.answer_text = details.get('answer') + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, image, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index 8e43a9b8e4d..f69cb912cb7 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -32,7 +32,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): - workflow.answer += answer + node.answer_text = answer def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -73,6 +73,14 @@ def get_default_model_params_setting(model_id): class BaseQuestionNode(IQuestionNode): + def save_context(self, details, workflow_manage): + self.context['run_time'] = details.get('run_time') + self.context['question'] = details.get('question') + self.context['answer'] = details.get('answer') + self.context['message_tokens'] = details.get('message_tokens') + self.context['answer_tokens'] = details.get('answer_tokens') + self.answer_text = details.get('answer') + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py index d1eef33d40b..59dba2ff953 100644 --- a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py +++ b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py @@ -44,6 +44,13 @@ def filter_result(document_list: List[Document], max_paragraph_char_number, top_ class BaseRerankerNode(IRerankerNode): + def save_context(self, details, workflow_manage): + self.context['document_list'] = details.get('document_list', []) + self.context['question'] = details.get('question') + self.context['run_time'] = details.get('run_time') + self.context['result_list'] = details.get('result_list') + self.context['result'] = details.get('result') + def execute(self, question, reranker_setting, reranker_list, reranker_model_id, **kwargs) -> NodeResult: documents = merge_reranker_list(reranker_list) diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 693495a6a78..84af258bedf 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -45,6 +45,21 @@ def reset_title(title): class BaseSearchDatasetNode(ISearchDatasetStepNode): + def save_context(self, details, workflow_manage): + result = details.get('paragraph_list', []) + dataset_setting = self.node_params_serializer.data.get('dataset_setting') + directly_return = '\n'.join( + [f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in result if + paragraph.get('is_hit_handling_method')]) + self.context['paragraph_list'] = result + self.context['question'] = details.get('question') + self.context['run_time'] = details.get('run_time') + self.context['is_hit_handling_method_list'] = [row for row in result if row.get('is_hit_handling_method')] + self.context['data'] = '\n'.join( + [f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in + result])[0:dataset_setting.get('max_paragraph_char_number', 5000)] + self.context['directly_return'] = directly_return + def execute(self, dataset_id_list, dataset_setting, question, exclude_paragraph_id_list=None, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/start_node/i_start_node.py b/apps/application/flow/step_node/start_node/i_start_node.py index bb23ad3f53e..41d73f21811 100644 --- a/apps/application/flow/step_node/start_node/i_start_node.py +++ b/apps/application/flow/step_node/start_node/i_start_node.py @@ -6,9 +6,6 @@ @date:2024/6/3 16:54 @desc: """ -from typing import Type - -from rest_framework import serializers from application.flow.i_step_node import INode, NodeResult diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index 186623123d3..9ac7e2aefed 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -31,6 +31,17 @@ def get_global_variable(node): class BaseStartStepNode(IStarNode): + def save_context(self, details, workflow_manage): + base_node = self.workflow_manage.get_base_node() + default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', [])) + workflow_variable = {**default_global_variable, **get_global_variable(self)} + self.context['question'] = details.get('question') + self.context['run_time'] = details.get('run_time') + self.status = details.get('status') + self.err_message = details.get('err_message') + for key, value in workflow_variable.items(): + workflow_manage.context[key] = value + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: pass diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 5de8e0ddc17..39c42aa2449 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -9,7 +9,6 @@ import json import threading import traceback -import uuid from concurrent.futures import ThreadPoolExecutor from functools import reduce from typing import List, Dict @@ -212,12 +211,12 @@ def pop(self): except IndexError as e: if self.current_node_chunk.is_end(): self.current_node_chunk = None - if len(self.work_flow.answer) > 0: + if self.work_flow.answer_is_not_empty(): chunk = self.work_flow.base_to_response.to_stream_chunk_response( self.work_flow.params['chat_id'], self.work_flow.params['chat_record_id'], '\n\n', False, 0, 0) - self.work_flow.answer += '\n\n' + self.work_flow.append_answer('\n\n') return chunk return self.pop() return None @@ -240,29 +239,65 @@ def is_end(self): class WorkflowManage: def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, - base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None): + base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, + start_node_id=None, + start_node_data=None, chat_record=None): if form_data is None: form_data = {} if image_list is None: image_list = [] + self.start_node = None + self.start_node_result_future = None self.form_data = form_data self.image_list = image_list self.params = params self.flow = flow self.lock = threading.Lock() self.context = {} - self.node_context = [] self.node_chunk_manage = NodeChunkManage(self) self.work_flow_post_handler = work_flow_post_handler self.current_node = None self.current_result = None self.answer = "" + self.answer_list = [''] self.status = 0 self.base_to_response = base_to_response + self.chat_record = chat_record + if start_node_id is not None: + self.load_node(chat_record, start_node_id, start_node_data) + else: + self.node_context = [] + + def append_answer(self, content): + self.answer += content + self.answer_list[-1] += content + + def answer_is_not_empty(self): + return len(self.answer_list[-1]) > 0 + + def load_node(self, chat_record, start_node_id, start_node_data): + self.node_context = [] + self.answer = chat_record.answer_text + self.answer_list = chat_record.answer_text_list + self.answer_list.append('') + for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')): + node_id = node_details.get('node_id') + if node_details.get('runtime_node_id') == start_node_id: + self.start_node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id')) + self.start_node.valid_args(self.start_node.node_params, self.start_node.workflow_params) + self.start_node.save_context(node_details, self) + node_result = NodeResult({**start_node_data, 'form_data': start_node_data, 'is_submit': True}, {}) + self.start_node_result_future = NodeResultFuture(node_result, None) + return + node_id = node_details.get('node_id') + node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id')) + node.valid_args(node.node_params, node.workflow_params) + node.save_context(node_details, self) + self.node_context.append(node) def run(self): if self.params.get('stream'): - return self.run_stream() + return self.run_stream(self.start_node, self.start_node_result_future) return self.run_block() def run_block(self): @@ -270,7 +305,7 @@ def run_block(self): 非流式响应 @return: 结果 """ - result = self.run_chain_async(None) + result = self.run_chain_async(None, None) result.result() details = self.get_runtime_details() message_tokens = sum([row.get('message_tokens') for row in details.values() if @@ -285,12 +320,12 @@ def run_block(self): , message_tokens, answer_tokens, _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR) - def run_stream(self): + def run_stream(self, current_node, node_result_future): """ 流式响应 @return: """ - result = self.run_chain_async(None) + result = self.run_chain_async(current_node, node_result_future) return tools.to_stream_response_simple(self.await_result(result)) def await_result(self, result): @@ -307,21 +342,23 @@ def await_result(self, result): if chunk is None: break yield chunk + yield self.get_chunk_content('', True) finally: self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], self.answer, self) yield self.get_chunk_content('', True) - def run_chain_async(self, current_node): - future = executor.submit(self.run_chain, current_node) + def run_chain_async(self, current_node, node_result_future): + future = executor.submit(self.run_chain, current_node, node_result_future) return future - def run_chain(self, current_node): + def run_chain(self, current_node, node_result_future=None): if current_node is None: start_node = self.get_start_node() current_node = get_node(start_node.type)(start_node, self.params, self) - node_result_future = self.run_node_future(current_node) + if node_result_future is None: + node_result_future = self.run_node_future(current_node) try: is_stream = self.params.get('stream', True) # 处理节点响应 @@ -335,7 +372,7 @@ def run_chain(self, current_node): # 获取到可执行的子节点 result_list = [] for node in node_list: - result = self.run_chain_async(node) + result = self.run_chain_async(node, None) result_list.append(result) [r.result() for r in result_list] if self.status == 0: @@ -445,10 +482,41 @@ def get_runtime_details(self): details_result = {} for index in range(len(self.node_context)): node = self.node_context[index] + if self.chat_record is not None and self.chat_record.details is not None: + details = self.chat_record.details.get(node.runtime_node_id) + if details is not None and self.start_node.runtime_node_id != node.runtime_node_id: + details_result[node.runtime_node_id] = details + continue details = node.get_details(index) - details_result[str(uuid.uuid1())] = details + details['node_id'] = node.id + details['runtime_node_id'] = node.runtime_node_id + details_result[node.runtime_node_id] = details return details_result + def get_answer_text_list(self): + answer_text_list = [] + for index in range(len(self.node_context)): + node = self.node_context[index] + answer_text = node.get_answer_text() + if answer_text is not None: + if self.chat_record is not None and self.chat_record.details is not None: + details = self.chat_record.details.get(node.runtime_node_id) + if details is not None and self.start_node.runtime_node_id != node.runtime_node_id: + continue + answer_text_list.append( + {'content': answer_text, 'type': 'form' if node.type == 'form-node' else 'md'}) + result = [] + for index in range(len(answer_text_list)): + answer = answer_text_list[index] + if index == 0: + result.append(answer.get('content')) + continue + if answer.get('type') != answer_text_list[index - 1]: + result.append(answer.get('content')) + else: + result[-1] += answer.get('content') + return result + def get_next_node(self): """ 获取下一个可运行的所有节点 @@ -485,6 +553,8 @@ def get_next_node_list(self, current_node, current_node_result): @param current_node_result: 当前可执行节点结果 @return: 可执行节点列表 """ + if current_node.type == 'form-node' and 'form_data' not in current_node_result.node_variable: + return [] node_list = [] if current_node_result is not None and current_node_result.is_assertion_result(): for edge in self.flow.edges: @@ -537,7 +607,6 @@ def generate_prompt(self, prompt: str): prompt = prompt.replace(globeLabel, globeValue) context[node.id] = node.context prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2') - value = prompt_template.format(context=context) return value @@ -557,11 +626,11 @@ def get_base_node(self): base_node_list = [node for node in self.flow.nodes if node.type == 'base-node'] return base_node_list[0] - def get_node_cls_by_id(self, node_id): + def get_node_cls_by_id(self, node_id, runtime_node_id=None): for node in self.flow.nodes: if node.id == node_id: node_instance = get_node(node.type)(node, - self.params, self) + self.params, self, runtime_node_id) return node_instance return None diff --git a/apps/application/migrations/0019_application_file_upload_enable_and_more.py b/apps/application/migrations/0019_application_file_upload_enable_and_more.py index 0f934d6bf5f..f59a4990c2f 100644 --- a/apps/application/migrations/0019_application_file_upload_enable_and_more.py +++ b/apps/application/migrations/0019_application_file_upload_enable_and_more.py @@ -1,10 +1,15 @@ -# Generated by Django 4.2.15 on 2024-11-07 11:22 +# Generated by Django 4.2.15 on 2024-11-13 10:13 +import django.contrib.postgres.fields from django.db import migrations, models +sql = """ +UPDATE "public".application_chat_record +SET "answer_text_list" = ARRAY[answer_text]; +""" -class Migration(migrations.Migration): +class Migration(migrations.Migration): dependencies = [ ('application', '0018_workflowversion_name'), ] @@ -20,4 +25,11 @@ class Migration(migrations.Migration): name='file_upload_setting', field=models.JSONField(default={}, verbose_name='文件上传相关设置'), ), + migrations.AddField( + model_name='chatrecord', + name='answer_text_list', + field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=40960), default=list, + size=None, verbose_name='改进标注列表'), + ), + migrations.RunSQL(sql) ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 8df9ac19032..5df928c4dab 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -147,6 +147,9 @@ class ChatRecord(AppModelMixin): default=VoteChoices.UN_VOTE) problem_text = models.CharField(max_length=10240, verbose_name="问题") answer_text = models.CharField(max_length=40960, verbose_name="答案") + answer_text_list = ArrayField(verbose_name="改进标注列表", + base_field=models.CharField(max_length=40960) + , default=list) message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) const = models.IntegerField(verbose_name="总费用", default=0) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index e0576b247b9..b14e9f33e3f 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -214,10 +214,15 @@ def chat(self, instance: Dict, with_valid=True): class ChatMessageSerializer(serializers.Serializer): - chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id")) + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题")) stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答")) re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答")) + chat_record_id = serializers.UUIDField(required=False, allow_null=True, + error_messages=ErrMessage.uuid("对话记录id")) + runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("节点id")) + node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数")) application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) @@ -293,6 +298,19 @@ def chat_simple(self, chat_info: ChatInfo, base_to_response): pipeline_message.run(params) return pipeline_message.context['chat_result'] + @staticmethod + def get_chat_record(chat_info, chat_record_id): + if chat_info is not None: + chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if + str(chat_record.id) == str(chat_record_id)] + if chat_record_list is not None and len(chat_record_list): + return chat_record_list[-1] + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_info.chat_id).first() + if chat_record is None: + raise ChatException(500, "对话纪要不存在") + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id).first() + return chat_record + def chat_work_flow(self, chat_info: ChatInfo, base_to_response): message = self.data.get('message') re_chat = self.data.get('re_chat') @@ -302,15 +320,21 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response): form_data = self.data.get('form_data') image_list = self.data.get('image_list') user_id = chat_info.application.user_id + chat_record_id = self.data.get('chat_record_id') + chat_record = None + if chat_record_id is not None: + chat_record = self.get_chat_record(chat_info, chat_record_id) work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), {'history_chat_record': chat_info.chat_record_list, 'question': message, - 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), + 'chat_id': chat_info.chat_id, 'chat_record_id': str( + uuid.uuid1()) if chat_record is None else chat_record.id, 'stream': stream, 're_chat': re_chat, 'client_id': client_id, 'client_type': client_type, 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), - base_to_response, form_data, image_list) + base_to_response, form_data, image_list, self.data.get('runtime_node_id'), + self.data.get('node_data'), chat_record) r = work_flow_manage.run() return r diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 40a16af2fae..ad25ac7d915 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -397,7 +397,7 @@ class ChatRecordSerializerModel(serializers.ModelSerializer): class Meta: model = ChatRecord fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text', - 'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index', + 'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index','answer_text_list', 'create_time', 'update_time'] diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 4b60e4bcb28..586787b203a 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -129,9 +129,14 @@ def post(self, request: Request, chat_id: str): 'client_id': request.auth.client_id, 'form_data': (request.data.get( 'form_data') if 'form_data' in request.data else {}), + 'image_list': request.data.get( 'image_list') if 'image_list' in request.data else [], - 'client_type': request.auth.client_type}).chat() + 'client_type': request.auth.client_type, + 'runtime_node_id': request.data.get('runtime_node_id', None), + 'node_data': request.data.get('node_data', {}), + 'chat_record_id': request.data.get('chat_record_id')} + ).chat() @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取对话列表", diff --git a/apps/dataset/migrations/0010_file_meta.py b/apps/dataset/migrations/0010_file_meta.py index f227554ba26..6e28e3eecc3 100644 --- a/apps/dataset/migrations/0010_file_meta.py +++ b/apps/dataset/migrations/0010_file_meta.py @@ -13,6 +13,6 @@ class Migration(migrations.Migration): migrations.AddField( model_name='file', name='meta', - field=models.JSONField(default={}, verbose_name='文件关联数据'), + field=models.JSONField(default=dict, verbose_name='文件关联数据'), ), ] diff --git a/ui/env.d.ts b/ui/env.d.ts index 52f54527078..be1f8389d30 100644 --- a/ui/env.d.ts +++ b/ui/env.d.ts @@ -8,7 +8,10 @@ declare module 'markdown-it-sub' declare module 'markdown-it-sup' declare module 'markdown-it-toc-done-right' declare module 'katex' +interface Window { + sendMessage: ?((message: string, other_params_data: any) => void) +} interface ImportMeta { readonly env: ImportMetaEnv } -declare type Recordable = Record; +declare type Recordable = Record diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 00ad179fe9e..5d4971f7fa4 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -27,6 +27,7 @@ interface chatType { problem_text: string answer_text: string buffer: Array + answer_text_list: Array /** * 是否写入结束 */ @@ -36,6 +37,7 @@ interface chatType { */ is_stop?: boolean record_id: string + chat_id: string vote_status: string status?: number } @@ -56,18 +58,25 @@ export class ChatRecordManage { this.is_close = false this.write_ed = false } + append_answer(chunk_answer: String) { + this.chat.answer_text_list[this.chat.answer_text_list.length - 1] = + this.chat.answer_text_list[this.chat.answer_text_list.length - 1] + chunk_answer + this.chat.answer_text = this.chat.answer_text + chunk_answer + } write() { this.chat.is_stop = false this.is_stop = false + this.is_close = false + this.write_ed = false + this.chat.write_ed = false if (this.loading) { this.loading.value = true } this.id = setInterval(() => { if (this.chat.buffer.length > 20) { - this.chat.answer_text = - this.chat.answer_text + this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('') + this.append_answer(this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('')) } else if (this.is_close) { - this.chat.answer_text = this.chat.answer_text + this.chat.buffer.splice(0).join('') + this.append_answer(this.chat.buffer.splice(0).join('')) this.chat.write_ed = true this.write_ed = true if (this.loading) { @@ -79,7 +88,7 @@ export class ChatRecordManage { } else { const s = this.chat.buffer.shift() if (s !== undefined) { - this.chat.answer_text = this.chat.answer_text + s + this.append_answer(s) } } }, this.ms) @@ -95,6 +104,10 @@ export class ChatRecordManage { close() { this.is_close = true } + open() { + this.is_close = false + this.is_stop = false + } append(answer_text_block: string) { for (let index = 0; index < answer_text_block.length; index++) { this.chat.buffer.push(answer_text_block[index]) diff --git a/ui/src/assets/icon_form.svg b/ui/src/assets/icon_form.svg new file mode 100644 index 00000000000..22a10210da3 --- /dev/null +++ b/ui/src/assets/icon_form.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/ui/src/components/ai-chat/component/answer-content/index.vue b/ui/src/components/ai-chat/component/answer-content/index.vue new file mode 100644 index 00000000000..e8e761b55ee --- /dev/null +++ b/ui/src/components/ai-chat/component/answer-content/index.vue @@ -0,0 +1,94 @@ + + + diff --git a/ui/src/components/ai-chat/component/chat-input-operate/index.vue b/ui/src/components/ai-chat/component/chat-input-operate/index.vue new file mode 100644 index 00000000000..abb0bfbe095 --- /dev/null +++ b/ui/src/components/ai-chat/component/chat-input-operate/index.vue @@ -0,0 +1,392 @@ + + + diff --git a/ui/src/components/ai-chat/OperationButton.vue b/ui/src/components/ai-chat/component/operation-button/ChatOperationButton.vue similarity index 95% rename from ui/src/components/ai-chat/OperationButton.vue rename to ui/src/components/ai-chat/component/operation-button/ChatOperationButton.vue index 6a125429659..198ce3ed1e2 100644 --- a/ui/src/components/ai-chat/OperationButton.vue +++ b/ui/src/components/ai-chat/component/operation-button/ChatOperationButton.vue @@ -19,7 +19,7 @@ - + @@ -90,26 +90,21 @@ const { params: { id } } = route as any -const props = defineProps({ - data: { - type: Object, - default: () => {} - }, - applicationId: { - type: String, - default: '' - }, - chatId: { - type: String, - default: '' - }, - chat_loading: { - type: Boolean - }, - log: Boolean, - tts: Boolean, - tts_type: String -}) +const props = withDefaults( + defineProps<{ + data: any + type: 'log' | 'ai-chat' | 'debug-ai-chat' + chatId: string + chat_loading: boolean + applicationId: string + tts: boolean + tts_type: string + }>(), + { + data: () => ({}), + type: 'ai-chat' + } +) const emit = defineEmits(['update:data', 'regeneration']) diff --git a/ui/src/components/ai-chat/LogOperationButton.vue b/ui/src/components/ai-chat/component/operation-button/LogOperationButton.vue similarity index 68% rename from ui/src/components/ai-chat/LogOperationButton.vue rename to ui/src/components/ai-chat/component/operation-button/LogOperationButton.vue index 10dd47165df..96ccaf756ae 100644 --- a/ui/src/components/ai-chat/LogOperationButton.vue +++ b/ui/src/components/ai-chat/component/operation-button/LogOperationButton.vue @@ -1,59 +1,61 @@ + diff --git a/ui/src/components/ai-chat/component/prologue-content/index.vue b/ui/src/components/ai-chat/component/prologue-content/index.vue new file mode 100644 index 00000000000..e8df17898a5 --- /dev/null +++ b/ui/src/components/ai-chat/component/prologue-content/index.vue @@ -0,0 +1,35 @@ + + + diff --git a/ui/src/components/ai-chat/component/question-content/index.vue b/ui/src/components/ai-chat/component/question-content/index.vue new file mode 100644 index 00000000000..a9d202e4f27 --- /dev/null +++ b/ui/src/components/ai-chat/component/question-content/index.vue @@ -0,0 +1,81 @@ + + + diff --git a/ui/src/components/ai-chat/component/user-form/index.vue b/ui/src/components/ai-chat/component/user-form/index.vue new file mode 100644 index 00000000000..f64d08b7376 --- /dev/null +++ b/ui/src/components/ai-chat/component/user-form/index.vue @@ -0,0 +1,304 @@ + + + diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 38bda779f7f..0ea8dd18986 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -1,368 +1,100 @@ + diff --git a/ui/src/components/markdown/MdRenderer.vue b/ui/src/components/markdown/MdRenderer.vue index 517a3bc5683..f8faf3b2df1 100644 --- a/ui/src/components/markdown/MdRenderer.vue +++ b/ui/src/components/markdown/MdRenderer.vue @@ -2,9 +2,9 @@