Skip to content

Commit

Permalink
Merge pull request #114 from h2oai/chatsep
Browse files Browse the repository at this point in the history
Specify chat separator
  • Loading branch information
pseudotensor authored May 5, 2023
2 parents e2f9b2d + e94abed commit f1ae287
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 20 deletions.
23 changes: 17 additions & 6 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,10 +746,10 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F
assert prompt_type is not None
assert cutoff_len is not None
assert tokenizer is not None
full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
full_prompt, _, _, _ = generate_prompt(data_point, prompt_type, False, False)
tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
if not train_on_inputs:
user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
Expand All @@ -768,9 +768,11 @@ def get_prompt(prompt_type, chat, context, reduced):
if prompt_type in [-1, "-1", "plain"]:
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
terminate_response = []
chat_sep = ''
elif prompt_type == 'simple_instruct':
promptA = promptB = PreInstruct = PreInput = PreResponse = None
terminate_response = []
chat_sep = '\n'
elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
Expand All @@ -790,6 +792,7 @@ def get_prompt(prompt_type, chat, context, reduced):
terminate_response = ['### End']
else:
terminate_response = None
chat_sep = '\n'
elif prompt_type in [1, "1", "quality"]:
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
Expand All @@ -806,6 +809,7 @@ def get_prompt(prompt_type, chat, context, reduced):
### Response:
"""
terminate_response = None
chat_sep = '\n'
elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
if reduced or context or prompt_type in [2, "2", "human_bot"]:
preprompt = ''
Expand Down Expand Up @@ -835,6 +839,7 @@ def get_prompt(prompt_type, chat, context, reduced):
PreResponse = bot

terminate_response = [start, PreResponse]
chat_sep = '\n'
elif prompt_type in [3, "3", "dai_faq"]:
promptA = ''
promptB = 'Answer the following Driverless AI question.\n'
Expand All @@ -849,11 +854,13 @@ def get_prompt(prompt_type, chat, context, reduced):
### Driverless AI documentation answer:
"""
terminate_response = ['\n\n']
chat_sep = terminate_response
elif prompt_type in [5, "5", "summarize"]:
promptA = promptB = PreInput = ''
PreInstruct = '## Main Text\n\n'
PreResponse = '\n\n## Summary\n\n'
terminate_response = None
chat_sep = '\n'
elif prompt_type in [6, "6", "instruct_vicuna"]:
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
Expand All @@ -868,18 +875,21 @@ def get_prompt(prompt_type, chat, context, reduced):
### Assistant:
"""
terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
chat_sep = '\n'
elif prompt_type in [10, "10", "prompt_answer"]:
preprompt = ''
start = prompt_tokens
promptB = promptA = '%s%s' % (preprompt, start)
PreInstruct = ""
PreInput = None
PreResponse = answer_tokens
terminate_response = [start, PreResponse]
eos = '<|endoftext|>' # neox eos
terminate_response = [start, PreResponse, eos]
chat_sep = eos
else:
raise RuntimeError("No such prompt_type=%s" % prompt_type)

return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep


def generate_prompt(data_point, prompt_type, chat, reduced):
Expand All @@ -891,7 +901,8 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
output = data_point.get('output')
prompt_type = data_point.get('prompt_type', prompt_type)
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
promptA, promptB, PreInstruct, PreInput, PreResponse, \
terminate_response, chat_sep = get_prompt(prompt_type, chat, context, reduced)

prompt = context if not reduced else ''

Expand Down Expand Up @@ -943,7 +954,7 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
if output:
prompt += f"""{output}"""

return prompt, pre_response, terminate_response
return prompt, pre_response, terminate_response, chat_sep


def inject_newline(prompt_type, prompt):
Expand Down
30 changes: 21 additions & 9 deletions gradio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys

from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
from prompter import Prompter
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
ping
from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
Expand Down Expand Up @@ -517,9 +518,12 @@ def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
:return:
"""
args_list = list(args)
user_message = args_list[0]
input1 = args_list[1]
context1 = args_list[2]
user_message = args_list[eval_func_param_names.index('instruction')] # chat only
input1 = args_list[eval_func_param_names.index('iinput')] # chat only
context1 = args_list[eval_func_param_names.index('context')]
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
chat1 = args_list[eval_func_param_names.index('chat')]
stream_output1 = args_list[eval_func_param_names.index('stream_output')]
if input1 and not user_message.endswith(':'):
user_message1 = user_message + ":" + input1
elif input1:
Expand All @@ -529,6 +533,8 @@ def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
if sanitize_user_prompt:
from better_profanity import profanity
user_message1 = profanity.censor(user_message1)
# FIXME: WIP to use desired seperator when user enters nothing
prompter = Prompter(prompt_type1, debug=kwargs['debug'], chat=chat1, stream_output=stream_output1)
if user_message1 in ['']:
# e.g. when user just hits enter in textbox,
# else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
Expand Down Expand Up @@ -581,12 +587,18 @@ def bot(*args, retry=False):
context1 = ''
for histi in range(len(history) - 1):
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
context1 += generate_prompt(data_point, prompt_type1, chat1, reduced=True)[0].replace(
'<br>', '\n')
if not context1.endswith('\n'):
context1 += '\n'
if context1 and not context1.endswith('\n'):
context1 += '\n' # ensure if terminates abruptly, then human continues on next line
prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
chat1, reduced=True)
# md -> back to text, maybe not super improtant if model trained enough
prompt = prompt.replace('<br>', chat_sep)
context1 += prompt
if not context1.endswith(chat_sep):
context1 += chat_sep

_, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
reduced=True)
if context1 and not context1.endswith(chat_sep):
context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
args_list[0] = instruction1 # override original instruction with history from user
# only include desired chat history
args_list[2] = context1[-kwargs['chat_history']:]
Expand Down
11 changes: 6 additions & 5 deletions prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, re
allowed_repeat_line_length=10):
self.prompt_type = prompt_type
data_point = dict(instruction='', input='', output='')
_, self.pre_response, self.terminate_response = generate_prompt(data_point, prompt_type, chat, False)
_, self.pre_response, self.terminate_response, self.chat_sep = \
generate_prompt(data_point, prompt_type, chat, False)
self.debug = debug
self.chat = chat
self.stream_output = stream_output
Expand All @@ -15,7 +16,7 @@ def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, re

def generate_prompt(self, data_point):
reduced = False
prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
if self.debug:
print("prompt: ", prompt, flush=True)
self.prompt = prompt
Expand All @@ -25,12 +26,12 @@ def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
if isinstance(outputs, str):
outputs = [outputs]
if self.debug:
print("output: ", '\n\n'.join(outputs), flush=True)
print("output:\n", '\n\n'.join(outputs), flush=True)
if prompt is not None:
self.prompt = prompt

def clean_response(response):
meaningless_words = ['<pad>', '</s>', '<|endoftext|>', '”\n']
meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
for word in meaningless_words:
response = response.replace(word, "")
if sanitize_bot_response:
Expand Down Expand Up @@ -103,5 +104,5 @@ def clean_repeats(response):
# join all outputs, only one extra new line between outputs
output = '\n'.join(outputs)
if self.debug:
print("outputclean: ", '\n\n'.join(outputs), flush=True)
print("outputclean:\n", '\n\n'.join(outputs), flush=True)
return output
17 changes: 17 additions & 0 deletions tests/manual_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def test_chat_context():
# on h2oai/h2ogpt-oasst1-512-20b
instruction = """Rephrase in 5 different ways: “Apple a day keeps the doctor away.”"""
expected_response = """1. “A apple every day will keep you healthy.”
2. “An Apple A Day Keeps The Doctor Away”
3. “One of these apples each and everyday, is all it takes to stay well”
4. “Eat an apple daily for good health!”
5. “If eaten one per day, this fruit can help prevent disease”.
I hope that helps! Let me know if there’s anything else I could do for you today?"""
instruction2 = """Summarize into single sentence."""
expected_response2 = """“The more fruits we eat, the healthier.” - Dr. John Yiamouyiannis (American physician)"""

# NOTE: if something broken, might say something unrelated to first question, e.g.
unexpected_response2 = """I am an AI language model ..."""

raise NotImplementedError("MANUAL TEST FOR NOW")

0 comments on commit f1ae287

Please sign in to comment.