Skip to content

Commit

Permalink
Merge pull request #156 from h2oai/save_export_chats
Browse files Browse the repository at this point in the history
Control chats, save, export, import and otherwise manage
  • Loading branch information
pseudotensor authored May 22, 2023
2 parents 4d50f91 + 63c20e5 commit 3417659
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 7 deletions.
155 changes: 148 additions & 7 deletions gradio_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import functools
import inspect
import json
import os
import random
import sys
import traceback
import uuid
Expand All @@ -13,7 +15,7 @@
from prompter import Prompter, \
prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
ping, get_short_name, get_url
ping, get_short_name, get_url, makedirs
from finetune import generate_prompt
from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
inputs_kwargs_list, get_cutoffs, scratch_base_dir
Expand Down Expand Up @@ -160,6 +162,7 @@ def _postprocess_chat_messages(self, chat_message: str):
model_options_state = gr.State([model_options])
lora_options_state = gr.State([lora_options])
my_db_state = gr.State([None, None])
chat_state = gr.State({})
gr.Markdown(f"""
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
Expand All @@ -179,7 +182,7 @@ def _postprocess_chat_messages(self, chat_message: str):
with gr.Row():
col_nochat = gr.Column(visible=not kwargs['chat'])
with col_nochat: # FIXME: for model comparison, and check rest
text_output_nochat = gr.Textbox(lines=5, label=output_label0)
text_output_nochat = gr.Textbox(lines=5, label=output_label0).style(show_copy_button=True)
instruction_nochat = gr.Textbox(
lines=kwargs['input_lines'],
label=instruction_label_nochat,
Expand Down Expand Up @@ -213,7 +216,7 @@ def _postprocess_chat_messages(self, chat_message: str):
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
with gr.Row():
clear = gr.Button("New Conversation")
clear = gr.Button("Save, New Conversation")
flag_btn = gr.Button("Flag")
if not kwargs['auto_score']: # FIXME: For checkbox model2
with gr.Column(visible=kwargs['score_model']):
Expand All @@ -232,7 +235,7 @@ def _postprocess_chat_messages(self, chat_message: str):
score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
retry = gr.Button("Regenerate")
undo = gr.Button("Undo")
with gr.TabItem("Input/Output"):
with gr.TabItem("Chat"):
with gr.Row():
if 'mbart-' in kwargs['model_lower']:
src_lang = gr.Dropdown(list(languages_covered().keys()),
Expand All @@ -241,6 +244,22 @@ def _postprocess_chat_messages(self, chat_message: str):
tgt_lang = gr.Dropdown(list(languages_covered().keys()),
value=kwargs['tgt_lang'],
label="Output Language")
radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
type='value')
with gr.Row():
remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
chats_row = gr.Row(visible=True).style(equal_height=False)
with chats_row:
export_chats_btn = gr.Button(value="Export Chats")
chats_file = gr.File(interactive=False, label="Download File")
chats_row2 = gr.Row(visible=True).style(equal_height=False)
with chats_row2:
chatsup_output = gr.File(label="Upload Chat File(s)",
file_types=['.json'],
file_count='multiple',
elem_id="warning", elem_classes="feedback")
add_to_chats_btn = gr.Button("Add File(s) to Chats")
with gr.TabItem("Data Source"):
langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
from_str=True)
Expand Down Expand Up @@ -482,7 +501,8 @@ def upload_file(files, x):
with gr.Column():
with gr.Row():
system_btn = gr.Button(value='Get System Info')
system_text = gr.Textbox(label='System Info', interactive=False)
system_text = gr.Textbox(label='System Info', interactive=False).style(
show_copy_button=True)

with gr.Row():
zip_btn = gr.Button("Zip")
Expand Down Expand Up @@ -974,9 +994,130 @@ def clear_instruct():
.then(**score_args_submit, api_name='undo_score' if allow_api else None) \
.then(**score_args2_submit, api_name='undo_score2' if allow_api else None)

# MANAGE CHATS
def dedup(short_chat, short_chats):
if short_chat not in short_chats:
return short_chat
for i in range(1, 1000):
short_chat_try = short_chat + "_" + str(i)
if short_chat_try not in short_chats:
return short_chat_try
# fallback and hope for best
short_chat = short_chat + "_" + str(random.random())
return short_chat

def get_short_chat(x, short_chats, short_len=20, words=4):
if x and len(x[0]) == 2 and x[0][0] is not None:
short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
short_chat = dedup(short_chat, short_chats)
else:
short_chat = None
return short_chat

def is_chat_same(x, y):
# <p> etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation
is_same = True
# length of conversation has to be same
if len(x) != len(y):
return False
for stepx, stepy in zip(x, y):
if len(stepx) != len(stepy):
# something off with a conversation
return False
if len(stepx) != 2:
# something off
return False
if len(stepy) != 2:
# something off
return False
questionx = stepx[0].replace('<p>', '').replace('</p>', '')
answerx = stepx[1].replace('<p>', '').replace('</p>', '')

questiony = stepy[0].replace('<p>', '').replace('</p>', '')
answery = stepy[1].replace('<p>', '').replace('</p>', '')

if questionx != questiony or answerx != answery:
return False
return is_same

def save_chat(chat1, chat2, chat_state1):
short_chats = list(chat_state1.keys())
for chati in [chat1, chat2]:
if chati and len(chati) > 0 and len(chati[0]) == 2 and chati[0][1] is not None:
short_chat = get_short_chat(chati, short_chats)
if short_chat:
already_exists = any([is_chat_same(chati, x) for x in chat_state1.values()])
if not already_exists:
chat_state1[short_chat] = chati
return chat_state1

def update_radio_chats(chat_state1):
return gr.update(choices=list(chat_state1.keys()), value=None)

def deselect_radio_chats():
return gr.update(value=None)

def switch_chat(chat_key, chat_state1):
chosen_chat = chat_state1[chat_key]
return chosen_chat, chosen_chat

radio_chats.input(switch_chat, inputs=[radio_chats, chat_state], outputs=[text_output, text_output2])

def remove_chat(chat_key, chat_state1):
chat_state1.pop(chat_key, None)
return chat_state1

remove_chat_btn.click(remove_chat, inputs=[radio_chats, chat_state], outputs=chat_state) \
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats)

def get_chats1(chat_state1):
base = 'chats'
makedirs(base, exist_ok=True)
filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4()))
with open(filename, "wt") as f:
f.write(json.dumps(chat_state1, indent=2))
return filename

export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False,
api_name='export_chats' if allow_api else None)

def add_chats_from_file(file, chat_state1, add_btn):
if isinstance(file, str):
files = [file]
else:
files = file
for file1 in files:
try:
if hasattr(file1, 'name'):
file1 = file1.name
with open(file1, "rt") as f:
new_chats = json.loads(f.read())
for chat1_k, chat1_v in new_chats.items():
# ignore chat1_k, regenerate and de-dup to avoid loss
chat_state1 = save_chat(chat1_v, None, chat_state1)
except BaseException as e:
print("Add chats exception: %s" % str(e), flush=True)
return chat_state1, add_btn

# note for update_user_db_func output is ignored for db
add_to_chats_btn.click(add_chats_from_file,
inputs=[chatsup_output, chat_state, add_to_chats_btn],
outputs=[chat_state, add_to_my_db_btn], queue=False,
api_name='add_to_chats' if allow_api else None) \
.then(clear_file_list, outputs=chatsup_output, queue=False) \
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats, queue=False)

clear_chat_btn.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
.then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None) \
.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False)

# does both models
clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
.then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None)
clear.click(save_chat, inputs=[text_output, text_output2, chat_state], outputs=chat_state,
api_name='save_chat' if allow_api else None) \
.then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
api_name='update_chats' if allow_api else None) \
.then(lambda: None, None, text_output, queue=False, api_name='clearB' if allow_api else None) \
.then(lambda: None, None, text_output2, queue=False, api_name='clearB2' if allow_api else None)
# NOTE: clear of instruction/iinput for nochat has to come after score,
# because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
submit_event_nochat = submit_nochat.click(fun,
Expand Down
1 change: 1 addition & 0 deletions gradio_themes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
body_background_fill_dark="*neutral_900",
background_fill_primary_dark="*block_background_fill",
block_radius="0 0 8px 8px",
checkbox_label_text_color_selected_dark='#000000',
)


Expand Down

0 comments on commit 3417659

Please sign in to comment.