Skip to content

symoon11/guided-stream-of-search

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Guided Stream of Search

This is the code for the paper Guided Stream of Search: Learning to Better Search with Language Models via Optimal Path Guidance.

Computational resources

  • Training: 4 x NVIDIA A100 80GB
  • Inference: 1 x NVIDIA RTX 3090 24GB

Prerequisite

The base directory is set to /home/{user}/guided-stream-of-search. All data, checkpoints, and other files will be stored under this base directory. Please update this path as needed before running the script.

Environment settings

conda env create --name countdown --file environment.yaml
conda activate countdown
cd stream-of-search
pip install -r requirements.txt
cd ..
cd tril
pip install -e .
pip install flash-attn --no-build-isolation

Note

Please do not modify the package versions. Any changes may cause numerical instability as discussed in this article.

Data generation

The datasets are saved in /home/{user}/guided-stream-of-search/stream-of-search/data.

conda activate countdown
cd stream-of-search
sh script/task/gen_task.sh  # Training
sh script/task/gen_task_final.sh  # Evaluation

Unsupervised pre-training

The checkpoint is saved in /home/{user}/guided-stream-of-search/stream-of-search/output.

conda activate countdown
cd stream-of-search
sh script/gpt2/train_sos.sh

Supervised fine-tuning with self-generated data

The data and checkpoints are saved in /home/{user}/guided-stream-of-search/stream-of-search/output.

conda activate countdown
cd stream-of-search

# Iteration 1
sh script/gpt2/iter1/gen_star_s0.sh
sh script/gpt2/iter1/gen_gsos_rand_s0.sh --start 0
...
sh script/gpt2/iter1/gen_gsos_rand_s0.sh --start 199000
sh script/gpt2/iter1/train_gsos_rand_s0.sh

# Iteration 2
sh script/gpt2/iter2/gen_gsos_rand_s0.sh --start 0
...
sh script/gpt2/iter2/gen_gsos_rand_s0.sh --start 199000
sh script/gpt2/iter2/train_gsos_rand_s0.sh

# Iteration 3
sh script/gpt2/iter3/gen_gsos_rand_s0.sh --start 0
...
sh script/gpt2/iter3/gen_gsos_rand_s0.sh --start 199000
sh script/gpt2/iter3/train_gsos_rand_s0.sh

Note

The data generation process requires a large number of GPUs. It is recommended to use over 40 NVIDIA RTX 3090 GPUs and run the scripts in parallel.

RL fine-tuning

The checkpoint is saved in /home/{user}/guided-stream-of-search/tril/output.

conda activate countdown
cd tril
sh examples/countdown/countdown_ppo_op.sh

Evaluation

The results are saved in the provided checkpoint directory.

Unsupervised pre-training & supervised fine-tuning

conda activate countdown
cd stream-of-search
python eval.py --ckpt {ckpt} --start 0
...
python eval.py --ckpt {ckpt} --start 9000
cd ..
python summary.py --ckpt {ckpt}

RL fine-tuning

conda activate countdown
cd tril
python eval.py --ckpt {ckpt} --start 0
...
python eval.py --ckpt {ckpt} --start 9000
cd ..
python summary.py --ckpt {ckpt}

Checkpoints

Checkpoints used in the paper can be found in the following links:

Acknowledgements

This repository is built on the following repositories with some modifications. The specific changes we made are detailed in the README file of each directory.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published