-
Notifications
You must be signed in to change notification settings - Fork 16
/
run.sh
36 lines (35 loc) · 1.16 KB
/
run.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# model training
CUDA_VISIBLE_DEVICES=0 python joint_training.py \
--model_name_or_path configs\
--init_checkpoint models/reddit_generator.pkl \
--train_input_file data/reddit_train.db \
--eval_input_file data/reddit_test.txt \
--output_dir outputs/joint_reddit \
--file_suffix joint_reddit \
--train_batch_size 1 \
--gradient_accumulation_steps 1 \
--eval_batch_size 1 \
--num_optim_steps 16000 \
--encoder_model_type ance_roberta \
--pretrained_model_cfg bert-base-uncased \
--model_file models/reddit_retriever.pkl \
--ctx_file data/wiki.txt \
--num_shards 1 \
--batch_size 128 \
--n_docs 2 \
--encoding \
--load_trained_model
# evaluating checkpoint hf_bert
CUDA_VISIBLE_DEVICES=0 python eval_checkpoint.py \
--eval_mode rank \
--encoder_model_type ance_roberta \
--pretrained_model_cfg bert-base-uncased \
--model_file models/reddit_retriever.pkl \
--qa_file data/2k_positive.txt \
--ctx_file data/10k.txt \
--n_docs 50 \
--batch_size 64 \
--shard_id 0 \
--num_shards 1 \
--load_trained_model \
--encoding