-
Notifications
You must be signed in to change notification settings - Fork 8
/
train_bert.sh
30 lines (26 loc) · 914 Bytes
/
train_bert.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#!/bin/bash
base_dir=`pwd`
jobname='bert for ComQA'
train_features_path=${base_dir}/data/train.bert.obj
dev_features_path=${base_dir}/data/dev.bert.obj
test_features_path=${base_dir}/data/test.bert.obj
model_save_path=${base_dir}/model/bert.comqa.base.th
echo $jobname
echo "start processing data"
echo $train_features_path
echo $dev_features_path
echo $test_features_path
cd process
#python3 process_bert.py --train=${train_features_path} --dev=${dev_features_path} --test=${test_features_path}
cd ../train
echo "start training"
export OMP_NUM_THREADS=2
python3.6 -m torch.distributed.launch --nproc_per_node=4 bert.py \
--train_file_path=${train_features_path} \
--dev_file_path=${dev_features_path} \
--model_save_path=${model_save_path} \
--epoch=10 \
--pretrain_model=${base_dir}/model/bert.base.th
cd ../evaluation
echo "start evaluation"
python3 evaluate_bert.py ${test_features_path} ${model_save_path}