diff --git a/README.md b/README.md index cda4a40bb..16138d0cf 100644 --- a/README.md +++ b/README.md @@ -56,8 +56,7 @@ DB-GPT is an experimental open-source project that uses localized GPT large mode ## Demo Run on an RTX 4090 GPU. ##### Chat Excel -![cx_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/659259dc-c3ba-41c8-8bc3-179cd4385dbe) -![chatecl_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/2ebfdee2-2262-4d32-8933-4fb27f969180) +![excel](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/0474d220-2a9f-449f-a940-92c8a25af390) ##### Chat Plugin ![auto_plugin_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/7d95c347-f4b7-4fb6-8dd2-c1c02babaa56) ##### LLM Management diff --git a/README.zh.md b/README.zh.md index 301dd2e6d..3832fc30c 100644 --- a/README.zh.md +++ b/README.zh.md @@ -58,8 +58,7 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地 示例通过 RTX 4090 GPU 演示 ##### Chat Excel -![cx_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/659259dc-c3ba-41c8-8bc3-179cd4385dbe) -![chatecl_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/2ebfdee2-2262-4d32-8933-4fb27f969180) +![excel](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/0474d220-2a9f-449f-a940-92c8a25af390) ##### Chat Plugin ![auto_plugin_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/7d95c347-f4b7-4fb6-8dd2-c1c02babaa56) ##### LLM Management diff --git a/pilot/base_modules/agent/commands/command_mange.py b/pilot/base_modules/agent/commands/command_mange.py index be9e02811..4fbaa6dc9 100644 --- a/pilot/base_modules/agent/commands/command_mange.py +++ b/pilot/base_modules/agent/commands/command_mange.py @@ -227,7 +227,7 @@ def __is_need_wait_plugin_call(self, api_call_context): i += 1 return False - def __check_last_plugin_call_ready(self, all_context): + def check_last_plugin_call_ready(self, all_context): start_agent_count = all_context.count(self.agent_prefix) end_agent_count = all_context.count(self.agent_end) @@ -359,7 +359,7 @@ def to_view_text(self, api_status: PluginStatus): def run(self, llm_text): if self.__is_need_wait_plugin_call(llm_text): # wait api call generate complete - if self.__check_last_plugin_call_ready(llm_text): + if self.check_last_plugin_call_ready(llm_text): self.update_from_context(llm_text) for key, value in self.plugin_status_map.items(): if value.status == Status.TODO.value: @@ -379,7 +379,7 @@ def run(self, llm_text): def run_display_sql(self, llm_text, sql_run_func): if self.__is_need_wait_plugin_call(llm_text): # wait api call generate complete - if self.__check_last_plugin_call_ready(llm_text): + if self.check_last_plugin_call_ready(llm_text): self.update_from_context(llm_text) for key, value in self.plugin_status_map.items(): if value.status == Status.TODO.value: diff --git a/pilot/openapi/api_v1/editor/api_editor_v1.py b/pilot/openapi/api_v1/editor/api_editor_v1.py index e41998942..e7637e582 100644 --- a/pilot/openapi/api_v1/editor/api_editor_v1.py +++ b/pilot/openapi/api_v1/editor/api_editor_v1.py @@ -29,6 +29,8 @@ from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader from pilot.scene.chat_db.data_loader import DbDataLoader from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory +from pilot.base_modules.agent.commands.command_mange import ApiCall + router = APIRouter() CFG = Config() @@ -101,12 +103,19 @@ async def get_editor_sql(con_uid: str, round: int): logger.info( f'history ai json resp:{element["data"]["content"]}' ) - context = ( + api_call = ApiCall() + result = {} + result["thoughts"] = element["data"]["content"] + if api_call.check_last_plugin_call_ready( element["data"]["content"] - .replace("\\n", " ") - .replace("\n", " ") - ) - return Result.succ(json.loads(context)) + ): + api_call.update_from_context(element["data"]["content"]) + if len(api_call.plugin_status_map) > 0: + first_item = next( + iter(api_call.plugin_status_map.items()) + )[1] + result["sql"] = first_item.args["sql"] + return Result.succ(result) return Result.faild(msg="not have sql!") @@ -156,17 +165,22 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()): ) )[0] if edit_round: + new_ai_text = "" for element in edit_round["messages"]: if element["type"] == "ai": - db_resp = json.loads(element["data"]["content"]) - db_resp["thoughts"] = sql_edit_context.new_speak - db_resp["sql"] = sql_edit_context.new_sql - element["data"]["content"] = json.dumps(db_resp) + new_ai_text = element["data"]["content"] + new_ai_text.replace( + sql_edit_context.old_sql, sql_edit_context.new_sql + ) + element["data"]["content"] = new_ai_text + + for element in edit_round["messages"]: if element["type"] == "view": - data_loader = DbDataLoader() - element["data"]["content"] = data_loader.get_table_view_by_conn( - conn.run(sql_edit_context.new_sql), sql_edit_context.new_speak + api_call = ApiCall() + new_view_text = api_call.run_display_sql( + new_ai_text, conn.run_to_df ) + element["data"]["content"] = new_view_text history_mem.update(history_messages) return Result.succ(None) return Result.faild(msg="Edit Faild!") diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index f92df7a3a..b05480cdf 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -5,6 +5,7 @@ from pilot.common.sql_database import Database from pilot.configs.config import Config from pilot.scene.chat_db.auto_execute.prompt import prompt +from pilot.base_modules.agent.commands.command_mange import ApiCall CFG = Config() @@ -37,6 +38,7 @@ def __init__(self, chat_param: Dict): self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name) self.top_k: int = 200 + self.api_call = ApiCall(display_registry=CFG.command_disply) def generate_input_values(self): """ @@ -69,6 +71,12 @@ def generate_input_values(self): } return input_values - def do_action(self, prompt_response): - print(f"do_action:{prompt_response}") - return self.database.run(prompt_response.sql) + def stream_plugin_call(self, text): + text = text.replace("\n", " ") + print(f"stream_plugin_call:{text}") + return self.api_call.run_display_sql(text, self.database.run_to_df) + + # + # def do_action(self, prompt_response): + # print(f"do_action:{prompt_response}") + # return self.database.run(prompt_response.sql) diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index abc889cec..d9b67af39 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -8,24 +8,51 @@ CFG = Config() -PROMPT_SCENE_DEFINE = "You are a SQL expert. " -_DEFAULT_TEMPLATE = """ +_PROMPT_SCENE_DEFINE_EN = "You are a database expert. " +_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. " + +_DEFAULT_TEMPLATE_EN = """ Given an input question, create a syntactically correct {dialect} sql. +Table structure information: + {table_info} +Constraint: +1. You can only use the table provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql query." It is prohibited to fabricate information at will. +2. Do not query columns that do not exist. Pay attention to which column is in which table. +3. Replace the corresponding sql into the sql field in the returned result +4. Unless the user specifies in the question a specific number of examples he wishes to obtain, always limit the query to a maximum of {top_k} results. +5. Please output the Sql content in the following format to execute the corresponding SQL to display the data:response_tableSQL Query to run +Please make sure to respond as following format: + thoughts summary to say to user.response_tableSQL Query to run + +Question: {input} +""" -Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. -Use as few tables as possible when querying. -Only use the following tables schema to generate sql: -{table_info} -Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. +_DEFAULT_TEMPLATE_ZH = """ +给定一个输入问题,创建一个语法正确的 {dialect} sql。 +已知表结构信息: + {table_info} -Question: {input} +约束: +1. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。 +2. 不要查询不存在的列,注意哪一列位于哪张表中。 +3.将对应的sql替换到返回结果中的sql字段中 +4.除非用户在问题中指定了他希望获得的具体示例数量,否则始终将查询限制为最多 {top_k} 个结果。 -Respond in JSON format as following format: -{response} -Ensure the response is correct json and can be parsed by Python json.loads +请务必按照以下格式回复: + 对用户说的想法摘要。response_table要运行的 SQL + +问题:{input} """ +_DEFAULT_TEMPLATE = ( + _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH +) + +PROMPT_SCENE_DEFINE = ( + _PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH +) + RESPONSE_FORMAT_SIMPLE = { "thoughts": "thoughts summary to say to user", "sql": "SQL Query to run", @@ -33,7 +60,7 @@ PROMPT_SEP = SeparatorStyle.SINGLE.value -PROMPT_NEED_NEED_STREAM_OUT = False +PROMPT_NEED_NEED_STREAM_OUT = True # Temperature is a configuration hyperparameter that controls the randomness of language model output. # A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output. @@ -43,7 +70,7 @@ prompt = PromptTemplate( template_scene=ChatScene.ChatWithDbExecute.value(), input_variables=["input", "table_info", "dialect", "top_k", "response"], - response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4), + # response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4), template_define=PROMPT_SCENE_DEFINE, template=_DEFAULT_TEMPLATE, stream_out=PROMPT_NEED_NEED_STREAM_OUT,