Codes to the paper Improving Weak-to-Strong Generalization with Reliability-Aware Alignment
Please create the environment from the environment.yml file:
conda env create -f environment.yml
Then enter your conda environment to run the codes.
We did experiments on four datasets: Hellaswag, MMLU, ETHICS (commonsense subset), and GSM8K. To reproduce the results of each dataset, please enter the corresponding dataset directory.
Besides, to run the codes, please download your own pre-trained models from Huggingface into the models
directory and name them like
models
└───Llama-2-7b-chat-hf
│ files of the model and tokenizer
└───Llama-2-13b-chat-hf
│ files of the model and tokenizer
└───Llama-3-8B
│ files of the model and tokenizer
└───mistralai_Mistral-7B-Instruct-v0.2
│ files of the model and tokenizer
The first step to conducting weak-to-strong generalization is to train a weak supervisor on the training data. To run the codes that create weak supervisor, go to the corresponding dataset directory and first create two folders to store the trained models and results by mkdir models results
, then run
python create_weak_supervisor_on_{dataset_name}.py --weak_model_name {weak_model_name}
,
where the {weak_model_name}
should be replaced by choosing from ["Mistral-7B","LLAMA2-7B"].
In the naive weak-to-strong generalization setting, we prompt the weak labels from the weak supervisor and SFT the strong model using the naive weak labels. To reproduce the results, go to the corresponding dataset directory and run
python baseline_SFT_weak_supervision.py --weak_model_name {weak_model_name} --strong_model_name {strong_model_name}
,
where the {weak_model_name}
should be chosen from ["Mistral-7B","LLAMA2-7B"], and {strong_model_name}
should be chosen from ["Mistral-7B","LLAMA2-13B","LLAMA3-8B"].
The ceiling performance is achieved by fine-tuning the strong model on the ground truth label of the validation data. To run the code, use
python baseline_SFT_on_groundtruth.py --strong_model_name {strong_model_name}
,
where the {strong_model_name}
should be chosen from ["Mistral-7B","LLAMA2-13B","LLAMA3-8B"].
In this method, we prompt a question with multiple variations to the model, which returns multiple answers from the model. We then compute the uncertainty of these answers using the entropy score. Answers with high entropy are filtered out, and only the low-entropy answers are preserved for model SFT. Run
python our_method_uncertainty_filtering.py --weak_model_name {weak_model_name} --strong_model_name {strong_model_name}
,
or
python our_method_uncertainty_filtering.py --weak_model_name {weak_model_name} --strong_model_name {strong_model_name} --do_sampling
for a sampled dataset.
The {weak_model_name}
should be chosen from ["Mistral-7B","LLAMA2-7B"], and {strong_model_name}
should be chosen from ["Mistral-7B","LLAMA2-13B","LLAMA3-8B"].
In this method, we use the multiple answers from the previous prompting. Then we reweight the loss of (question, weak answer) sample with the reliability score of the weak answer.
Note that you should run method 1 first before running the script for method 2. This is because method 2 re-uses the file {dataset}validation_with_entropy_from{weak_model_name}.csv generated from method 1. To reproduce the results of method 2, run
python our_method_reweighting.py --weak_model_name {weak_model_name} --strong_model_name {strong_model_name}
,
or
python our_method_reweighting.py --weak_model_name {weak_model_name} --strong_model_name {strong_model_name} --do_sampling
for a sampled dataset.
where the {weak_model_name}
should be chosen from ["Mistral-7B","LLAMA2-7B"], and {strong_model_name}
should be chosen from ["Mistral-7B","LLAMA2-13B","LLAMA3-8B"].
To reproduce Figure 3 and 5 from the paper, please use the Visualization_on_entropy_uncertainty.ipynb
notebook.
To reproduce Figure 4 and 6 from the paper, please use the Visualization_on_probability_reliability_score.ipynb
notebook.