From b5e2cc44122e387707a4e36c93bddf09c9a339af Mon Sep 17 00:00:00 2001 From: nsantacruz Date: Mon, 17 Jun 2024 11:27:43 +0300 Subject: [PATCH] feat: try to fine tune model to correct topic prompts --- .../topic_prompt/find_topic_prompt_edits.py | 64 +++++++++++++++++ experiments/topic_prompt/fine_tune.py | 69 +++++++++++++++++++ experiments/topic_prompt/project.yml | 49 +++++++++++++ 3 files changed, 182 insertions(+) create mode 100644 experiments/topic_prompt/find_topic_prompt_edits.py create mode 100644 experiments/topic_prompt/fine_tune.py create mode 100644 experiments/topic_prompt/project.yml diff --git a/experiments/topic_prompt/find_topic_prompt_edits.py b/experiments/topic_prompt/find_topic_prompt_edits.py new file mode 100644 index 0000000..edae1c8 --- /dev/null +++ b/experiments/topic_prompt/find_topic_prompt_edits.py @@ -0,0 +1,64 @@ +""" +Goal is to find topic prompts that were edited by the learning team +""" +import django +django.setup() +from sefaria.model import * +from basic_langchain.chat_models import ChatOpenAI +from basic_langchain.schema import SystemMessage, HumanMessage +from util.general import get_by_xml_list, get_by_xml_tag, run_parallel +import random + +random.seed(34243) + + +def find_substantive_difference(texts): + a, b = texts + system = SystemMessage(content="Given two paragraphs, output the essential difference between them as a paraphrase." + " The paraphrase should include the whole phrase with the difference." + " If there are multiple differences, output each phrase with a difference." + " Input text A wrapped in tags and text B wrapped in tags." + " Output each difference in a tag." + " Inside each tag include the original paraphrase from text A in" + " tags and the original paraphrase from text B in tags.") + human = HumanMessage(content=f"{a}\n{b}") + llm = ChatOpenAI(model="gpt-4o", temperature=0) + response = llm([system, human]) + raw_differences = get_by_xml_list(response.content, "difference") + differences = [] + for diff in raw_differences: + a = get_by_xml_tag(diff, 'text_a') + b = get_by_xml_tag(diff, 'text_b') + differences.append((a, b)) + return differences + + +def find_substantive_difference_testing(): + a = """Rabbi Meir's ordination was a daring act of defiance against Roman persecution, ensuring the survival of Jewish legal traditions. The Epistle of Rav Sherira Gaon recounts the bravery of Rabbi Yehudah ben Bava, who risked his life to ordain rabbis during a time of severe Roman oppression.""" + b = """Rabbi Meir's ordination was a daring act of defiance against Roman persecution, ensuring the survival of Jewish legal traditions. The 10th-century Epistle of Rav Sherira Gaon recounts the bravery of Rabbi Yehudah ben Bava, who risked his life to ordain several rabbis, among them Rabbi Meir, during a time of severe Roman oppression.""" + diffs = find_substantive_difference(a, b) + for d in diffs: + print(d) + + +def find_differences_in_all_links(): + count = 0 + prompts_with_diffs = [] + for link in RefTopicLinkSet({"descriptions.en.ai_prompt": {"$exists": True}}): + description = link.descriptions['en'] + if description['ai_prompt'] != description['prompt']: + count += 1 + prompts_with_diffs += [(description['ai_prompt'], description['prompt'])] + diffs_list = run_parallel(prompts_with_diffs, find_substantive_difference, max_workers=30, desc='find diffs') + for diffs in diffs_list: + for a_diff, b_diff in diffs: + print('----') + print(a_diff) + print(b_diff) + + print(count) + + +if __name__ == '__main__': + # find_substantive_difference_testing() + find_differences_in_all_links() diff --git a/experiments/topic_prompt/fine_tune.py b/experiments/topic_prompt/fine_tune.py new file mode 100644 index 0000000..f35fb4e --- /dev/null +++ b/experiments/topic_prompt/fine_tune.py @@ -0,0 +1,69 @@ +import django +django.setup() +from sefaria.model import * +from langchain.schema import HumanMessage, SystemMessage, AIMessage +from langchain_community.adapters.openai import convert_message_to_dict +from srsly import write_jsonl +from sklearn.model_selection import train_test_split + +def get_prompts_with_diffs(): + prompts_with_diffs = [] + for link in RefTopicLinkSet({"descriptions.en.ai_prompt": {"$exists": True}}): + description = link.descriptions['en'] + if description['ai_prompt'] != description['prompt']: + prompts_with_diffs += [(description['ai_prompt'], description['prompt'])] + return prompts_with_diffs + + +class GptPromptTrainingGenerator: + + @staticmethod + def generate(input_toprompt, gold_standard_toprompt=None): + """ + Generate a list of messages to feed to GPT to either train or run on + :return: + """ + example = GptPromptTrainingGenerator.generate_one(input_toprompt, gold_standard_toprompt) + return GptPromptTrainingGenerator.serialize_messages(example) + + @staticmethod + def generate_one(input_toprompt, gold_standard_toprompt=None): + return GptPromptTrainingGenerator._generate_one_chat_format(input_toprompt, gold_standard_toprompt) + + @staticmethod + def _generate_one_chat_format(input_toprompt, gold_standard_toprompt=None): + messages = [ + SystemMessage(content=GptPromptTrainingGenerator._create_system_prompt()), + HumanMessage(content=GptPromptTrainingGenerator._create_prompt(input_toprompt)), + ] + if gold_standard_toprompt: + messages += [AIMessage(content=GptPromptTrainingGenerator._create_completion(gold_standard_toprompt))] + return messages + + @staticmethod + def serialize_messages(messages): + return {"messages": [convert_message_to_dict(message) for message in messages]} + + @staticmethod + def _create_system_prompt(): + return "You are Jewish scholar knowledgeable in all Torah texts. Your goal is to take a description of a Jewish source and rewrite it, adhering religiously to your style guide." + + @staticmethod + def _create_prompt(input_toprompt): + return input_toprompt + + @staticmethod + def _create_completion(gold_standard_toprompt): + return gold_standard_toprompt + + +def save_fine_tune_training_set(prompts_with_diffs): + write_jsonl("output/fine_tune_training_set.jsonl", training_set) + + +if __name__ == '__main__': + prompts_with_diffs = get_prompts_with_diffs() + fine_tune_data = [GptPromptTrainingGenerator.generate(a, b) for (a, b) in prompts_with_diffs] + training_data, validation_data = train_test_split(fine_tune_data, random_state=613, train_size=0.999) + write_jsonl("output/fine_tune_training_set.jsonl", training_data) + write_jsonl("output/fine_tune_validation_set.jsonl", validation_data) \ No newline at end of file diff --git a/experiments/topic_prompt/project.yml b/experiments/topic_prompt/project.yml new file mode 100644 index 0000000..bb877d7 --- /dev/null +++ b/experiments/topic_prompt/project.yml @@ -0,0 +1,49 @@ +vars: + script_dir: "../../app/util/fine_tune" + base_model: "gpt-3.5-turbo-0125" + output_model_suffix: "topic-prompt" +env: + openai_api_key: OPENAI_API_KEY +workflows: + all: + - upload_files + - create_fine_tune + +commands: + - name: delete_all_files + script: + - "python ${vars.script_dir}/delete_all_files.py" + + - name: delete_all_fine_tunes + script: + - "python ${vars.script_dir}/delete_all_fine_tunes.py" + + - name: delete_last_job + script: + - "python ${vars.script_dir}/delete_last_fine_tune_job.py" + + - name: upload_files + deps: + - 'output/fine_tune_training_set.jsonl' + - 'output/fine_tune_validation_set.jsonl' + outputs: + - 'output/fine_tune_file_ids.json' + script: + - 'python ${vars.script_dir}/upload_fine_tune_files.py output/fine_tune_training_set.jsonl output/fine_tune_validation_set.jsonl' + + + - name: create_fine_tune + deps: + - 'output/fine_tune_file_ids.json' + script: + - 'python ${vars.script_dir}/create_fine_tune.py ${vars.base_model} ${vars.output_model_suffix}' + + - name: fine_tune_status + script: + - 'python ${vars.script_dir}/fine_tune_status.py output/fine_tune_status.json' + + - name: fine_tune_stats + outputs: + - 'output/fine_tune_stats.csv' + script: + - 'python ${vars.script_dir}/fine_tune_stats.py output/fine_tune_stats.csv'