🎉 RefAug has been accepted to EMNLP 2024 (main conference)! We have updated the Arxiv paper to the camera-ready version. Check it out!
This is the code repo for the paper Learn Beyond The Answer: Training Language Models with Reflection for Mathematical Reasoning. Our work introduces reflective augmentation, a novel technique that aims at cultivating a deeper understanding of the training problems, so as to enhance performance not only in the standard single-round QA settings but also in more complex scenarios that require reflective thinking. Please refer to our paper for more details!
Our models are tested on A100 nodes with CUDA version 11.7 and Python 3.9. Please refer to requirements.txt
for the Python environment we used.
- The original training data from GSM8k and MATH:
data/original/train.json
. - The test data for standard math reasoning tasks:
data/original/test.json
. If you are training RefAug model, please make a copy of thistest.json
into theRefAug
directory. The original unprocessed data from MathInstruct are indata/MathInstruct
. - The training data for RefAug:
data/RefAug/train.json
. This is generated using the scriptsrc/data/get_reflection_openai.py
. If you want to use an open-source model to generate the RefAug data, checksrc/data/get_reflection_hf.py
. - The test data for MathChat tasks:
data/original/test-multiturn-followup.json' and
data/original/test-multiturn-error-correct.json`. - The test data for the math subset of MINT:
data/original/test-mint-original-prompt.json
- Training: please refer to
scripts/train.sh
- Inference: the training script combines inference, and there is also a separate script called
scripts/inference.sh
- After inference, for RefAug models, remove the generated reflective section using
src/evaluate/remove_followup.py
. Another option is to addReflection:
as a termination string into model decoding.
- For evaluating standard math reasoning tasks, we largely follow the implementation of MathInstruct. Please check
src/evaluate/eval_mathinstruct.py
. Note that for MMLU and SAT, since they are multiple-choice tasks, please first extract the predicted option usingsrc/evaluate/gpt_extract_answer.py
before callingeval_mathinstruct.py
. - For evaluating MathChat tasks, check
src/evaluate/eval_multiturn_gsm.py
for follow-up QA andsrc/evaluate/eval_error_correction_gsm.py
for error correction. After inference, re-run the script with the same arguments will skip inference and directly show the results. - For evaluating MINT, check
src/evaluate/mint/eval_mint.py
. After inference, re-run the script with the same arguments will skip inference and directly show the results. - For error analysis on GSM8k test set, check
src/evaluate/error_analysis.py
. - For contamination test on GSM8k and MATH, check
src/evaluate/check_overlap.py
If you find our data or code useful, please kindly cite our paper:
@article{zhang2024refaug,
title={Learn Beyond The Answer: Training Language Models with Reflection for Mathematical Reasoning},
author={Zhang, Zhihan and Liang, Zhenwen and Yu, Wenhao and Yu, Dian and Jia, Mengzhao and Yu, Dong and Jiang, Meng},
journal={ArXiv preprint},
volume={2406.12050},
year={2024}
}