-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
118 lines (99 loc) · 2.94 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from __future__ import annotations
import os
import time
import gradio as gr
from src.llm import text_generation
METAPROMPT = (
"<s>[INST] <<SYS>>"
"You are a helpful bot. Answer the user's questions. Respect the user. "
"Do not provide false information. If you do not know the answer, say I don't know."
"<</SYS>>"
)
def _build_prompt(msg, hist, system):
if len(hist) > 5:
hist = hist[-5:]
if len(hist) == 0:
return system + f"{msg} [/INST]"
prompt = system + f"{hist[0][0]} [/INST] {hist[0][1]} </s>"
for usr_msg, model_msg in hist[1:]:
prompt += f"<s>[INST] {usr_msg} [/INST] {model_msg} </s>"
prompt += f"<s>[INST] {msg} [/INST]"
return prompt
def user(msg, hist):
return "", hist + [[msg, None]]
def predict(msg, hist, system, **kwargs):
if system is None or system == "":
system = METAPROMPT
prompt = _build_prompt(msg, hist, system)
hist[-1][1] = ""
for character in text_generation.generate_text(prompt, **kwargs):
hist[-1][1] += character
time.sleep(0.05)
yield hist
with gr.Blocks() as demo:
gr.Markdown("# Llama v2 7b Chatbot")
chatbot = gr.Chatbot(height=400)
msg = gr.Textbox(label="Prompt", lines=1, placeholder="Type your prompt here:")
with gr.Accordion(label="Advanced options", open=False):
system = gr.Textbox(
label="System",
lines=2,
placeholder="Enter system prompt here:",
)
temperature = gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.1,
label="Temperature",
)
max_length = gr.Slider(
minimum=1,
maximum=4096,
step=1,
value=2000,
label="Max Lenght"
)
top_p = gr.Slider(
minimum=0,
maximum=1,
step=0.05,
value=0.95,
label="Top P",
)
top_k = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=40,
label="Top K"
)
btn = gr.Button(value="Submit")
clear = gr.ClearButton(components=[msg, chatbot], value="Clear console")
btn.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
predict,
inputs=[msg, chatbot, system],
outputs=chatbot,
)
msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then(
predict,
inputs=[
msg,
chatbot,
system,
# advanced options
temperature,
max_length,
top_p,
top_k,
],
outputs=chatbot,
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.queue().launch(
debug=False,
width=400,
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
)