Skip to content

Commit

Permalink
feat: try to fine tune model to correct topic prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
nsantacruz committed Jun 17, 2024
1 parent 13b6833 commit b5e2cc4
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 0 deletions.
64 changes: 64 additions & 0 deletions experiments/topic_prompt/find_topic_prompt_edits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Goal is to find topic prompts that were edited by the learning team
"""
import django
django.setup()
from sefaria.model import *
from basic_langchain.chat_models import ChatOpenAI
from basic_langchain.schema import SystemMessage, HumanMessage
from util.general import get_by_xml_list, get_by_xml_tag, run_parallel
import random

random.seed(34243)


def find_substantive_difference(texts):
a, b = texts
system = SystemMessage(content="Given two paragraphs, output the essential difference between them as a paraphrase."
" The paraphrase should include the whole phrase with the difference."
" If there are multiple differences, output each phrase with a difference."
" Input text A wrapped in <text_a> tags and text B wrapped in <text_b> tags."
" Output each difference in a <difference> tag."
" Inside each <difference> tag include the original paraphrase from text A in"
" <text_a> tags and the original paraphrase from text B in <text_b> tags.")
human = HumanMessage(content=f"<text>{a}</text>\n<text>{b}</text>")
llm = ChatOpenAI(model="gpt-4o", temperature=0)
response = llm([system, human])
raw_differences = get_by_xml_list(response.content, "difference")
differences = []
for diff in raw_differences:
a = get_by_xml_tag(diff, 'text_a')
b = get_by_xml_tag(diff, 'text_b')
differences.append((a, b))
return differences


def find_substantive_difference_testing():
a = """Rabbi Meir's ordination was a daring act of defiance against Roman persecution, ensuring the survival of Jewish legal traditions. The Epistle of Rav Sherira Gaon recounts the bravery of Rabbi Yehudah ben Bava, who risked his life to ordain rabbis during a time of severe Roman oppression."""
b = """Rabbi Meir's ordination was a daring act of defiance against Roman persecution, ensuring the survival of Jewish legal traditions. The 10th-century Epistle of Rav Sherira Gaon recounts the bravery of Rabbi Yehudah ben Bava, who risked his life to ordain several rabbis, among them Rabbi Meir, during a time of severe Roman oppression."""
diffs = find_substantive_difference(a, b)
for d in diffs:
print(d)


def find_differences_in_all_links():
count = 0
prompts_with_diffs = []
for link in RefTopicLinkSet({"descriptions.en.ai_prompt": {"$exists": True}}):
description = link.descriptions['en']
if description['ai_prompt'] != description['prompt']:
count += 1
prompts_with_diffs += [(description['ai_prompt'], description['prompt'])]
diffs_list = run_parallel(prompts_with_diffs, find_substantive_difference, max_workers=30, desc='find diffs')
for diffs in diffs_list:
for a_diff, b_diff in diffs:
print('----')
print(a_diff)
print(b_diff)

print(count)


if __name__ == '__main__':
# find_substantive_difference_testing()
find_differences_in_all_links()
69 changes: 69 additions & 0 deletions experiments/topic_prompt/fine_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import django
django.setup()
from sefaria.model import *
from langchain.schema import HumanMessage, SystemMessage, AIMessage
from langchain_community.adapters.openai import convert_message_to_dict
from srsly import write_jsonl
from sklearn.model_selection import train_test_split

def get_prompts_with_diffs():
prompts_with_diffs = []
for link in RefTopicLinkSet({"descriptions.en.ai_prompt": {"$exists": True}}):
description = link.descriptions['en']
if description['ai_prompt'] != description['prompt']:
prompts_with_diffs += [(description['ai_prompt'], description['prompt'])]
return prompts_with_diffs


class GptPromptTrainingGenerator:

@staticmethod
def generate(input_toprompt, gold_standard_toprompt=None):
"""
Generate a list of messages to feed to GPT to either train or run on
:return:
"""
example = GptPromptTrainingGenerator.generate_one(input_toprompt, gold_standard_toprompt)
return GptPromptTrainingGenerator.serialize_messages(example)

@staticmethod
def generate_one(input_toprompt, gold_standard_toprompt=None):
return GptPromptTrainingGenerator._generate_one_chat_format(input_toprompt, gold_standard_toprompt)

@staticmethod
def _generate_one_chat_format(input_toprompt, gold_standard_toprompt=None):
messages = [
SystemMessage(content=GptPromptTrainingGenerator._create_system_prompt()),
HumanMessage(content=GptPromptTrainingGenerator._create_prompt(input_toprompt)),
]
if gold_standard_toprompt:
messages += [AIMessage(content=GptPromptTrainingGenerator._create_completion(gold_standard_toprompt))]
return messages

@staticmethod
def serialize_messages(messages):
return {"messages": [convert_message_to_dict(message) for message in messages]}

@staticmethod
def _create_system_prompt():
return "You are Jewish scholar knowledgeable in all Torah texts. Your goal is to take a description of a Jewish source and rewrite it, adhering religiously to your style guide."

@staticmethod
def _create_prompt(input_toprompt):
return input_toprompt

@staticmethod
def _create_completion(gold_standard_toprompt):
return gold_standard_toprompt


def save_fine_tune_training_set(prompts_with_diffs):
write_jsonl("output/fine_tune_training_set.jsonl", training_set)


if __name__ == '__main__':
prompts_with_diffs = get_prompts_with_diffs()
fine_tune_data = [GptPromptTrainingGenerator.generate(a, b) for (a, b) in prompts_with_diffs]
training_data, validation_data = train_test_split(fine_tune_data, random_state=613, train_size=0.999)
write_jsonl("output/fine_tune_training_set.jsonl", training_data)
write_jsonl("output/fine_tune_validation_set.jsonl", validation_data)
49 changes: 49 additions & 0 deletions experiments/topic_prompt/project.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
vars:
script_dir: "../../app/util/fine_tune"
base_model: "gpt-3.5-turbo-0125"
output_model_suffix: "topic-prompt"
env:
openai_api_key: OPENAI_API_KEY
workflows:
all:
- upload_files
- create_fine_tune

commands:
- name: delete_all_files
script:
- "python ${vars.script_dir}/delete_all_files.py"

- name: delete_all_fine_tunes
script:
- "python ${vars.script_dir}/delete_all_fine_tunes.py"

- name: delete_last_job
script:
- "python ${vars.script_dir}/delete_last_fine_tune_job.py"

- name: upload_files
deps:
- 'output/fine_tune_training_set.jsonl'
- 'output/fine_tune_validation_set.jsonl'
outputs:
- 'output/fine_tune_file_ids.json'
script:
- 'python ${vars.script_dir}/upload_fine_tune_files.py output/fine_tune_training_set.jsonl output/fine_tune_validation_set.jsonl'


- name: create_fine_tune
deps:
- 'output/fine_tune_file_ids.json'
script:
- 'python ${vars.script_dir}/create_fine_tune.py ${vars.base_model} ${vars.output_model_suffix}'

- name: fine_tune_status
script:
- 'python ${vars.script_dir}/fine_tune_status.py output/fine_tune_status.json'

- name: fine_tune_stats
outputs:
- 'output/fine_tune_stats.csv'
script:
- 'python ${vars.script_dir}/fine_tune_stats.py output/fine_tune_stats.csv'

0 comments on commit b5e2cc4

Please sign in to comment.