-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_qa.sh
82 lines (75 loc) · 2.42 KB
/
predict_qa.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/bin/bash
# Copyright 2020 Google and DeepMind.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Script to obtain predictions using a trained model on XQuAD, TyDi QA, and MLQA.
REPO=$PWD
MODEL=${1:-bert-base-multilingual-cased}
MODEL_TYPE=${2:-bert}
MODEL_PATH=${3}
TGT=${4:-xquad}
GPU=${5:-0}
DATA_DIR=${6:-"$REPO/download/"}
PREDICTIONS_DIR=${7:-"$REPO/predictions"}
PREDICT_FILE_NAME=${8}
if [ ! -d "${MODEL_PATH}" ]; then
echo "Model path ${MODEL_PATH} does not exist."
exit
fi
DIR=${DATA_DIR}/${TGT}/
PRED_DIR=${PREDICTIONS_DIR}/$TGT/
mkdir -p "${PRED_DIR}"
if [ $TGT == 'xquad' ]; then
langs=( en es de el ru tr ar vi th zh hi )
elif [ $TGT == 'mlqa' ]; then
langs=( en es de ar hi vi zh )
elif [ $TGT == 'tydiqa' ]; then
langs=( en ar bn fi id ko ru sw te )
elif [ $TGT == 'chaii_hi' ]; then
DIR=${DATA_DIR}
langs=( hi )
PRED_DIR=${PREDICTIONS_DIR}
elif [ $TGT == 'chaii_ta' ]; then
DIR=${DATA_DIR}
langs=( ta )
PRED_DIR=${PREDICTIONS_DIR}
fi
echo "************************"
echo ${MODEL}
echo "************************"
echo
echo "Predictions on $TGT"
for lang in ${langs[@]}; do
echo " $lang "
if [ $TGT == 'xquad' ]; then
TEST_FILE=${DIR}/xquad.$lang.json
elif [ $TGT == 'mlqa' ]; then
TEST_FILE=${DIR}/MLQA_V1/test/test-context-$lang-question-$lang.json
elif [ $TGT == 'tydiqa' ]; then
TEST_FILE=${DIR}/tydiqa-goldp-v1.1-dev/tydiqa.$lang.dev.json
elif [ $TGT == 'chaii_hi' ]; then
PREDICT_FILE_NAME=${PREDICT_FILE_NAME:-"dev.hi.qa.jsonl"}
TEST_FILE=${DIR}/${PREDICT_FILE_NAME}
elif [ $TGT == 'chaii_ta' ]; then
PREDICT_FILE_NAME=${PREDICT_FILE_NAME:-"dev.ta.qa.jsonl"}
TEST_FILE=${DIR}/${PREDICT_FILE_NAME}
fi
CUDA_VISIBLE_DEVICES=${GPU} python third_party/run_squad.py \
--model_type ${MODEL_TYPE} \
--model_name_or_path ${MODEL_PATH} \
--do_eval \
--do_lower_case \
--eval_lang ${lang} \
--predict_file "${TEST_FILE}" \
--output_dir "${PRED_DIR}" &> /dev/null
done