Skip to content

Commit

Permalink
Add Code Interpreter demo
Browse files Browse the repository at this point in the history
  • Loading branch information
xusenlin committed Nov 7, 2023
1 parent a81b502 commit eec8154
Show file tree
Hide file tree
Showing 10 changed files with 364 additions and 11 deletions.
7 changes: 4 additions & 3 deletions api/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
res, function_call = parse_response(content["text"])
content["text"] = res

if isinstance(function_call, dict):
if isinstance(function_call, dict) and "arguments" in function_call:
finish_reason = "function_call"
function_call = FunctionCallResponse(**function_call)

Expand Down Expand Up @@ -126,6 +126,7 @@ async def chat_completion_stream_generator(
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
"""
_id = f"chatcmpl-{secrets.token_hex(12)}"
use_tool = bool(gen_params["functions"] is not None)
for i in range(n):
# First chunk with role
choice_data = ChatCompletionResponseStreamChoice(
Expand Down Expand Up @@ -160,14 +161,14 @@ async def chat_completion_stream_generator(
function_call = None
if finish_reason == "function_call" and "chatglm3" in config.MODEL_NAME.lower():
try:
function_call = process_response_v3(decoded_unicode, use_tool=True)
function_call = process_response_v3(decoded_unicode, use_tool=use_tool)
except:
logger.warning("Failed to parse tool call")

elif finish_reason == "function_call" and "qwen" in config.MODEL_NAME.lower():
_, function_call = parse_response(decoded_unicode)

if isinstance(function_call, dict):
if isinstance(function_call, dict) and "arguments" in function_call:
function_call = FunctionCallResponse(**function_call)

delta = DeltaMessage(
Expand Down
1 change: 1 addition & 0 deletions examples/chatglm3/tool_using.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def run_conversation(query: str, stream=False, functions=None, max_retry=5):
params["messages"].append(
{
"role": "assistant",
"function_call": function_call,
"content": output
}
)
Expand Down
4 changes: 3 additions & 1 deletion streamlit-demo/.env
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ TOOL_CHAT_API_BASE = "http://192.168.20.59:7891/v1" # 调用工具模型接口
EMBEDDING_API_BASE = "http://192.168.0.53:7891/v1" # 嵌入模型接口地址(可选)
API_KEY = "xxx" # 默认不需要配置
EMBEDDING_NAME = "" # 使用本地嵌入模型的路径(可选,EMBEDDING_API_BASE 和 EMBEDDING_NAME 两种方式选一种即可)
SERPAPI_API_KEY = "" # 搜索功能需要
SERPAPI_API_KEY = "" # 搜索功能需要
IPYKERNEL = "llm" # python解释器名称
INTERPRETER_CHAT_API_BASE = "http://192.168.20.59:7891/v1" # 代码解释器模型接口地址(可选)
6 changes: 6 additions & 0 deletions streamlit-demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@ streamlit run streamlit_app.py

**环境变量配置见 [.env](.env)**

## 代码解释器(基于 ChatGLM3 模型)【测试版本】

```shell
ipython kernel install --name llm --user
```

5 changes: 4 additions & 1 deletion streamlit-demo/streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def main():
from streamlit_gallery.apps import gallery
from streamlit_gallery.components import chat, doc_chat, sql_chat, search_chat, tool_chat
from streamlit_gallery.components import chat, doc_chat, sql_chat, search_chat, tool_chat, code_interpreter

page = page_group("p")

Expand All @@ -32,6 +32,9 @@ def main():
if os.getenv("TOOL_CHAT_API_BASE", ""):
page.item("Tool Chat", tool_chat)

if os.getenv("INTERPRETER_CHAT_API_BASE", ""):
page.item("Code Interpreter", code_interpreter)

with st.expander("🐧 PARAMTERS", False):
max_tokens = st.slider("MaxTokens", 20, 4096, 1024)
temperature = st.slider("Temperature", 0.0, 1.0, 0.9)
Expand Down
1 change: 1 addition & 0 deletions streamlit-demo/streamlit_gallery/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .sql_chat.streamlit_app import main as sql_chat
from .search_chat.streamlit_app import main as search_chat
from .tool_chat.streamlit_app import main as tool_chat
from .code_interpreter.streamlit_app import main as code_interpreter
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os

import openai
import streamlit as st

from .utils import CodeKernel, extract_code, execute, postprocess_text


@st.cache_resource
def get_kernel():
return CodeKernel()


SYSTEM_MESSAGE = [
{
"role": "system",
"content": "你是一位智能AI助手,你叫ChatGLM,你连接着一台电脑,但请注意不能联网。在使用Python解决任务时,你可以运行代码并得到结果,如果运行结果有错误,你需要尽可能对代码进行改进。你可以处理用户上传到电脑上的文件,文件默认存储路径是/mnt/data/。"
}
]


def chat_once(message_placeholder):
params = dict(
model="chatglm3",
messages=SYSTEM_MESSAGE + st.session_state.messages,
stream=True,
max_tokens=st.session_state.get("max_tokens", 512),
temperature=st.session_state.get("temperature", 0.9),
)
response = openai.ChatCompletion.create(**params)

display = ""
for _ in range(5):
full_response = ""
for chunk in response:
content = chunk.choices[0].delta.get("content", "")
full_response += content
display += content
message_placeholder.markdown(postprocess_text(display) + "▌")

if chunk.choices[0].finish_reason == "stop":
message_placeholder.markdown(postprocess_text(display) + "▌")
st.session_state.messages.append(
{
"role": "assistant",
"content": full_response
}
)
return

elif chunk.choices[0].finish_reason == "function_call":
try:
code = extract_code(full_response)
except:
continue

with message_placeholder:
with st.spinner("Executing code..."):
try:
res_type, res = execute(code, get_kernel())
except Exception as e:
st.error(f"Error when executing code: {e}")
return

if res_type == "text":
res = postprocess_text(res)
display += "\n" + res
message_placeholder.markdown(postprocess_text(display) + "▌")
elif res_type == "image":
st.image(res)

st.session_state.messages.append(
{
"role": "assistant",
"content": full_response,
"function_call": {"name": "interpreter", "arguments": ""},
}
)
st.session_state.messages.append(
{
"role": "function",
"content": "[Image]" if res_type == "image" else res, # 调用函数返回结果
}
)

break

params["messages"] = st.session_state.messages
response = openai.ChatCompletion.create(**params)


def main():
st.title("💬 Code Interpreter")

openai.api_base = os.getenv("INTERPRETER_CHAT_API_BASE", "http://192.168.20.59:7891/v1")
openai.api_key = os.getenv("API_KEY", "xxx")

if "messages" not in st.session_state:
st.session_state.messages = []

for message in st.session_state.messages:
role = message["role"]
if role in ["user", "function"]:
with st.chat_message("user"):
st.markdown(message["content"])
else:
with st.chat_message("assistant"):
st.markdown(postprocess_text(message["content"]))

if prompt := st.chat_input("What is up?"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)

with st.chat_message("assistant"):
message_placeholder = st.empty()
chat_once(message_placeholder)


if __name__ == "__main__":
main()
Loading

0 comments on commit eec8154

Please sign in to comment.