diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 48c373307c1..3d30758a27c 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -7,6 +7,7 @@ @desc: """ import time +import uuid from abc import abstractmethod from typing import Type, Dict, List @@ -31,7 +32,7 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow): if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable: answer = step_variable['answer'] yield answer - workflow.answer += answer + workflow.append_answer(answer) if global_variable is not None: for key in global_variable: workflow.context[key] = global_variable[key] @@ -54,15 +55,27 @@ def handler(self, chat_id, 'message_tokens' in row and row.get('message_tokens') is not None]) answer_tokens = sum([row.get('answer_tokens') for row in details.values() if 'answer_tokens' in row and row.get('answer_tokens') is not None]) - chat_record = ChatRecord(id=chat_record_id, - chat_id=chat_id, - problem_text=question, - answer_text=answer, - details=details, - message_tokens=message_tokens, - answer_tokens=answer_tokens, - run_time=time.time() - workflow.context['start_time'], - index=0) + answer_text_list = workflow.get_answer_text_list() + answer_text = '\n\n'.join(answer_text_list) + if workflow.chat_record is not None: + chat_record = workflow.chat_record + chat_record.answer_text = answer_text + chat_record.details = details + chat_record.message_tokens = message_tokens + chat_record.answer_tokens = answer_tokens + chat_record.answer_text_list = answer_text_list + chat_record.run_time = time.time() - workflow.context['start_time'] + else: + chat_record = ChatRecord(id=chat_record_id, + chat_id=chat_id, + problem_text=question, + answer_text=answer_text, + details=details, + message_tokens=message_tokens, + answer_tokens=answer_tokens, + answer_text_list=answer_text_list, + run_time=time.time() - workflow.context['start_time'], + index=0) self.chat_info.append_chat_record(chat_record, self.client_id) # 重新设置缓存 chat_cache.set(chat_id, @@ -118,7 +131,15 @@ class FlowParamsSerializer(serializers.Serializer): class INode: - def __init__(self, node, workflow_params, workflow_manage): + + @abstractmethod + def save_context(self, details, workflow_manage): + pass + + def get_answer_text(self): + return self.answer_text + + def __init__(self, node, workflow_params, workflow_manage, runtime_node_id=None): # 当前步骤上下文,用于存储当前步骤信息 self.status = 200 self.err_message = '' @@ -129,7 +150,12 @@ def __init__(self, node, workflow_params, workflow_manage): self.node_params_serializer = None self.flow_params_serializer = None self.context = {} + self.answer_text = None self.id = node.id + if runtime_node_id is None: + self.runtime_node_id = str(uuid.uuid1()) + else: + self.runtime_node_id = runtime_node_id def valid_args(self, node_params, flow_params): flow_params_serializer_class = self.get_flow_params_serializer_class() diff --git a/apps/application/flow/step_node/__init__.py b/apps/application/flow/step_node/__init__.py index c1ebd222437..cd8b08a974a 100644 --- a/apps/application/flow/step_node/__init__.py +++ b/apps/application/flow/step_node/__init__.py @@ -9,19 +9,23 @@ from .ai_chat_step_node import * from .application_node import BaseApplicationNode from .condition_node import * -from .question_node import * -from .search_dataset_node import * -from .start_node import * from .direct_reply_node import * +from .form_node import * from .function_lib_node import * from .function_node import * +from .question_node import * from .reranker_node import * + from .document_extract_node import * from .image_understand_step_node import * +from .search_dataset_node import * +from .start_node import * + node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode, - BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, BaseDocumentExtractNode, - BaseImageUnderstandNode] + BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode, BaseApplicationNode, + BaseDocumentExtractNode, + BaseImageUnderstandNode, BaseFormNode] def get_node(node_type): diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index daa7b452fbc..d4835f6190e 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -32,7 +32,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): - workflow.answer += answer + node.answer_text = answer def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -73,6 +73,11 @@ def get_default_model_params_setting(model_id): class BaseChatNode(IChatNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + self.answer_text = details.get('answer') + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/application_node/impl/base_application_node.py b/apps/application/flow/step_node/application_node/impl/base_application_node.py index 7f4644a5815..f1abca40d8f 100644 --- a/apps/application/flow/step_node/application_node/impl/base_application_node.py +++ b/apps/application/flow/step_node/application_node/impl/base_application_node.py @@ -21,7 +21,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): - workflow.answer += answer + node.answer_text = answer def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -64,6 +64,12 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor class BaseApplicationNode(IApplicationNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + self.context['type'] = details.get('type') + self.answer_text = details.get('answer') + def execute(self, application_id, message, chat_id, chat_record_id, stream, re_chat, client_id, client_type, **kwargs) -> NodeResult: from application.serializers.chat_message_serializers import ChatMessageSerializer diff --git a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py index 3164bb9feb7..2c5260f8952 100644 --- a/apps/application/flow/step_node/condition_node/impl/base_condition_node.py +++ b/apps/application/flow/step_node/condition_node/impl/base_condition_node.py @@ -14,6 +14,10 @@ class BaseConditionNode(IConditionNode): + def save_context(self, details, workflow_manage): + self.context['branch_id'] = details.get('branch_id') + self.context['branch_name'] = details.get('branch_name') + def execute(self, **kwargs) -> NodeResult: branch_list = self.node_params_serializer.data['branch'] branch = self._execute(branch_list) diff --git a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py index de79279d932..6a51edd6bae 100644 --- a/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py +++ b/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -13,6 +13,9 @@ class BaseReplyNode(IReplyNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.answer_text = details.get('answer') def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: if reply_type == 'referencing': result = self.get_reference_content(fields) diff --git a/apps/application/flow/step_node/form_node/__init__.py b/apps/application/flow/step_node/form_node/__init__.py new file mode 100644 index 00000000000..ce04b64aea8 --- /dev/null +++ b/apps/application/flow/step_node/form_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/11/4 14:48 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/apps/application/flow/step_node/form_node/i_form_node.py b/apps/application/flow/step_node/form_node/i_form_node.py new file mode 100644 index 00000000000..18ff91fda8b --- /dev/null +++ b/apps/application/flow/step_node/form_node/i_form_node.py @@ -0,0 +1,32 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_form_node.py + @date:2024/11/4 14:48 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class FormNodeParamsSerializer(serializers.Serializer): + form_field_list = serializers.ListField(required=True, error_messages=ErrMessage.list("表单配置")) + form_content_format = serializers.CharField(required=True, error_messages=ErrMessage.char('表单输出内容')) + + +class IFormNode(INode): + type = 'form-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return FormNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult: + pass diff --git a/apps/application/flow/step_node/form_node/impl/__init__.py b/apps/application/flow/step_node/form_node/impl/__init__.py new file mode 100644 index 00000000000..4cea85e1d9e --- /dev/null +++ b/apps/application/flow/step_node/form_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/11/4 14:49 + @desc: +""" +from .base_form_node import BaseFormNode diff --git a/apps/application/flow/step_node/form_node/impl/base_form_node.py b/apps/application/flow/step_node/form_node/impl/base_form_node.py new file mode 100644 index 00000000000..1d51705772a --- /dev/null +++ b/apps/application/flow/step_node/form_node/impl/base_form_node.py @@ -0,0 +1,84 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_form_node.py + @date:2024/11/4 14:52 + @desc: +""" +import json +import time +from typing import Dict + +from langchain_core.prompts import PromptTemplate + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.form_node.i_form_node import IFormNode + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable: + result = step_variable['result'] + yield result + node.answer_text = result + node.context['run_time'] = time.time() - node.context['start_time'] + + +class BaseFormNode(IFormNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + self.context['form_content_format'] = details.get('form_content_format') + self.context['form_field_list'] = details.get('form_field_list') + self.context['run_time'] = details.get('run_time') + self.context['start_time'] = details.get('start_time') + self.answer_text = details.get('result') + + def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult: + form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, + "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), + "is_submit": self.context.get("is_submit", False)} + form = f'{json.dumps(form_setting)}' + prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') + value = prompt_template.format(form=form) + return NodeResult( + {'result': value, 'form_field_list': form_field_list, 'form_content_format': form_content_format}, {}, + _write_context=write_context) + + def get_answer_text(self): + form_content_format = self.context.get('form_content_format') + form_field_list = self.context.get('form_field_list') + form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, + "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), + 'form_data': self.context.get('form_data', {}), + "is_submit": self.context.get("is_submit", False)} + form = f'{json.dumps(form_setting)}' + prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') + value = prompt_template.format(form=form) + return value + + def get_details(self, index: int, **kwargs): + form_content_format = self.context.get('form_content_format') + form_field_list = self.context.get('form_field_list') + form_setting = {"form_field_list": form_field_list, "runtime_node_id": self.runtime_node_id, + "chat_record_id": self.flow_params_serializer.data.get("chat_record_id"), + 'form_data': self.context.get('form_data', {}), + "is_submit": self.context.get("is_submit", False)} + form = f'{json.dumps(form_setting)}' + prompt_template = PromptTemplate.from_template(form_content_format, template_format='jinja2') + value = prompt_template.format(form=form) + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "result": value, + "form_content_format": self.context.get('form_content_format'), + "form_field_list": self.context.get('form_field_list'), + 'form_data': self.context.get('form_data'), + 'start_time': self.context.get('start_time'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py index 64e1c556eb9..273b84d9786 100644 --- a/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py +++ b/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py @@ -91,6 +91,9 @@ def convert_value(name: str, value, _type, is_required, source, node): class BaseFunctionLibNodeNode(IFunctionLibNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + self.answer_text = details.get('result') def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult: function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first() if not function_lib.is_active: diff --git a/apps/application/flow/step_node/function_node/impl/base_function_node.py b/apps/application/flow/step_node/function_node/impl/base_function_node.py index f2aead83fa8..3336b308ac5 100644 --- a/apps/application/flow/step_node/function_node/impl/base_function_node.py +++ b/apps/application/flow/step_node/function_node/impl/base_function_node.py @@ -78,6 +78,10 @@ def convert_value(name: str, value, _type, is_required, source, node): class BaseFunctionNodeNode(IFunctionNode): + def save_context(self, details, workflow_manage): + self.context['result'] = details.get('result') + self.answer_text = details.get('result') + def execute(self, input_field_list, code, **kwargs) -> NodeResult: params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'), field.get('is_required'), field.get('source'), self) diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 9731fbd80cc..e6a88a9f157 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -58,6 +58,11 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor class BaseImageUnderstandNode(IImageUnderstandNode): + def save_context(self, details, workflow_manage): + self.context['answer'] = details.get('answer') + self.context['question'] = details.get('question') + self.answer_text = details.get('answer') + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, image, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/question_node/impl/base_question_node.py b/apps/application/flow/step_node/question_node/impl/base_question_node.py index 8e43a9b8e4d..f69cb912cb7 100644 --- a/apps/application/flow/step_node/question_node/impl/base_question_node.py +++ b/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -32,7 +32,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo node.context['question'] = node_variable['question'] node.context['run_time'] = time.time() - node.context['start_time'] if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): - workflow.answer += answer + node.answer_text = answer def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): @@ -73,6 +73,14 @@ def get_default_model_params_setting(model_id): class BaseQuestionNode(IQuestionNode): + def save_context(self, details, workflow_manage): + self.context['run_time'] = details.get('run_time') + self.context['question'] = details.get('question') + self.context['answer'] = details.get('answer') + self.context['message_tokens'] = details.get('message_tokens') + self.context['answer_tokens'] = details.get('answer_tokens') + self.answer_text = details.get('answer') + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, model_params_setting=None, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py index d1eef33d40b..59dba2ff953 100644 --- a/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py +++ b/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py @@ -44,6 +44,13 @@ def filter_result(document_list: List[Document], max_paragraph_char_number, top_ class BaseRerankerNode(IRerankerNode): + def save_context(self, details, workflow_manage): + self.context['document_list'] = details.get('document_list', []) + self.context['question'] = details.get('question') + self.context['run_time'] = details.get('run_time') + self.context['result_list'] = details.get('result_list') + self.context['result'] = details.get('result') + def execute(self, question, reranker_setting, reranker_list, reranker_model_id, **kwargs) -> NodeResult: documents = merge_reranker_list(reranker_list) diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 693495a6a78..84af258bedf 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -45,6 +45,21 @@ def reset_title(title): class BaseSearchDatasetNode(ISearchDatasetStepNode): + def save_context(self, details, workflow_manage): + result = details.get('paragraph_list', []) + dataset_setting = self.node_params_serializer.data.get('dataset_setting') + directly_return = '\n'.join( + [f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in result if + paragraph.get('is_hit_handling_method')]) + self.context['paragraph_list'] = result + self.context['question'] = details.get('question') + self.context['run_time'] = details.get('run_time') + self.context['is_hit_handling_method_list'] = [row for row in result if row.get('is_hit_handling_method')] + self.context['data'] = '\n'.join( + [f"{paragraph.get('title', '')}:{paragraph.get('content')}" for paragraph in + result])[0:dataset_setting.get('max_paragraph_char_number', 5000)] + self.context['directly_return'] = directly_return + def execute(self, dataset_id_list, dataset_setting, question, exclude_paragraph_id_list=None, **kwargs) -> NodeResult: diff --git a/apps/application/flow/step_node/start_node/i_start_node.py b/apps/application/flow/step_node/start_node/i_start_node.py index bb23ad3f53e..41d73f21811 100644 --- a/apps/application/flow/step_node/start_node/i_start_node.py +++ b/apps/application/flow/step_node/start_node/i_start_node.py @@ -6,9 +6,6 @@ @date:2024/6/3 16:54 @desc: """ -from typing import Type - -from rest_framework import serializers from application.flow.i_step_node import INode, NodeResult diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index 186623123d3..9ac7e2aefed 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -31,6 +31,17 @@ def get_global_variable(node): class BaseStartStepNode(IStarNode): + def save_context(self, details, workflow_manage): + base_node = self.workflow_manage.get_base_node() + default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', [])) + workflow_variable = {**default_global_variable, **get_global_variable(self)} + self.context['question'] = details.get('question') + self.context['run_time'] = details.get('run_time') + self.status = details.get('status') + self.err_message = details.get('err_message') + for key, value in workflow_variable.items(): + workflow_manage.context[key] = value + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: pass diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 5de8e0ddc17..39c42aa2449 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -9,7 +9,6 @@ import json import threading import traceback -import uuid from concurrent.futures import ThreadPoolExecutor from functools import reduce from typing import List, Dict @@ -212,12 +211,12 @@ def pop(self): except IndexError as e: if self.current_node_chunk.is_end(): self.current_node_chunk = None - if len(self.work_flow.answer) > 0: + if self.work_flow.answer_is_not_empty(): chunk = self.work_flow.base_to_response.to_stream_chunk_response( self.work_flow.params['chat_id'], self.work_flow.params['chat_record_id'], '\n\n', False, 0, 0) - self.work_flow.answer += '\n\n' + self.work_flow.append_answer('\n\n') return chunk return self.pop() return None @@ -240,29 +239,65 @@ def is_end(self): class WorkflowManage: def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, - base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None): + base_to_response: BaseToResponse = SystemToResponse(), form_data=None, image_list=None, + start_node_id=None, + start_node_data=None, chat_record=None): if form_data is None: form_data = {} if image_list is None: image_list = [] + self.start_node = None + self.start_node_result_future = None self.form_data = form_data self.image_list = image_list self.params = params self.flow = flow self.lock = threading.Lock() self.context = {} - self.node_context = [] self.node_chunk_manage = NodeChunkManage(self) self.work_flow_post_handler = work_flow_post_handler self.current_node = None self.current_result = None self.answer = "" + self.answer_list = [''] self.status = 0 self.base_to_response = base_to_response + self.chat_record = chat_record + if start_node_id is not None: + self.load_node(chat_record, start_node_id, start_node_data) + else: + self.node_context = [] + + def append_answer(self, content): + self.answer += content + self.answer_list[-1] += content + + def answer_is_not_empty(self): + return len(self.answer_list[-1]) > 0 + + def load_node(self, chat_record, start_node_id, start_node_data): + self.node_context = [] + self.answer = chat_record.answer_text + self.answer_list = chat_record.answer_text_list + self.answer_list.append('') + for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')): + node_id = node_details.get('node_id') + if node_details.get('runtime_node_id') == start_node_id: + self.start_node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id')) + self.start_node.valid_args(self.start_node.node_params, self.start_node.workflow_params) + self.start_node.save_context(node_details, self) + node_result = NodeResult({**start_node_data, 'form_data': start_node_data, 'is_submit': True}, {}) + self.start_node_result_future = NodeResultFuture(node_result, None) + return + node_id = node_details.get('node_id') + node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id')) + node.valid_args(node.node_params, node.workflow_params) + node.save_context(node_details, self) + self.node_context.append(node) def run(self): if self.params.get('stream'): - return self.run_stream() + return self.run_stream(self.start_node, self.start_node_result_future) return self.run_block() def run_block(self): @@ -270,7 +305,7 @@ def run_block(self): 非流式响应 @return: 结果 """ - result = self.run_chain_async(None) + result = self.run_chain_async(None, None) result.result() details = self.get_runtime_details() message_tokens = sum([row.get('message_tokens') for row in details.values() if @@ -285,12 +320,12 @@ def run_block(self): , message_tokens, answer_tokens, _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR) - def run_stream(self): + def run_stream(self, current_node, node_result_future): """ 流式响应 @return: """ - result = self.run_chain_async(None) + result = self.run_chain_async(current_node, node_result_future) return tools.to_stream_response_simple(self.await_result(result)) def await_result(self, result): @@ -307,21 +342,23 @@ def await_result(self, result): if chunk is None: break yield chunk + yield self.get_chunk_content('', True) finally: self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], self.answer, self) yield self.get_chunk_content('', True) - def run_chain_async(self, current_node): - future = executor.submit(self.run_chain, current_node) + def run_chain_async(self, current_node, node_result_future): + future = executor.submit(self.run_chain, current_node, node_result_future) return future - def run_chain(self, current_node): + def run_chain(self, current_node, node_result_future=None): if current_node is None: start_node = self.get_start_node() current_node = get_node(start_node.type)(start_node, self.params, self) - node_result_future = self.run_node_future(current_node) + if node_result_future is None: + node_result_future = self.run_node_future(current_node) try: is_stream = self.params.get('stream', True) # 处理节点响应 @@ -335,7 +372,7 @@ def run_chain(self, current_node): # 获取到可执行的子节点 result_list = [] for node in node_list: - result = self.run_chain_async(node) + result = self.run_chain_async(node, None) result_list.append(result) [r.result() for r in result_list] if self.status == 0: @@ -445,10 +482,41 @@ def get_runtime_details(self): details_result = {} for index in range(len(self.node_context)): node = self.node_context[index] + if self.chat_record is not None and self.chat_record.details is not None: + details = self.chat_record.details.get(node.runtime_node_id) + if details is not None and self.start_node.runtime_node_id != node.runtime_node_id: + details_result[node.runtime_node_id] = details + continue details = node.get_details(index) - details_result[str(uuid.uuid1())] = details + details['node_id'] = node.id + details['runtime_node_id'] = node.runtime_node_id + details_result[node.runtime_node_id] = details return details_result + def get_answer_text_list(self): + answer_text_list = [] + for index in range(len(self.node_context)): + node = self.node_context[index] + answer_text = node.get_answer_text() + if answer_text is not None: + if self.chat_record is not None and self.chat_record.details is not None: + details = self.chat_record.details.get(node.runtime_node_id) + if details is not None and self.start_node.runtime_node_id != node.runtime_node_id: + continue + answer_text_list.append( + {'content': answer_text, 'type': 'form' if node.type == 'form-node' else 'md'}) + result = [] + for index in range(len(answer_text_list)): + answer = answer_text_list[index] + if index == 0: + result.append(answer.get('content')) + continue + if answer.get('type') != answer_text_list[index - 1]: + result.append(answer.get('content')) + else: + result[-1] += answer.get('content') + return result + def get_next_node(self): """ 获取下一个可运行的所有节点 @@ -485,6 +553,8 @@ def get_next_node_list(self, current_node, current_node_result): @param current_node_result: 当前可执行节点结果 @return: 可执行节点列表 """ + if current_node.type == 'form-node' and 'form_data' not in current_node_result.node_variable: + return [] node_list = [] if current_node_result is not None and current_node_result.is_assertion_result(): for edge in self.flow.edges: @@ -537,7 +607,6 @@ def generate_prompt(self, prompt: str): prompt = prompt.replace(globeLabel, globeValue) context[node.id] = node.context prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2') - value = prompt_template.format(context=context) return value @@ -557,11 +626,11 @@ def get_base_node(self): base_node_list = [node for node in self.flow.nodes if node.type == 'base-node'] return base_node_list[0] - def get_node_cls_by_id(self, node_id): + def get_node_cls_by_id(self, node_id, runtime_node_id=None): for node in self.flow.nodes: if node.id == node_id: node_instance = get_node(node.type)(node, - self.params, self) + self.params, self, runtime_node_id) return node_instance return None diff --git a/apps/application/migrations/0019_application_file_upload_enable_and_more.py b/apps/application/migrations/0019_application_file_upload_enable_and_more.py index 0f934d6bf5f..f59a4990c2f 100644 --- a/apps/application/migrations/0019_application_file_upload_enable_and_more.py +++ b/apps/application/migrations/0019_application_file_upload_enable_and_more.py @@ -1,10 +1,15 @@ -# Generated by Django 4.2.15 on 2024-11-07 11:22 +# Generated by Django 4.2.15 on 2024-11-13 10:13 +import django.contrib.postgres.fields from django.db import migrations, models +sql = """ +UPDATE "public".application_chat_record +SET "answer_text_list" = ARRAY[answer_text]; +""" -class Migration(migrations.Migration): +class Migration(migrations.Migration): dependencies = [ ('application', '0018_workflowversion_name'), ] @@ -20,4 +25,11 @@ class Migration(migrations.Migration): name='file_upload_setting', field=models.JSONField(default={}, verbose_name='文件上传相关设置'), ), + migrations.AddField( + model_name='chatrecord', + name='answer_text_list', + field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=40960), default=list, + size=None, verbose_name='改进标注列表'), + ), + migrations.RunSQL(sql) ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 8df9ac19032..5df928c4dab 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -147,6 +147,9 @@ class ChatRecord(AppModelMixin): default=VoteChoices.UN_VOTE) problem_text = models.CharField(max_length=10240, verbose_name="问题") answer_text = models.CharField(max_length=40960, verbose_name="答案") + answer_text_list = ArrayField(verbose_name="改进标注列表", + base_field=models.CharField(max_length=40960) + , default=list) message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) const = models.IntegerField(verbose_name="总费用", default=0) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index e0576b247b9..b14e9f33e3f 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -214,10 +214,15 @@ def chat(self, instance: Dict, with_valid=True): class ChatMessageSerializer(serializers.Serializer): - chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id")) + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题")) stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答")) re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答")) + chat_record_id = serializers.UUIDField(required=False, allow_null=True, + error_messages=ErrMessage.uuid("对话记录id")) + runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("节点id")) + node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数")) application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) @@ -293,6 +298,19 @@ def chat_simple(self, chat_info: ChatInfo, base_to_response): pipeline_message.run(params) return pipeline_message.context['chat_result'] + @staticmethod + def get_chat_record(chat_info, chat_record_id): + if chat_info is not None: + chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if + str(chat_record.id) == str(chat_record_id)] + if chat_record_list is not None and len(chat_record_list): + return chat_record_list[-1] + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_info.chat_id).first() + if chat_record is None: + raise ChatException(500, "对话纪要不存在") + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id).first() + return chat_record + def chat_work_flow(self, chat_info: ChatInfo, base_to_response): message = self.data.get('message') re_chat = self.data.get('re_chat') @@ -302,15 +320,21 @@ def chat_work_flow(self, chat_info: ChatInfo, base_to_response): form_data = self.data.get('form_data') image_list = self.data.get('image_list') user_id = chat_info.application.user_id + chat_record_id = self.data.get('chat_record_id') + chat_record = None + if chat_record_id is not None: + chat_record = self.get_chat_record(chat_info, chat_record_id) work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), {'history_chat_record': chat_info.chat_record_list, 'question': message, - 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), + 'chat_id': chat_info.chat_id, 'chat_record_id': str( + uuid.uuid1()) if chat_record is None else chat_record.id, 'stream': stream, 're_chat': re_chat, 'client_id': client_id, 'client_type': client_type, 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), - base_to_response, form_data, image_list) + base_to_response, form_data, image_list, self.data.get('runtime_node_id'), + self.data.get('node_data'), chat_record) r = work_flow_manage.run() return r diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 40a16af2fae..ad25ac7d915 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -397,7 +397,7 @@ class ChatRecordSerializerModel(serializers.ModelSerializer): class Meta: model = ChatRecord fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text', - 'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index', + 'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index','answer_text_list', 'create_time', 'update_time'] diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 4b60e4bcb28..586787b203a 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -129,9 +129,14 @@ def post(self, request: Request, chat_id: str): 'client_id': request.auth.client_id, 'form_data': (request.data.get( 'form_data') if 'form_data' in request.data else {}), + 'image_list': request.data.get( 'image_list') if 'image_list' in request.data else [], - 'client_type': request.auth.client_type}).chat() + 'client_type': request.auth.client_type, + 'runtime_node_id': request.data.get('runtime_node_id', None), + 'node_data': request.data.get('node_data', {}), + 'chat_record_id': request.data.get('chat_record_id')} + ).chat() @action(methods=['GET'], detail=False) @swagger_auto_schema(operation_summary="获取对话列表", diff --git a/apps/dataset/migrations/0010_file_meta.py b/apps/dataset/migrations/0010_file_meta.py index f227554ba26..6e28e3eecc3 100644 --- a/apps/dataset/migrations/0010_file_meta.py +++ b/apps/dataset/migrations/0010_file_meta.py @@ -13,6 +13,6 @@ class Migration(migrations.Migration): migrations.AddField( model_name='file', name='meta', - field=models.JSONField(default={}, verbose_name='文件关联数据'), + field=models.JSONField(default=dict, verbose_name='文件关联数据'), ), ] diff --git a/ui/env.d.ts b/ui/env.d.ts index 52f54527078..be1f8389d30 100644 --- a/ui/env.d.ts +++ b/ui/env.d.ts @@ -8,7 +8,10 @@ declare module 'markdown-it-sub' declare module 'markdown-it-sup' declare module 'markdown-it-toc-done-right' declare module 'katex' +interface Window { + sendMessage: ?((message: string, other_params_data: any) => void) +} interface ImportMeta { readonly env: ImportMetaEnv } -declare type Recordable = Record; +declare type Recordable = Record diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 00ad179fe9e..5d4971f7fa4 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -27,6 +27,7 @@ interface chatType { problem_text: string answer_text: string buffer: Array + answer_text_list: Array /** * 是否写入结束 */ @@ -36,6 +37,7 @@ interface chatType { */ is_stop?: boolean record_id: string + chat_id: string vote_status: string status?: number } @@ -56,18 +58,25 @@ export class ChatRecordManage { this.is_close = false this.write_ed = false } + append_answer(chunk_answer: String) { + this.chat.answer_text_list[this.chat.answer_text_list.length - 1] = + this.chat.answer_text_list[this.chat.answer_text_list.length - 1] + chunk_answer + this.chat.answer_text = this.chat.answer_text + chunk_answer + } write() { this.chat.is_stop = false this.is_stop = false + this.is_close = false + this.write_ed = false + this.chat.write_ed = false if (this.loading) { this.loading.value = true } this.id = setInterval(() => { if (this.chat.buffer.length > 20) { - this.chat.answer_text = - this.chat.answer_text + this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('') + this.append_answer(this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('')) } else if (this.is_close) { - this.chat.answer_text = this.chat.answer_text + this.chat.buffer.splice(0).join('') + this.append_answer(this.chat.buffer.splice(0).join('')) this.chat.write_ed = true this.write_ed = true if (this.loading) { @@ -79,7 +88,7 @@ export class ChatRecordManage { } else { const s = this.chat.buffer.shift() if (s !== undefined) { - this.chat.answer_text = this.chat.answer_text + s + this.append_answer(s) } } }, this.ms) @@ -95,6 +104,10 @@ export class ChatRecordManage { close() { this.is_close = true } + open() { + this.is_close = false + this.is_stop = false + } append(answer_text_block: string) { for (let index = 0; index < answer_text_block.length; index++) { this.chat.buffer.push(answer_text_block[index]) diff --git a/ui/src/assets/icon_form.svg b/ui/src/assets/icon_form.svg new file mode 100644 index 00000000000..22a10210da3 --- /dev/null +++ b/ui/src/assets/icon_form.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/ui/src/components/ai-chat/component/answer-content/index.vue b/ui/src/components/ai-chat/component/answer-content/index.vue new file mode 100644 index 00000000000..e8e761b55ee --- /dev/null +++ b/ui/src/components/ai-chat/component/answer-content/index.vue @@ -0,0 +1,94 @@ + + + + + + + + + + + + 已停止回答 + + + 回答中 + + + + + + + + + + + + + diff --git a/ui/src/components/ai-chat/component/chat-input-operate/index.vue b/ui/src/components/ai-chat/component/chat-input-operate/index.vue new file mode 100644 index 00000000000..abb0bfbe095 --- /dev/null +++ b/ui/src/components/ai-chat/component/chat-input-operate/index.vue @@ -0,0 +1,392 @@ + + + + + + + + + + + + + + + + + + + + + + + + 00:{{ recorderTime < 10 ? `0${recorderTime}` : recorderTime }} + + + + + + + + + + + + + + + + + {{ applicationDetails.disclaimer_value }} + + + + + + + diff --git a/ui/src/components/ai-chat/OperationButton.vue b/ui/src/components/ai-chat/component/operation-button/ChatOperationButton.vue similarity index 95% rename from ui/src/components/ai-chat/OperationButton.vue rename to ui/src/components/ai-chat/component/operation-button/ChatOperationButton.vue index 6a125429659..198ce3ed1e2 100644 --- a/ui/src/components/ai-chat/OperationButton.vue +++ b/ui/src/components/ai-chat/component/operation-button/ChatOperationButton.vue @@ -19,7 +19,7 @@ - + @@ -90,26 +90,21 @@ const { params: { id } } = route as any -const props = defineProps({ - data: { - type: Object, - default: () => {} - }, - applicationId: { - type: String, - default: '' - }, - chatId: { - type: String, - default: '' - }, - chat_loading: { - type: Boolean - }, - log: Boolean, - tts: Boolean, - tts_type: String -}) +const props = withDefaults( + defineProps<{ + data: any + type: 'log' | 'ai-chat' | 'debug-ai-chat' + chatId: string + chat_loading: boolean + applicationId: string + tts: boolean + tts_type: string + }>(), + { + data: () => ({}), + type: 'ai-chat' + } +) const emit = defineEmits(['update:data', 'regeneration']) diff --git a/ui/src/components/ai-chat/LogOperationButton.vue b/ui/src/components/ai-chat/component/operation-button/LogOperationButton.vue similarity index 68% rename from ui/src/components/ai-chat/LogOperationButton.vue rename to ui/src/components/ai-chat/component/operation-button/LogOperationButton.vue index 10dd47165df..96ccaf756ae 100644 --- a/ui/src/components/ai-chat/LogOperationButton.vue +++ b/ui/src/components/ai-chat/component/operation-button/LogOperationButton.vue @@ -1,59 +1,61 @@ - - - {{ datetimeFormat(data.create_time) }} - - - - - - - - + + + + {{ datetimeFormat(data.create_time) }} + + + + + + + + + + + + + + + + + + + + - - - + + + + - - - - - - - - - - - + + + + + + + + + + - - - - + + - - - - - - - - - - - - - - + + + + + + diff --git a/ui/src/components/ai-chat/component/prologue-content/index.vue b/ui/src/components/ai-chat/component/prologue-content/index.vue new file mode 100644 index 00000000000..e8df17898a5 --- /dev/null +++ b/ui/src/components/ai-chat/component/prologue-content/index.vue @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + diff --git a/ui/src/components/ai-chat/component/question-content/index.vue b/ui/src/components/ai-chat/component/question-content/index.vue new file mode 100644 index 00000000000..a9d202e4f27 --- /dev/null +++ b/ui/src/components/ai-chat/component/question-content/index.vue @@ -0,0 +1,81 @@ + + + + + + + + + + + + {{ chatRecord.problem_text }} + + + + + + diff --git a/ui/src/components/ai-chat/component/user-form/index.vue b/ui/src/components/ai-chat/component/user-form/index.vue new file mode 100644 index 00000000000..f64d08b7376 --- /dev/null +++ b/ui/src/components/ai-chat/component/user-form/index.vue @@ -0,0 +1,304 @@ + + + + + + 用户输入 + + + + + + + + + + + + diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 38bda779f7f..0ea8dd18986 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -1,368 +1,100 @@ - - - - - - 用户输入 - - - - - - - - - + + - - - - - - - - - - - - - - {{ item.str }} - - - - - - + - - - - - - - - - - {{ item.problem_text }} - - - + - - - - - - - - - - - - - - - - - 已停止回答 - - - 回答中 - - - - - - - - - - - - - - - - - 继续 - - 停止回答 - - - - - - - - + - - - - - - - - - - - - - - - - - - - - - - 00:{{ recorderTime < 10 ? `0${recorderTime}` : recorderTime }} - - - - - - - - - - - - - - - - {{ data.disclaimer_value }} - - - - + + diff --git a/ui/src/components/markdown/MdRenderer.vue b/ui/src/components/markdown/MdRenderer.vue index 517a3bc5683..f8faf3b2df1 100644 --- a/ui/src/components/markdown/MdRenderer.vue +++ b/ui/src/components/markdown/MdRenderer.vue @@ -2,9 +2,9 @@ {}" + @click="sendMessage ? sendMessage(item.content, 'new') : (content: string) => {}" class="problem-button ellipsis-2 mb-8" - :class="quickProblemHandle ? 'cursor' : 'disabled'" + :class="sendMessage ? 'cursor' : 'disabled'" > {{ item.content }} @@ -14,6 +14,11 @@ v-else-if="item.type === 'echarts_rander'" :option="item.content" > + { @@ -54,7 +60,7 @@ const props = withDefaults( defineProps<{ source?: string inner_suffix?: boolean - quickProblemHandle?: (q: string) => void + sendMessage?: (question: string, type: 'old' | 'new', other_params_data?: any) => void }>(), { source: '' @@ -63,7 +69,9 @@ const props = withDefaults( const editorRef = ref() const md_view_list = computed(() => { const temp_source = props.source - return split_echarts_rander(split_html_rander(split_quick_question([temp_source]))) + return split_form_rander( + split_echarts_rander(split_html_rander(split_quick_question([temp_source]))) + ) }) const split_quick_question = (result: Array) => { @@ -168,6 +176,41 @@ const split_echarts_rander_ = (source: string, type: string) => { }) return result } + +const split_form_rander = (result: Array) => { + return result + .map((item) => split_form_rander_(item.content, item.type)) + .reduce((x: any, y: any) => { + return [...x, ...y] + }, []) +} + +const split_form_rander_ = (source: string, type: string) => { + const temp_md_quick_question_list = source.match(/[\d\D]*?<\/form_rander>/g) + const md_quick_question_list = temp_md_quick_question_list + ? temp_md_quick_question_list.filter((i) => i) + : [] + const split_quick_question_value = source + .split(/[\d\D]*?<\/form_rander>/g) + .filter((item) => item !== undefined) + .filter((item) => !md_quick_question_list?.includes(item)) + const result = Array.from( + { length: md_quick_question_list.length + split_quick_question_value.length }, + (v, i) => i + ).map((index) => { + if (index % 2 == 0) { + return { type: type, content: split_quick_question_value[Math.floor(index / 2)] } + } else { + return { + type: 'form_rander', + content: md_quick_question_list[Math.floor(index / 2)] + .replace('', '') + .replace('', '') + } + } + }) + return result +} diff --git a/ui/src/workflow/common/EditFormCollect.vue b/ui/src/workflow/common/EditFormCollect.vue new file mode 100644 index 00000000000..3413a565105 --- /dev/null +++ b/ui/src/workflow/common/EditFormCollect.vue @@ -0,0 +1,55 @@ + + + + + + 取消 + 修改 + + + + + + diff --git a/ui/src/workflow/common/data.ts b/ui/src/workflow/common/data.ts index 97495cbd188..29a202a114a 100644 --- a/ui/src/workflow/common/data.ts +++ b/ui/src/workflow/common/data.ts @@ -168,6 +168,31 @@ export const rerankerNode = { } } } +export const formNode = { + type: WorkflowType.FormNode, + text: '在问答过程中用于收集用户信息,可以根据收集到表单数据执行后续流程', + label: '表单收集', + height: 252, + properties: { + width: 600, + stepName: '表单收集', + node_data: { + is_result: true, + form_field_list: [], + form_content_format: `你好,请先填写下面表单内容: +{{form}} +填写后请点击【提交】按钮进行提交。` + }, + config: { + fields: [ + { + label: '表单全部内容', + value: 'form_data' + } + ] + } + } +} export const documentExtractNode = { type: WorkflowType.DocumentExtractNode, text: '提取文档中的内容', @@ -180,7 +205,7 @@ export const documentExtractNode = { { label: '文件内容', value: 'content' - }, + } ] } } @@ -197,7 +222,7 @@ export const imageUnderstandNode = { { label: 'AI 回答内容', value: 'content' - }, + } ] } } @@ -208,9 +233,7 @@ export const menuNodes = [ questionNode, conditionNode, replyNode, - rerankerNode, - documentExtractNode, - imageUnderstandNode + rerankerNode ] /** @@ -297,9 +320,10 @@ export const nodeDict: any = { [WorkflowType.FunctionLib]: functionLibNode, [WorkflowType.FunctionLibCustom]: functionNode, [WorkflowType.RrerankerNode]: rerankerNode, + [WorkflowType.FormNode]: formNode, [WorkflowType.Application]: applicationNode, [WorkflowType.DocumentExtractNode]: documentExtractNode, - [WorkflowType.ImageUnderstandNode]: imageUnderstandNode, + [WorkflowType.ImageUnderstandNode]: imageUnderstandNode } export function isWorkFlow(type: string | undefined) { return type === 'WORK_FLOW' diff --git a/ui/src/workflow/icons/form-node-icon.vue b/ui/src/workflow/icons/form-node-icon.vue new file mode 100644 index 00000000000..40f6ea77fac --- /dev/null +++ b/ui/src/workflow/icons/form-node-icon.vue @@ -0,0 +1,6 @@ + + + + + + diff --git a/ui/src/workflow/nodes/form-node/index.ts b/ui/src/workflow/nodes/form-node/index.ts new file mode 100644 index 00000000000..b7c2dbec9fa --- /dev/null +++ b/ui/src/workflow/nodes/form-node/index.ts @@ -0,0 +1,12 @@ +import FormNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class FormNode extends AppNode { + constructor(props: any) { + super(props, FormNodeVue) + } +} +export default { + type: 'form-node', + model: AppNodeModel, + view: FormNode +} diff --git a/ui/src/workflow/nodes/form-node/index.vue b/ui/src/workflow/nodes/form-node/index.vue new file mode 100644 index 00000000000..0b52ccd854d --- /dev/null +++ b/ui/src/workflow/nodes/form-node/index.vue @@ -0,0 +1,192 @@ + + + + + + + + 表单输出内容* + + + + `设置执行该节点输出的内容,{{ '{ from }' }}为表单的占位符。` + + + + + + + + + + {{ '接口传参' }} + + + + + 添加 + + + + + + + {{ + row.label.label + }} + {{ row.label }} + + + + + {{ + input_type_list.find((item) => item.value === row.input_type)?.label + }} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +