Skip to content

Latest commit

 

History

History
137 lines (100 loc) · 3.83 KB

README.md

File metadata and controls

137 lines (100 loc) · 3.83 KB

Guided Stream of Search

This is the code for the paper Guided Stream of Search: Learning to Better Search with Language Models via Optimal Path Guidance.

Computational resources

  • Training: 4 x NVIDIA A100 80GB
  • Inference: 1 x NVIDIA RTX 3090 24GB

Prerequisite

The base directory is set to /home/{user}/guided-stream-of-search. All data, checkpoints, and other files will be stored under this base directory. Please update this path as needed before running the script.

Environment settings

conda env create --name countdown --file environment.yaml
conda activate countdown
cd stream-of-search
pip install -r requirements.txt
cd ..
cd tril
pip install -e .
pip install flash-attn --no-build-isolation

Note

Please do not modify the package versions. Any changes may cause numerical instability as discussed in this article.

Data generation

The datasets are saved in /home/{user}/guided-stream-of-search/stream-of-search/data.

conda activate countdown
cd stream-of-search
sh script/task/gen_task.sh  # Training
sh script/task/gen_task_final.sh  # Evaluation

Unsupervised pre-training

The checkpoint is saved in /home/{user}/guided-stream-of-search/stream-of-search/output.

conda activate countdown
cd stream-of-search
sh script/gpt2/train_sos.sh

Supervised fine-tuning with self-generated data

The data and checkpoints are saved in /home/{user}/guided-stream-of-search/stream-of-search/output.

conda activate countdown
cd stream-of-search

# Iteration 1
sh script/gpt2/iter1/gen_star_s0.sh
sh script/gpt2/iter1/gen_gsos_rand_s0.sh --start 0
...
sh script/gpt2/iter1/gen_gsos_rand_s0.sh --start 199000
sh script/gpt2/iter1/train_gsos_rand_s0.sh

# Iteration 2
sh script/gpt2/iter2/gen_gsos_rand_s0.sh --start 0
...
sh script/gpt2/iter2/gen_gsos_rand_s0.sh --start 199000
sh script/gpt2/iter2/train_gsos_rand_s0.sh

# Iteration 3
sh script/gpt2/iter3/gen_gsos_rand_s0.sh --start 0
...
sh script/gpt2/iter3/gen_gsos_rand_s0.sh --start 199000
sh script/gpt2/iter3/train_gsos_rand_s0.sh

Note

The data generation process requires a large number of GPUs. It is recommended to use over 40 NVIDIA RTX 3090 GPUs and run the scripts in parallel.

RL fine-tuning

The checkpoint is saved in /home/{user}/guided-stream-of-search/tril/output.

conda activate countdown
cd tril
sh examples/countdown/countdown_ppo_op.sh

Evaluation

The results are saved in the provided checkpoint directory.

Unsupervised pre-training & supervised fine-tuning

conda activate countdown
cd stream-of-search
python eval.py --ckpt {ckpt} --start 0
...
python eval.py --ckpt {ckpt} --start 9000
cd ..
python summary.py --ckpt {ckpt}

RL fine-tuning

conda activate countdown
cd tril
python eval.py --ckpt {ckpt} --start 0
...
python eval.py --ckpt {ckpt} --start 9000
cd ..
python summary.py --ckpt {ckpt}

Checkpoints

Checkpoints used in the paper can be found in the following links:

Acknowledgements

This repository is built on the following repositories with some modifications. The specific changes we made are detailed in the README file of each directory.