Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control chats, save, export, import and otherwise manage #156

Merged
merged 2 commits into from
May 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 148 additions & 7 deletions gradio_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import functools
import inspect
import json
import os
import random
import sys
import traceback
import uuid
Expand All @@ -13,7 +15,7 @@
from prompter import Prompter, \
prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
ping, get_short_name, get_url
ping, get_short_name, get_url, makedirs
from finetune import generate_prompt
from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
inputs_kwargs_list, get_cutoffs, scratch_base_dir
Expand Down Expand Up @@ -160,6 +162,7 @@ def _postprocess_chat_messages(self, chat_message: str):
model_options_state = gr.State([model_options])
lora_options_state = gr.State([lora_options])
my_db_state = gr.State([None, None])
chat_state = gr.State({})
gr.Markdown(f"""
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}

Expand All @@ -179,7 +182,7 @@ def _postprocess_chat_messages(self, chat_message: str):
with gr.Row():
col_nochat = gr.Column(visible=not kwargs['chat'])
with col_nochat: # FIXME: for model comparison, and check rest
text_output_nochat = gr.Textbox(lines=5, label=output_label0)
text_output_nochat = gr.Textbox(lines=5, label=output_label0).style(show_copy_button=True)
instruction_nochat = gr.Textbox(
lines=kwargs['input_lines'],
label=instruction_label_nochat,
Expand Down Expand Up @@ -213,7 +216,7 @@ def _postprocess_chat_messages(self, chat_message: str):
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
with gr.Row():
clear = gr.Button("New Conversation")
clear = gr.Button("Save, New Conversation")
flag_btn = gr.Button("Flag")
if not kwargs['auto_score']: # FIXME: For checkbox model2
with gr.Column(visible=kwargs['score_model']):
Expand All @@ -232,7 +235,7 @@ def _postprocess_chat_messages(self, chat_message: str):
score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
retry = gr.Button("Regenerate")
undo = gr.Button("Undo")
with gr.TabItem("Input/Output"):
with gr.TabItem("Chat"):
with gr.Row():
if 'mbart-' in kwargs['model_lower']:
src_lang = gr.Dropdown(list(languages_covered().keys()),
Expand All @@ -241,6 +244,22 @@ def _postprocess_chat_messages(self, chat_message: str):
tgt_lang = gr.Dropdown(list(languages_covered().keys()),
value=kwargs['tgt_lang'],
label="Output Language")
radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
type='value')
with gr.Row():
remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
chats_row = gr.Row(visible=True).style(equal_height=False)
with chats_row:
export_chats_btn = gr.Button(value="Export Chats")
chats_file = gr.File(interactive=False, label="Download File")
chats_row2 = gr.Row(visible=True).style(equal_height=False)
with chats_row2:
chatsup_output = gr.File(label="Upload Chat File(s)",
file_types=['.json'],
file_count='multiple',
elem_id="warning", elem_classes="feedback")
add_to_chats_btn = gr.Button("Add File(s) to Chats")
with gr.TabItem("Data Source"):
langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
from_str=True)
Expand Down Expand Up @@ -482,7 +501,8 @@ def upload_file(files, x):
with gr.Column():
with gr.Row():
system_btn = gr.Button(value='Get System Info')
system_text = gr.Textbox(label='System Info', interactive=False)
system_text = gr.Textbox(label='System Info', interactive=False).style(
show_copy_button=True)

with gr.Row():
zip_btn = gr.Button("Zip")
Expand Down Expand Up @@ -974,9 +994,130 @@ def clear_instruct():
.then(**score_args_submit, api_name='undo_score' if allow_api else None) \
.then(**score_args2_submit, api_name='undo_score2' if allow_api else None)

# MANAGE CHATS
def dedup(short_chat, short_chats):
if short_chat not in short_chats:
return short_chat
for i in range(1, 1000):
short_chat_try = short_chat + "_" + str(i)
if short_chat_try not in short_chats:
return short_chat_try
# fallback and hope for best
short_chat = short_chat + "_" + str(random.random())
return short_chat

def get_short_chat(x, short_chats, short_len=20, words=4):
if x and len(x[0]) == 2 and x[0][0] is not None:
short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
short_chat = dedup(short_chat, short_chats)
else:
short_chat = None
return short_chat

def is_chat_same(x, y):
# <p> etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation
is_same = True
# length of conversation has to be same
if len(x) != len(y):
return False
for stepx, stepy in zip(x, y):
if len(stepx) != len(stepy):
# something off with a conversation
return False
if len(stepx) != 2:
# something off
return False
if len(stepy) != 2:
# something off
return False
questionx = stepx[0].replace('<p>', '').replace('</p>', '')
answerx = stepx[1].replace('<p>', '').replace('</p>', '')

questiony = stepy[0].replace('<p>', '').replace('</p>', '')
answery = stepy[1].replace('<p>', '').replace('</p>', '')

if questionx != questiony or answerx != answery:
return False
return is_same

def save_chat(chat1, chat2, chat_state1):
short_chats = list(chat_state1.keys())
for chati in [chat1, chat2]:
if chati and len(chati) > 0 and len(chati[0]) == 2 and chati[0][1] is not None:
short_chat = get_short_chat(chati, short_chats)
if short_chat:
already_exists = any([is_chat_same(chati, x) for x in chat_state1.values()])
if not already_exists:
chat_state1[short_chat] = chati
return chat_state1

def update_radio_chats(chat_state1):
return gr.update(choices=list(chat_state1.keys()), value=None)

def deselect_radio_chats():
return gr.update(value=None)

def switch_chat(chat_key, chat_state1):
chosen_chat = chat_state1[chat_key]
return chosen_chat, chosen_chat

radio_chats.input(switch_chat, inputs=[radio_chats, chat_state], outputs=[text_output, text_output2])

def remove_chat(chat_key, chat_state1):
chat_state1.pop(chat_key, None)
return chat_state1

remove_chat_btn.click(remove_chat, inputs=[radio_chats, chat_state], outputs=chat_state) \
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats)

def get_chats1(chat_state1):
base = 'chats'
makedirs(base, exist_ok=True)
filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4()))
with open(filename, "wt") as f:
f.write(json.dumps(chat_state1, indent=2))
return filename

export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False,
api_name='export_chats' if allow_api else None)

def add_chats_from_file(file, chat_state1, add_btn):
if isinstance(file, str):
files = [file]
else:
files = file
for file1 in files:
try:
if hasattr(file1, 'name'):
file1 = file1.name
with open(file1, "rt") as f:
new_chats = json.loads(f.read())
for chat1_k, chat1_v in new_chats.items():
# ignore chat1_k, regenerate and de-dup to avoid loss
chat_state1 = save_chat(chat1_v, None, chat_state1)
except BaseException as e:
print("Add chats exception: %s" % str(e), flush=True)
return chat_state1, add_btn

# note for update_user_db_func output is ignored for db
add_to_chats_btn.click(add_chats_from_file,
inputs=[chatsup_output, chat_state, add_to_chats_btn],
outputs=[chat_state, add_to_my_db_btn], queue=False,
api_name='add_to_chats' if allow_api else None) \
.then(clear_file_list, outputs=chatsup_output, queue=False) \
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats, queue=False)

clear_chat_btn.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
.then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None) \
.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False)

# does both models
clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
.then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None)
clear.click(save_chat, inputs=[text_output, text_output2, chat_state], outputs=chat_state,
api_name='save_chat' if allow_api else None) \
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
api_name='update_chats' if allow_api else None) \
.then(lambda: None, None, text_output, queue=False, api_name='clearB' if allow_api else None) \
.then(lambda: None, None, text_output2, queue=False, api_name='clearB2' if allow_api else None)
# NOTE: clear of instruction/iinput for nochat has to come after score,
# because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
submit_event_nochat = submit_nochat.click(fun,
Expand Down
1 change: 1 addition & 0 deletions gradio_themes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
body_background_fill_dark="*neutral_900",
background_fill_primary_dark="*block_background_fill",
block_radius="0 0 8px 8px",
checkbox_label_text_color_selected_dark='#000000',
)


Expand Down