From ba85901aa0611bb3fb9567ac9f25fea920639352 Mon Sep 17 00:00:00 2001
From: yhjun1026 <460342015@qq.com>
Date: Fri, 20 Oct 2023 16:21:49 +0800
Subject: [PATCH 1/3] feat(ChatData): ChatData Strean response
1.ChatData Stream response
---
.../agent/commands/command_mange.py | 6 +--
pilot/openapi/api_v1/editor/api_editor_v1.py | 34 +++++++-----
pilot/scene/chat_db/auto_execute/chat.py | 13 +++--
pilot/scene/chat_db/auto_execute/prompt.py | 53 ++++++++++++++-----
4 files changed, 73 insertions(+), 33 deletions(-)
diff --git a/pilot/base_modules/agent/commands/command_mange.py b/pilot/base_modules/agent/commands/command_mange.py
index be9e02811..4fbaa6dc9 100644
--- a/pilot/base_modules/agent/commands/command_mange.py
+++ b/pilot/base_modules/agent/commands/command_mange.py
@@ -227,7 +227,7 @@ def __is_need_wait_plugin_call(self, api_call_context):
i += 1
return False
- def __check_last_plugin_call_ready(self, all_context):
+ def check_last_plugin_call_ready(self, all_context):
start_agent_count = all_context.count(self.agent_prefix)
end_agent_count = all_context.count(self.agent_end)
@@ -359,7 +359,7 @@ def to_view_text(self, api_status: PluginStatus):
def run(self, llm_text):
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
- if self.__check_last_plugin_call_ready(llm_text):
+ if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
@@ -379,7 +379,7 @@ def run(self, llm_text):
def run_display_sql(self, llm_text, sql_run_func):
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
- if self.__check_last_plugin_call_ready(llm_text):
+ if self.check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
diff --git a/pilot/openapi/api_v1/editor/api_editor_v1.py b/pilot/openapi/api_v1/editor/api_editor_v1.py
index e41998942..3fb985e91 100644
--- a/pilot/openapi/api_v1/editor/api_editor_v1.py
+++ b/pilot/openapi/api_v1/editor/api_editor_v1.py
@@ -29,6 +29,8 @@
from pilot.scene.chat_dashboard.data_loader import DashboardDataLoader
from pilot.scene.chat_db.data_loader import DbDataLoader
from pilot.memory.chat_history.chat_hisotry_factory import ChatHistory
+from pilot.base_modules.agent.commands.command_mange import ApiCall
+
router = APIRouter()
CFG = Config()
@@ -101,12 +103,15 @@ async def get_editor_sql(con_uid: str, round: int):
logger.info(
f'history ai json resp:{element["data"]["content"]}'
)
- context = (
- element["data"]["content"]
- .replace("\\n", " ")
- .replace("\n", " ")
- )
- return Result.succ(json.loads(context))
+ api_call = ApiCall()
+ result = {}
+ result['thoughts'] = element["data"]["content"]
+ if api_call.check_last_plugin_call_ready(element["data"]["content"]):
+ api_call.update_from_context(element["data"]["content"])
+ if len(api_call.plugin_status_map) > 0:
+ first_item = next(iter(api_call.plugin_status_map.items()))[1]
+ result['sql'] = first_item.args["sql"]
+ return Result.succ(result)
return Result.faild(msg="not have sql!")
@@ -156,17 +161,18 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
)
)[0]
if edit_round:
+ new_ai_text = ""
for element in edit_round["messages"]:
if element["type"] == "ai":
- db_resp = json.loads(element["data"]["content"])
- db_resp["thoughts"] = sql_edit_context.new_speak
- db_resp["sql"] = sql_edit_context.new_sql
- element["data"]["content"] = json.dumps(db_resp)
+ new_ai_text = element["data"]["content"]
+ new_ai_text.replace(sql_edit_context.old_sql, sql_edit_context.new_sql)
+ element["data"]["content"] = new_ai_text
+
+ for element in edit_round["messages"]:
if element["type"] == "view":
- data_loader = DbDataLoader()
- element["data"]["content"] = data_loader.get_table_view_by_conn(
- conn.run(sql_edit_context.new_sql), sql_edit_context.new_speak
- )
+ api_call = ApiCall()
+ new_view_text = api_call.run_display_sql(new_ai_text, conn.run_to_df)
+ element["data"]["content"] = new_view_text
history_mem.update(history_messages)
return Result.succ(None)
return Result.faild(msg="Edit Faild!")
diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py
index f92df7a3a..01e04be63 100644
--- a/pilot/scene/chat_db/auto_execute/chat.py
+++ b/pilot/scene/chat_db/auto_execute/chat.py
@@ -5,6 +5,7 @@
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.scene.chat_db.auto_execute.prompt import prompt
+from pilot.base_modules.agent.commands.command_mange import ApiCall
CFG = Config()
@@ -37,6 +38,7 @@ def __init__(self, chat_param: Dict):
self.database = CFG.LOCAL_DB_MANAGE.get_connect(self.db_name)
self.top_k: int = 200
+ self.api_call = ApiCall(display_registry=CFG.command_disply)
def generate_input_values(self):
"""
@@ -69,6 +71,11 @@ def generate_input_values(self):
}
return input_values
- def do_action(self, prompt_response):
- print(f"do_action:{prompt_response}")
- return self.database.run(prompt_response.sql)
+ def stream_plugin_call(self, text):
+ text = text.replace("\n", " ")
+ print(f"stream_plugin_call:{text}")
+ return self.api_call.run_display_sql(text, self.database.run_to_df)
+ #
+ # def do_action(self, prompt_response):
+ # print(f"do_action:{prompt_response}")
+ # return self.database.run(prompt_response.sql)
diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py
index abc889cec..d9b67af39 100644
--- a/pilot/scene/chat_db/auto_execute/prompt.py
+++ b/pilot/scene/chat_db/auto_execute/prompt.py
@@ -8,24 +8,51 @@
CFG = Config()
-PROMPT_SCENE_DEFINE = "You are a SQL expert. "
-_DEFAULT_TEMPLATE = """
+_PROMPT_SCENE_DEFINE_EN = "You are a database expert. "
+_PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. "
+
+_DEFAULT_TEMPLATE_EN = """
Given an input question, create a syntactically correct {dialect} sql.
+Table structure information:
+ {table_info}
+Constraint:
+1. You can only use the table provided in the table structure information to generate sql. If you cannot generate sql based on the provided table structure, please say: "The table structure information provided is not enough to generate sql query." It is prohibited to fabricate information at will.
+2. Do not query columns that do not exist. Pay attention to which column is in which table.
+3. Replace the corresponding sql into the sql field in the returned result
+4. Unless the user specifies in the question a specific number of examples he wishes to obtain, always limit the query to a maximum of {top_k} results.
+5. Please output the Sql content in the following format to execute the corresponding SQL to display the data:response_tableSQL Query to run
+Please make sure to respond as following format:
+ thoughts summary to say to user.response_tableSQL Query to run
+
+Question: {input}
+"""
-Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results.
-Use as few tables as possible when querying.
-Only use the following tables schema to generate sql:
-{table_info}
-Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
+_DEFAULT_TEMPLATE_ZH = """
+给定一个输入问题,创建一个语法正确的 {dialect} sql。
+已知表结构信息:
+ {table_info}
-Question: {input}
+约束:
+1. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。
+2. 不要查询不存在的列,注意哪一列位于哪张表中。
+3.将对应的sql替换到返回结果中的sql字段中
+4.除非用户在问题中指定了他希望获得的具体示例数量,否则始终将查询限制为最多 {top_k} 个结果。
-Respond in JSON format as following format:
-{response}
-Ensure the response is correct json and can be parsed by Python json.loads
+请务必按照以下格式回复:
+ 对用户说的想法摘要。response_table要运行的 SQL
+
+问题:{input}
"""
+_DEFAULT_TEMPLATE = (
+ _DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
+)
+
+PROMPT_SCENE_DEFINE = (
+ _PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
+)
+
RESPONSE_FORMAT_SIMPLE = {
"thoughts": "thoughts summary to say to user",
"sql": "SQL Query to run",
@@ -33,7 +60,7 @@
PROMPT_SEP = SeparatorStyle.SINGLE.value
-PROMPT_NEED_NEED_STREAM_OUT = False
+PROMPT_NEED_NEED_STREAM_OUT = True
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
@@ -43,7 +70,7 @@
prompt = PromptTemplate(
template_scene=ChatScene.ChatWithDbExecute.value(),
input_variables=["input", "table_info", "dialect", "top_k", "response"],
- response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
+ # response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
From 6cc194b92ba8efdee9a247e90658c863cf3892bf Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Fri, 20 Oct 2023 16:24:54 +0800
Subject: [PATCH 2/3] doc:update readme
---
README.md | 3 +--
README.zh.md | 3 +--
2 files changed, 2 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index cda4a40bb..16138d0cf 100644
--- a/README.md
+++ b/README.md
@@ -56,8 +56,7 @@ DB-GPT is an experimental open-source project that uses localized GPT large mode
## Demo
Run on an RTX 4090 GPU.
##### Chat Excel
-![cx_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/659259dc-c3ba-41c8-8bc3-179cd4385dbe)
-![chatecl_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/2ebfdee2-2262-4d32-8933-4fb27f969180)
+![excel](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/0474d220-2a9f-449f-a940-92c8a25af390)
##### Chat Plugin
![auto_plugin_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/7d95c347-f4b7-4fb6-8dd2-c1c02babaa56)
##### LLM Management
diff --git a/README.zh.md b/README.zh.md
index 301dd2e6d..3832fc30c 100644
--- a/README.zh.md
+++ b/README.zh.md
@@ -58,8 +58,7 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目,使用本地
示例通过 RTX 4090 GPU 演示
##### Chat Excel
-![cx_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/659259dc-c3ba-41c8-8bc3-179cd4385dbe)
-![chatecl_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/2ebfdee2-2262-4d32-8933-4fb27f969180)
+![excel](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/0474d220-2a9f-449f-a940-92c8a25af390)
##### Chat Plugin
![auto_plugin_new](https://github.com/eosphoros-ai/DB-GPT/assets/13723926/7d95c347-f4b7-4fb6-8dd2-c1c02babaa56)
##### LLM Management
From b0a9fb7810175e30963731628c8aab579fd8cb2a Mon Sep 17 00:00:00 2001
From: aries_ckt <916701291@qq.com>
Date: Fri, 20 Oct 2023 16:27:13 +0800
Subject: [PATCH 3/3] style:fmt
---
pilot/openapi/api_v1/editor/api_editor_v1.py | 20 ++++++++++++++------
pilot/scene/chat_db/auto_execute/chat.py | 1 +
2 files changed, 15 insertions(+), 6 deletions(-)
diff --git a/pilot/openapi/api_v1/editor/api_editor_v1.py b/pilot/openapi/api_v1/editor/api_editor_v1.py
index 3fb985e91..e7637e582 100644
--- a/pilot/openapi/api_v1/editor/api_editor_v1.py
+++ b/pilot/openapi/api_v1/editor/api_editor_v1.py
@@ -105,12 +105,16 @@ async def get_editor_sql(con_uid: str, round: int):
)
api_call = ApiCall()
result = {}
- result['thoughts'] = element["data"]["content"]
- if api_call.check_last_plugin_call_ready(element["data"]["content"]):
+ result["thoughts"] = element["data"]["content"]
+ if api_call.check_last_plugin_call_ready(
+ element["data"]["content"]
+ ):
api_call.update_from_context(element["data"]["content"])
if len(api_call.plugin_status_map) > 0:
- first_item = next(iter(api_call.plugin_status_map.items()))[1]
- result['sql'] = first_item.args["sql"]
+ first_item = next(
+ iter(api_call.plugin_status_map.items())
+ )[1]
+ result["sql"] = first_item.args["sql"]
return Result.succ(result)
return Result.faild(msg="not have sql!")
@@ -165,13 +169,17 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
for element in edit_round["messages"]:
if element["type"] == "ai":
new_ai_text = element["data"]["content"]
- new_ai_text.replace(sql_edit_context.old_sql, sql_edit_context.new_sql)
+ new_ai_text.replace(
+ sql_edit_context.old_sql, sql_edit_context.new_sql
+ )
element["data"]["content"] = new_ai_text
for element in edit_round["messages"]:
if element["type"] == "view":
api_call = ApiCall()
- new_view_text = api_call.run_display_sql(new_ai_text, conn.run_to_df)
+ new_view_text = api_call.run_display_sql(
+ new_ai_text, conn.run_to_df
+ )
element["data"]["content"] = new_view_text
history_mem.update(history_messages)
return Result.succ(None)
diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py
index 01e04be63..b05480cdf 100644
--- a/pilot/scene/chat_db/auto_execute/chat.py
+++ b/pilot/scene/chat_db/auto_execute/chat.py
@@ -75,6 +75,7 @@ def stream_plugin_call(self, text):
text = text.replace("\n", " ")
print(f"stream_plugin_call:{text}")
return self.api_call.run_display_sql(text, self.database.run_to_df)
+
#
# def do_action(self, prompt_response):
# print(f"do_action:{prompt_response}")