From 09239a8c019edf86fa9adbbd718a3345967ab354 Mon Sep 17 00:00:00 2001 From: "Jonathan C. McKinney" Date: Mon, 22 May 2023 13:34:57 -0700 Subject: [PATCH 1/2] Control chats, save and export and otherwise manage --- gradio_runner.py | 123 ++++++++++++++++++++++++++++++++++++++++++++--- gradio_themes.py | 1 + 2 files changed, 117 insertions(+), 7 deletions(-) diff --git a/gradio_runner.py b/gradio_runner.py index 89c5b742a..bc0ca4781 100644 --- a/gradio_runner.py +++ b/gradio_runner.py @@ -1,7 +1,9 @@ import copy import functools import inspect +import json import os +import random import sys import traceback import uuid @@ -13,7 +15,7 @@ from prompter import Prompter, \ prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \ - ping, get_short_name, get_url + ping, get_short_name, get_url, makedirs from finetune import generate_prompt from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \ inputs_kwargs_list, get_cutoffs, scratch_base_dir @@ -160,6 +162,7 @@ def _postprocess_chat_messages(self, chat_message: str): model_options_state = gr.State([model_options]) lora_options_state = gr.State([lora_options]) my_db_state = gr.State([None, None]) + chat_state = gr.State({}) gr.Markdown(f""" {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)} @@ -179,12 +182,12 @@ def _postprocess_chat_messages(self, chat_message: str): with gr.Row(): col_nochat = gr.Column(visible=not kwargs['chat']) with col_nochat: # FIXME: for model comparison, and check rest - text_output_nochat = gr.Textbox(lines=5, label=output_label0) + text_output_nochat = gr.Textbox(lines=5, label=output_label0).style(show_copy_button=True) instruction_nochat = gr.Textbox( lines=kwargs['input_lines'], label=instruction_label_nochat, placeholder=kwargs['placeholder_instruction'], - ) + ).style(show_copy_button=True) iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction", placeholder=kwargs['placeholder_input']) submit_nochat = gr.Button("Submit") @@ -213,7 +216,7 @@ def _postprocess_chat_messages(self, chat_message: str): submit = gr.Button(value='Submit').style(full_width=False, size='sm') stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm') with gr.Row(): - clear = gr.Button("New Conversation") + clear = gr.Button("New/Save Conversation") flag_btn = gr.Button("Flag") if not kwargs['auto_score']: # FIXME: For checkbox model2 with gr.Column(visible=kwargs['score_model']): @@ -241,6 +244,16 @@ def _postprocess_chat_messages(self, chat_message: str): tgt_lang = gr.Dropdown(list(languages_covered().keys()), value=kwargs['tgt_lang'], label="Output Language") + radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True, + type='value') + with gr.Row(): + remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True) + clear_chat_btn = gr.Button(value="Clear Chat", visible=True) + chats_row = gr.Row(visible=True).style(equal_height=False) + with chats_row: + export_chats_btn = gr.Button(value="Export Chats") + chats_file = gr.File(interactive=False, label="Download File") + with gr.TabItem("Data Source"): langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md', from_str=True) @@ -482,7 +495,8 @@ def upload_file(files, x): with gr.Column(): with gr.Row(): system_btn = gr.Button(value='Get System Info') - system_text = gr.Textbox(label='System Info', interactive=False) + system_text = gr.Textbox(label='System Info', interactive=False).style( + show_copy_button=True) with gr.Row(): zip_btn = gr.Button("Zip") @@ -974,9 +988,104 @@ def clear_instruct(): .then(**score_args_submit, api_name='undo_score' if allow_api else None) \ .then(**score_args2_submit, api_name='undo_score2' if allow_api else None) + # MANAGE CHATS + def dedup(short_chat, short_chats): + if short_chat not in short_chats: + return short_chat + for i in range(1, 1000): + short_chat_try = short_chat + "_" + str(i) + if short_chat_try not in short_chats: + return short_chat_try + # fallback and hope for best + short_chat = short_chat + "_" + str(random.random()) + return short_chat + + def get_short_chat(x, short_chats, short_len=20, words=4): + if x and len(x[0]) == 2 and x[0][0] is not None: + short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip() + short_chat = dedup(short_chat, short_chats) + else: + short_chat = None + return short_chat + + def is_chat_same(x, y): + #

etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation + is_same = True + # length of conversation has to be same + if len(x) != len(y): + return False + for stepx, stepy in zip(x, y): + if len(stepx) != len(stepy): + # something off with a conversation + return False + if len(stepx) != 2: + # something off + return False + if len(stepy) != 2: + # something off + return False + questionx = stepx[0].replace('

', '').replace('

', '') + answerx = stepx[1].replace('

', '').replace('

', '') + + questiony = stepy[0].replace('

', '').replace('

', '') + answery = stepy[1].replace('

', '').replace('

', '') + + if questionx != questiony or answerx != answery: + return False + return is_same + + def save_chat(chat1, chat2, chat_state1): + short_chats = list(chat_state1.keys()) + for chati in [chat1, chat2]: + if chati and len(chati) > 0 and len(chati[0]) == 2 and chati[0][1] is not None: + short_chat = get_short_chat(chati, short_chats) + if short_chat: + already_exists = any([is_chat_same(chati, x) for x in chat_state1.values()]) + if not already_exists: + chat_state1[short_chat] = chati + return chat_state1 + + def update_radio_chats(chat_state1): + return gr.update(choices=list(chat_state1.keys()), value=None) + + def deselect_radio_chats(): + return gr.update(value=None) + + def switch_chat(chat_key, chat_state1): + chosen_chat = chat_state1[chat_key] + return chosen_chat, chosen_chat + + radio_chats.input(switch_chat, inputs=[radio_chats, chat_state], outputs=[text_output, text_output2]) + + def remove_chat(chat_key, chat_state1): + chat_state1.pop(chat_key, None) + return chat_state1 + + remove_chat_btn.click(remove_chat, inputs=[radio_chats, chat_state], outputs=chat_state) \ + .then(update_radio_chats, inputs=chat_state, outputs=radio_chats) + + def get_chats1(chat_state1): + base = 'chats' + makedirs(base, exist_ok=True) + filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4())) + with open(filename, "wt") as f: + f.write(json.dumps(chat_state1, indent=2)) + return filename + + export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False, + api_name='export_chats' if allow_api else None) + + clear_chat_btn.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \ + .then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None) \ + .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) + # does both models - clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \ - .then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None) + clear.click(save_chat, inputs=[text_output, text_output2, chat_state], outputs=chat_state, + api_name='save_chat' if allow_api else None) \ + .then(update_radio_chats, inputs=chat_state, outputs=radio_chats, + api_name='update_chats' if allow_api else None) \ + .then(lambda: None, None, text_output, queue=False, api_name='clearB' if allow_api else None) \ + .then(lambda: None, None, text_output2, queue=False, api_name='clearB2' if allow_api else None) # NOTE: clear of instruction/iinput for nochat has to come after score, # because score for nochat consumes actual textbox, while chat consumes chat history filled by user() submit_event_nochat = submit_nochat.click(fun, diff --git a/gradio_themes.py b/gradio_themes.py index b4e240321..d181a4387 100644 --- a/gradio_themes.py +++ b/gradio_themes.py @@ -75,6 +75,7 @@ def __init__( body_background_fill_dark="*neutral_900", background_fill_primary_dark="*block_background_fill", block_radius="0 0 8px 8px", + checkbox_label_text_color_selected_dark='#000000', ) From 63c20e5fb5f607e8f8026b776f18975bd269e9ac Mon Sep 17 00:00:00 2001 From: "Jonathan C. McKinney" Date: Mon, 22 May 2023 15:20:56 -0700 Subject: [PATCH 2/2] Allow chat import --- gradio_runner.py | 40 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/gradio_runner.py b/gradio_runner.py index bc0ca4781..0362616fc 100644 --- a/gradio_runner.py +++ b/gradio_runner.py @@ -187,7 +187,7 @@ def _postprocess_chat_messages(self, chat_message: str): lines=kwargs['input_lines'], label=instruction_label_nochat, placeholder=kwargs['placeholder_instruction'], - ).style(show_copy_button=True) + ) iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction", placeholder=kwargs['placeholder_input']) submit_nochat = gr.Button("Submit") @@ -216,7 +216,7 @@ def _postprocess_chat_messages(self, chat_message: str): submit = gr.Button(value='Submit').style(full_width=False, size='sm') stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm') with gr.Row(): - clear = gr.Button("New/Save Conversation") + clear = gr.Button("Save, New Conversation") flag_btn = gr.Button("Flag") if not kwargs['auto_score']: # FIXME: For checkbox model2 with gr.Column(visible=kwargs['score_model']): @@ -235,7 +235,7 @@ def _postprocess_chat_messages(self, chat_message: str): score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False) retry = gr.Button("Regenerate") undo = gr.Button("Undo") - with gr.TabItem("Input/Output"): + with gr.TabItem("Chat"): with gr.Row(): if 'mbart-' in kwargs['model_lower']: src_lang = gr.Dropdown(list(languages_covered().keys()), @@ -253,7 +253,13 @@ def _postprocess_chat_messages(self, chat_message: str): with chats_row: export_chats_btn = gr.Button(value="Export Chats") chats_file = gr.File(interactive=False, label="Download File") - + chats_row2 = gr.Row(visible=True).style(equal_height=False) + with chats_row2: + chatsup_output = gr.File(label="Upload Chat File(s)", + file_types=['.json'], + file_count='multiple', + elem_id="warning", elem_classes="feedback") + add_to_chats_btn = gr.Button("Add File(s) to Chats") with gr.TabItem("Data Source"): langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md', from_str=True) @@ -1075,6 +1081,32 @@ def get_chats1(chat_state1): export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False, api_name='export_chats' if allow_api else None) + def add_chats_from_file(file, chat_state1, add_btn): + if isinstance(file, str): + files = [file] + else: + files = file + for file1 in files: + try: + if hasattr(file1, 'name'): + file1 = file1.name + with open(file1, "rt") as f: + new_chats = json.loads(f.read()) + for chat1_k, chat1_v in new_chats.items(): + # ignore chat1_k, regenerate and de-dup to avoid loss + chat_state1 = save_chat(chat1_v, None, chat_state1) + except BaseException as e: + print("Add chats exception: %s" % str(e), flush=True) + return chat_state1, add_btn + + # note for update_user_db_func output is ignored for db + add_to_chats_btn.click(add_chats_from_file, + inputs=[chatsup_output, chat_state, add_to_chats_btn], + outputs=[chat_state, add_to_my_db_btn], queue=False, + api_name='add_to_chats' if allow_api else None) \ + .then(clear_file_list, outputs=chatsup_output, queue=False) \ + .then(update_radio_chats, inputs=chat_state, outputs=radio_chats, queue=False) + clear_chat_btn.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \ .then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None) \ .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False)