-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: try to fine tune model to correct topic prompts
- Loading branch information
1 parent
13b6833
commit b5e2cc4
Showing
3 changed files
with
182 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
Goal is to find topic prompts that were edited by the learning team | ||
""" | ||
import django | ||
django.setup() | ||
from sefaria.model import * | ||
from basic_langchain.chat_models import ChatOpenAI | ||
from basic_langchain.schema import SystemMessage, HumanMessage | ||
from util.general import get_by_xml_list, get_by_xml_tag, run_parallel | ||
import random | ||
|
||
random.seed(34243) | ||
|
||
|
||
def find_substantive_difference(texts): | ||
a, b = texts | ||
system = SystemMessage(content="Given two paragraphs, output the essential difference between them as a paraphrase." | ||
" The paraphrase should include the whole phrase with the difference." | ||
" If there are multiple differences, output each phrase with a difference." | ||
" Input text A wrapped in <text_a> tags and text B wrapped in <text_b> tags." | ||
" Output each difference in a <difference> tag." | ||
" Inside each <difference> tag include the original paraphrase from text A in" | ||
" <text_a> tags and the original paraphrase from text B in <text_b> tags.") | ||
human = HumanMessage(content=f"<text>{a}</text>\n<text>{b}</text>") | ||
llm = ChatOpenAI(model="gpt-4o", temperature=0) | ||
response = llm([system, human]) | ||
raw_differences = get_by_xml_list(response.content, "difference") | ||
differences = [] | ||
for diff in raw_differences: | ||
a = get_by_xml_tag(diff, 'text_a') | ||
b = get_by_xml_tag(diff, 'text_b') | ||
differences.append((a, b)) | ||
return differences | ||
|
||
|
||
def find_substantive_difference_testing(): | ||
a = """Rabbi Meir's ordination was a daring act of defiance against Roman persecution, ensuring the survival of Jewish legal traditions. The Epistle of Rav Sherira Gaon recounts the bravery of Rabbi Yehudah ben Bava, who risked his life to ordain rabbis during a time of severe Roman oppression.""" | ||
b = """Rabbi Meir's ordination was a daring act of defiance against Roman persecution, ensuring the survival of Jewish legal traditions. The 10th-century Epistle of Rav Sherira Gaon recounts the bravery of Rabbi Yehudah ben Bava, who risked his life to ordain several rabbis, among them Rabbi Meir, during a time of severe Roman oppression.""" | ||
diffs = find_substantive_difference(a, b) | ||
for d in diffs: | ||
print(d) | ||
|
||
|
||
def find_differences_in_all_links(): | ||
count = 0 | ||
prompts_with_diffs = [] | ||
for link in RefTopicLinkSet({"descriptions.en.ai_prompt": {"$exists": True}}): | ||
description = link.descriptions['en'] | ||
if description['ai_prompt'] != description['prompt']: | ||
count += 1 | ||
prompts_with_diffs += [(description['ai_prompt'], description['prompt'])] | ||
diffs_list = run_parallel(prompts_with_diffs, find_substantive_difference, max_workers=30, desc='find diffs') | ||
for diffs in diffs_list: | ||
for a_diff, b_diff in diffs: | ||
print('----') | ||
print(a_diff) | ||
print(b_diff) | ||
|
||
print(count) | ||
|
||
|
||
if __name__ == '__main__': | ||
# find_substantive_difference_testing() | ||
find_differences_in_all_links() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import django | ||
django.setup() | ||
from sefaria.model import * | ||
from langchain.schema import HumanMessage, SystemMessage, AIMessage | ||
from langchain_community.adapters.openai import convert_message_to_dict | ||
from srsly import write_jsonl | ||
from sklearn.model_selection import train_test_split | ||
|
||
def get_prompts_with_diffs(): | ||
prompts_with_diffs = [] | ||
for link in RefTopicLinkSet({"descriptions.en.ai_prompt": {"$exists": True}}): | ||
description = link.descriptions['en'] | ||
if description['ai_prompt'] != description['prompt']: | ||
prompts_with_diffs += [(description['ai_prompt'], description['prompt'])] | ||
return prompts_with_diffs | ||
|
||
|
||
class GptPromptTrainingGenerator: | ||
|
||
@staticmethod | ||
def generate(input_toprompt, gold_standard_toprompt=None): | ||
""" | ||
Generate a list of messages to feed to GPT to either train or run on | ||
:return: | ||
""" | ||
example = GptPromptTrainingGenerator.generate_one(input_toprompt, gold_standard_toprompt) | ||
return GptPromptTrainingGenerator.serialize_messages(example) | ||
|
||
@staticmethod | ||
def generate_one(input_toprompt, gold_standard_toprompt=None): | ||
return GptPromptTrainingGenerator._generate_one_chat_format(input_toprompt, gold_standard_toprompt) | ||
|
||
@staticmethod | ||
def _generate_one_chat_format(input_toprompt, gold_standard_toprompt=None): | ||
messages = [ | ||
SystemMessage(content=GptPromptTrainingGenerator._create_system_prompt()), | ||
HumanMessage(content=GptPromptTrainingGenerator._create_prompt(input_toprompt)), | ||
] | ||
if gold_standard_toprompt: | ||
messages += [AIMessage(content=GptPromptTrainingGenerator._create_completion(gold_standard_toprompt))] | ||
return messages | ||
|
||
@staticmethod | ||
def serialize_messages(messages): | ||
return {"messages": [convert_message_to_dict(message) for message in messages]} | ||
|
||
@staticmethod | ||
def _create_system_prompt(): | ||
return "You are Jewish scholar knowledgeable in all Torah texts. Your goal is to take a description of a Jewish source and rewrite it, adhering religiously to your style guide." | ||
|
||
@staticmethod | ||
def _create_prompt(input_toprompt): | ||
return input_toprompt | ||
|
||
@staticmethod | ||
def _create_completion(gold_standard_toprompt): | ||
return gold_standard_toprompt | ||
|
||
|
||
def save_fine_tune_training_set(prompts_with_diffs): | ||
write_jsonl("output/fine_tune_training_set.jsonl", training_set) | ||
|
||
|
||
if __name__ == '__main__': | ||
prompts_with_diffs = get_prompts_with_diffs() | ||
fine_tune_data = [GptPromptTrainingGenerator.generate(a, b) for (a, b) in prompts_with_diffs] | ||
training_data, validation_data = train_test_split(fine_tune_data, random_state=613, train_size=0.999) | ||
write_jsonl("output/fine_tune_training_set.jsonl", training_data) | ||
write_jsonl("output/fine_tune_validation_set.jsonl", validation_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
vars: | ||
script_dir: "../../app/util/fine_tune" | ||
base_model: "gpt-3.5-turbo-0125" | ||
output_model_suffix: "topic-prompt" | ||
env: | ||
openai_api_key: OPENAI_API_KEY | ||
workflows: | ||
all: | ||
- upload_files | ||
- create_fine_tune | ||
|
||
commands: | ||
- name: delete_all_files | ||
script: | ||
- "python ${vars.script_dir}/delete_all_files.py" | ||
|
||
- name: delete_all_fine_tunes | ||
script: | ||
- "python ${vars.script_dir}/delete_all_fine_tunes.py" | ||
|
||
- name: delete_last_job | ||
script: | ||
- "python ${vars.script_dir}/delete_last_fine_tune_job.py" | ||
|
||
- name: upload_files | ||
deps: | ||
- 'output/fine_tune_training_set.jsonl' | ||
- 'output/fine_tune_validation_set.jsonl' | ||
outputs: | ||
- 'output/fine_tune_file_ids.json' | ||
script: | ||
- 'python ${vars.script_dir}/upload_fine_tune_files.py output/fine_tune_training_set.jsonl output/fine_tune_validation_set.jsonl' | ||
|
||
|
||
- name: create_fine_tune | ||
deps: | ||
- 'output/fine_tune_file_ids.json' | ||
script: | ||
- 'python ${vars.script_dir}/create_fine_tune.py ${vars.base_model} ${vars.output_model_suffix}' | ||
|
||
- name: fine_tune_status | ||
script: | ||
- 'python ${vars.script_dir}/fine_tune_status.py output/fine_tune_status.json' | ||
|
||
- name: fine_tune_stats | ||
outputs: | ||
- 'output/fine_tune_stats.csv' | ||
script: | ||
- 'python ${vars.script_dir}/fine_tune_stats.py output/fine_tune_stats.csv' |