-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 836eb63
Showing
17 changed files
with
2,729 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
.DS_Store | ||
.idea/ | ||
.env | ||
data/ | ||
text_completion_data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Stable Sequential Unlearning | ||
|
||
This is the code and repo for Stable Sequential Unlearning (SSU). | ||
|
||
## Installation | ||
You need to install packages described in [requirements.txt](requirements.txt). We strongly recommend using a Conda environment. You will also need | ||
a .env file to store you HF_ACCESS_TOKEN to download Llama models. | ||
|
||
## Dataset Setup | ||
|
||
To begin, obtain the `.txt` versions of the books. You can either purchase them or download them from public sources, such as [Project Gutenberg](https://gutenberg.org/), or you can crawl and preprocess | ||
these books by following [gutenberg](https://github.com/pgcorpus/gutenberg). | ||
|
||
### Directory Structure | ||
1. **Create a main directory called {PATH_TO_DATA}**: This will store the `.txt` files of the books. | ||
2. **Organize files by time steps**: Inside the main directory, create subdirectories for each time step, named as `time_step_{x}` (e.g., `time_step_1`, `time_step_2`, etc.). Place the respective `.txt` files in these subdirectories. | ||
|
||
### Commands to Run | ||
|
||
Once the directory structure is set up, execute the following commands in sequence: | ||
|
||
```bash | ||
python create_training_data_csv.py --input_dir {PATH_TO_DATA} | ||
python create_training_data_bf.py | ||
python generate_json_train_test.py | ||
python generate_json.py | ||
python combine_previous_json.py | ||
``` | ||
|
||
## Fine-tuning | ||
To fine-tune the model, you can use the [fine_tune_books.py](fine_tune_books.py) script. | ||
|
||
The script create_training_data.py requires the following parameters: | ||
|
||
* --model_dir: Specifies the directory where your models are saved locally. This is essential for loading pre-trained or fine-tuned models. | ||
* --data_dir: Specifies the directory containing all the books you have saved. This directory will be used to load the data for the unlearning process. | ||
* --book_corpus_norm_data: Specifies the directory of the book corpus that does not contain the books you want to unlearn. This is used for GA-based methods. | ||
* --Lora: A flag to use LoRA (Low-Rank Adaptation) parameterization. If set, the script will use LoRA for fine-tuning. Default is False. | ||
* --time_step_num: An optional parameter to specify the time step number. If not specified, the script will fine-tune all the books in the data directory. | ||
|
||
|
||
You also need to adjust the fine-tuning config file [fine_tune_config.py](config/fine_tune_config.py): | ||
* base_model_name: Specifies the base model name. This is the base model you want to unlearn. | ||
* use_quantization: True if you want to use quantization, False otherwise (Please set to False because TV subtraction will be inaccurate using quantization). | ||
* time_step_num: The time step number you want to fine-tune. | ||
* random_loss_epsilon: Specifies the epsilon value for the random labeling loss. | ||
* num_std_dev : 0 if you just want to use mean, or however many standard deviations away from mean you want to use for the saliency-based weight update. | ||
* batch_size: The batch size for fine-tuning. | ||
* lr: The learning rate for fine-tuning. | ||
* max_unlearn_steps: The number of epochs for each unlearning steps. | ||
* gradient_accumulation_steps: The number of gradient accumulation steps. | ||
* ga_factual_epsilon: The epsilon value for the factual loss term for GA Difference. | ||
* npo_beta: The beta value for the NPO. | ||
* random_loss_pairs: Number of mismatch pairs for SSU. | ||
* intervention: The intervention method for fine-tuning. | ||
|
||
Available options are "unlearning_ga_none" (GA), "unlearning_npo_none" (NPO), "unlearning_ga_mismatch" (Gradient Difference), "unlearning_tv_none" (Pure TV), | ||
"unlearning_tv_ssu" (SSU). Note that when running NPO, you should first obtain a oracle model. In this case, you should use | ||
'unlearning_npo_ref' as the intervention method. | ||
|
||
In addition, if you are interested in running ablation studies, you can use the following interventions: | ||
"unlearning_tv_ssu_no_weight_saliency" and "unlearning_tv_ssu_no_random_loss". | ||
|
||
|
||
Lastly, in order to fine-tune a model on $D_f$, you should use the intervention "unlearning_gd_none". | ||
|
||
|
||
## Evaluation | ||
|
||
Please install [CoTaEval](https://github.com/boyiwei/CoTaEval/tree/main) to download MMLU dataset, | ||
setup Bloom filters for MemFree Decode, and setup MT-Bench running environment. For MT-Bench, you will need to | ||
slightly modify [gen_model_answer.py](https://github.com/boyiwei/CoTaEval/blob/main/eval/FastChat_new/fastchat/llm_judge/gen_model_answer.py) to | ||
use the unlearned model. | ||
|
||
After setting up the evaluation environment, you can run the [evaluate_unlearn.py](evaluate_unlearn.py) script to evaluate the unlearning performance. | ||
|
||
The script [evaluate_unlearn.py](evaluate_unlearn.py) requires the following parameters: | ||
* --base_dir: The directory to this package. | ||
* --model_dir: Specifies the directory where your models are saved locally. This is essential for loading pre-trained or fine-tuned models. | ||
* --time_step_num: The time step number you want to evaluate. | ||
* --single_book: A flag to evaluate a single book at each time step (If set to False, please update file_path variable in the code). | ||
* --use_all: A flag to use entire unlearning dataset for each time step (user can choose to unlearn a part of it (by splitting into 'training' and 'testing' set to save computational resources). | ||
* --eval_mode: True if we only evaluate $D_{nor}$ | ||
* --eval_mmlu_only: True if we only evaluate MMLU. | ||
|
||
You also need to adjust the evaluation config file [metrics_config.py](config/metrics_config.py): | ||
* model_name: Specifies the base model name. This is the base model you want to unlearn. | ||
* is_instruct_model: True if the model_name is a instruct model, False otherwise. | ||
* use_quantization: True if you want to use quantization, False otherwise (Please set it to False). | ||
* train_or_test: Select training or testing set (Ignore it if you have custom file_path). | ||
* model_dir: Specifies the directory where your models are saved locally. This is essential for loading pre-trained or fine-tuned models. | ||
* datatype: Specifies the datatype for evaluation. Default is 'gutenberg_books'. | ||
* num_test: If use_all is set to False, you need to specify number of testing data. | ||
* eval_general: True if you want to evaluate MMLU after running performance on unlearned books. Default is False. | ||
* n : Choice of n-grams for MemFree Decoding. | ||
* no_context: True if you want to provide no context for MemFree Decoding in the system prompt. | ||
* no_overwrite: True if you do want to overwrite the existing results. | ||
* acs_threshold: The threshold for ACS. | ||
* intervention: The intervention method for evaluation. | ||
|
||
Specifically, available intervention methods are: "sys_prompt-sys_none", "sys_prompt-sys_a", "sys_prompt-dbrx", | ||
"unlearning_ga_none", "unlearning_npo_none", "unlearning_ga_idk", "unlearning_ga_mismatch", | ||
"unlearning_tv_none", "unlearning_tv_ssu", "mem_free_tokenized_consecutive". | ||
|
||
If you want to run ablation studies, you can set intervention to be "unlearning_tv_ssu_no_weight_saliency", | ||
or "unlearning_tv_ssu_no_random_loss". |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import os | ||
import json | ||
|
||
|
||
def combine_previous_test_jsons(output_dir, current_time_step_num): | ||
combined_test_qa_pairs = [] | ||
|
||
# Iterate over the previous time steps to gather all test QA pairs | ||
for time_step_num in range(1, current_time_step_num): | ||
previous_json_file_path = os.path.join(output_dir, f'time_step_{time_step_num}', | ||
f'time_step_{time_step_num}_test_dataset_unlearn.json') | ||
|
||
# Check if the file exists before attempting to open it | ||
if os.path.exists(previous_json_file_path): | ||
with open(previous_json_file_path, 'r', encoding='utf-8') as json_file: | ||
test_qa_pairs = json.load(json_file) | ||
combined_test_qa_pairs.extend(test_qa_pairs) | ||
else: | ||
print(f"Warning: {previous_json_file_path} does not exist and will be skipped.") | ||
|
||
# Define the output JSON file path for the combined test data | ||
combined_json_file_path = os.path.join(output_dir, f'time_step_{current_time_step_num}', | ||
f'time_step_{current_time_step_num}_combined_previous_tests.json') | ||
|
||
# Write the combined list of test QA pairs to a JSON file | ||
with open(combined_json_file_path, 'w', encoding='utf-8') as combined_json_file: | ||
json.dump(combined_test_qa_pairs, combined_json_file, ensure_ascii=False, indent=4) | ||
|
||
print(f"Combined test JSON file for time_step_{current_time_step_num} has been saved to {combined_json_file_path}") | ||
|
||
|
||
# Loop through time steps starting from time_step_num > 1 | ||
for time_step_num in range(2, 11): | ||
output_dir = 'data_csv_single' | ||
combine_previous_test_jsons(output_dir, time_step_num) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import torch | ||
from peft import LoraConfig | ||
from transformers import BitsAndBytesConfig | ||
|
||
bnb_config = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
# bnb_4bit_quant_type="nf4", | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_compute_dtype=torch.bfloat16 | ||
) | ||
|
||
ssu_lora_config = LoraConfig( | ||
r=32, | ||
lora_alpha=32, | ||
inference_mode=False, | ||
lora_dropout=0.01, | ||
bias="none", | ||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | ||
task_type="CAUSAL_LM" | ||
) | ||
|
||
# Define LoraConfig | ||
lora_config = LoraConfig( | ||
r=32, | ||
lora_alpha=32, | ||
inference_mode=False, | ||
lora_dropout=0.01, | ||
bias="none", | ||
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | ||
task_type="CAUSAL_LM" | ||
) | ||
|
||
# Add any other configurations here | ||
config = { | ||
"base_model_name":"mistralai/Mistral-7B-Instruct-v0.3", | ||
'use_quantization': False, | ||
'intervention':'unlearning_ga_none', | ||
'time_step_num':3, | ||
"random_loss_epsilon":0.5, | ||
'num_std_dev' : 1, | ||
'batch_size': 2, | ||
'lr': 1e-5, | ||
'max_unlearn_epochs': 1, | ||
"gradient_accumulation_steps":2, | ||
####################### GA-based methods only ########################### | ||
'ga_factual_epsilon':0.5, | ||
####################### NPO only ########################### | ||
'npo_beta': 0.4, | ||
####################### SSU only ########################### | ||
"random_loss_pairs": 3, | ||
} | ||
|
||
initial_intervention = [ | ||
"unlearning_gd_none", | ||
] | ||
|
||
available_intervention = [ | ||
"sys_prompt-sys_none", | ||
"sys_prompt-sys_a", | ||
"unlearning_ga_none", | ||
"unlearning_npo_none", | ||
"unlearning_ga_idk", | ||
"unlearning_ga_mismatch", | ||
"unlearning_tv_none", | ||
"unlearning_tv_ssu", | ||
"mem_free_tokenized_consecutive", | ||
] | ||
|
||
oracle_model_for_npo = [ | ||
"unlearning_npo_ref", | ||
] | ||
|
||
ablation_intervention = [ | ||
"unlearning_tv_ssu_no_weight_saliency", | ||
"unlearning_tv_ssu_no_random_loss", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
from transformers import BitsAndBytesConfig | ||
|
||
# bnb_config = BitsAndBytesConfig( | ||
# load_in_8bit=True, | ||
# bnb_8bit_use_double_quant=True, | ||
# bnb_8bit_quant_type="nf4", | ||
# bnb_8bit_compute_dtype=torch.float16, | ||
# ) | ||
|
||
bnb_config = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
# bnb_4bit_quant_type="nf4", | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_compute_dtype=torch.bfloat16 | ||
) | ||
|
||
memFree_Prompt_Config = { | ||
# "model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct", | ||
"model_name":"mistralai/Mistral-7B-Instruct-v0.3", | ||
"is_instruct_model" : True, # Whether the model is an instruction following or a chat model | ||
"use_quantization" : False, # Don't use quantization | ||
"train_or_test":"train", | ||
"model_dir": "/cache", # The directory where the model is saved | ||
"datatype": "gutenberg_books", | ||
"n" : 6, # The choice of n-grams for MemFree decoding | ||
"num_tests" : 200, # The number of QA pairs to run for evaluation | ||
"intervention":"unlearning_ga_none", | ||
'eval_general':False, | ||
|
||
# "intervention": "sys_prompt-sys_a", | ||
"acs_threshold": 50, # for non-consecutive case | ||
"no_context": False, | ||
"no_overwrite":False, | ||
} | ||
|
||
available_intervention = [ | ||
"sys_prompt-sys_none", | ||
"sys_prompt-sys_a", | ||
"sys_prompt-dbrx", | ||
"unlearning_ga_none", | ||
"unlearning_npo_none", | ||
"unlearning_ga_idk", | ||
"unlearning_ga_mismatch", | ||
"unlearning_tv_none", | ||
"unlearning_tv_ssu", | ||
"mem_free_tokenized_consecutive", | ||
] | ||
|
||
ablation_intervention = [ | ||
"unlearning_tv_ssu_no_weight_saliency", | ||
"unlearning_tv_ssu_no_random_loss", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
from peft import LoraConfig | ||
from transformers import BitsAndBytesConfig | ||
|
||
# bnb_config = BitsAndBytesConfig( | ||
# load_in_16bit=True, | ||
# bnb_16bit_quant_type="nf16", | ||
# bnb_16bit_compute_dtype=torch.float16, | ||
# bnb_16bit_use_double_quant=True, | ||
# ) | ||
|
||
bnb_config = BitsAndBytesConfig( | ||
load_in_8bit=True, | ||
bnb_8bit_use_double_quant=True, | ||
bnb_8bit_quant_type="nf4", | ||
bnb_8bit_compute_dtype=torch.float16, | ||
) | ||
|
||
config = { | ||
'use_fine_tuned_model': False, | ||
'base_model_name':'meta-llama/Meta-Llama-3-8B', | ||
################################### only if use_fine_tuned_model is True ################################### | ||
'fine_tuned_model_name': 'llama3-8b-harry-potter', | ||
'fine_tuned_filename':'llama_3_8b_hp_checkpoint_base_200.pth', | ||
############################################################################################################ | ||
'tv_ft_filename': 'llama3_tv_random_loss_weight_saliency.pth', | ||
'save_file_name': 'llama3_tv_random_loss_weight_saliency_saved.pth', | ||
'show_sample_output': False | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import os | ||
import json | ||
import random | ||
import argparse | ||
from nltk.tokenize import sent_tokenize, word_tokenize | ||
|
||
def generate_dataset(book_path, output_dir, book_name, num_samples=50): | ||
with open(book_path, 'r', encoding='utf-8') as file: | ||
book_text = file.read() | ||
|
||
sentences = sent_tokenize(book_text) | ||
tokenized_sentences = [word_tokenize(sentence) for sentence in sentences] | ||
|
||
dataset = [] | ||
for _ in range(num_samples): | ||
start_index = random.randint(0, len(tokenized_sentences) - 2) | ||
question_tokens = [] | ||
answer_tokens = [] | ||
|
||
while len(question_tokens) < 200 and start_index < len(tokenized_sentences) - 1: | ||
question_tokens.extend(tokenized_sentences[start_index]) | ||
start_index += 1 | ||
|
||
if len(question_tokens) >= 200: | ||
question_tokens = question_tokens[:200] | ||
|
||
answer_end_index = start_index | ||
while len(answer_tokens) < 150 and answer_end_index < len(tokenized_sentences): | ||
answer_tokens.extend(tokenized_sentences[answer_end_index]) | ||
answer_end_index += 1 | ||
|
||
answer_tokens = answer_tokens[:150] | ||
|
||
question = ' '.join(question_tokens) | ||
answer = ' '.join(answer_tokens) | ||
|
||
# Replace the "â" symbol with a single quotation mark | ||
question = question.replace('â', "'") | ||
answer = answer.replace('â', "'") | ||
|
||
dataset.append({ | ||
'question': question, | ||
'answer': answer | ||
}) | ||
|
||
if not os.path.exists(output_dir): | ||
os.makedirs(output_dir) | ||
|
||
output_file = os.path.join(output_dir, f'{book_name}.json') | ||
with open(output_file, 'w', encoding='utf-8') as file: | ||
json.dump(dataset, file, indent=4, ensure_ascii=False) | ||
|
||
print(f"Dataset generated and saved to {output_file}") | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description='Process text completion data.') | ||
parser.add_argument('--book_path', type=str, help='the path to the book') | ||
parser.add_argument('--output_dir', type=str, help='the path to the output directory') | ||
parser.add_argument("--book_name", type=str, help="the name of the book") | ||
parser.add_argument("--num_samples", type=int, help="number of samples being generated") | ||
args = parser.parse_args() | ||
generate_dataset(args.book_path, args.output_dir, args.book_name, args.num_samples) |
Oops, something went wrong.