From 7402b4cf90cc7e3dd9d8d229ac94aa83a356cf61 Mon Sep 17 00:00:00 2001 From: Mingzhi Zheng Date: Tue, 29 Oct 2019 10:32:38 -0700 Subject: [PATCH 1/2] fix issues in new quac-kd runner (cont.) --- examples/run_quac_kd.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/run_quac_kd.py b/examples/run_quac_kd.py index e3483369d2cfb6..11ec4152eaee5f 100644 --- a/examples/run_quac_kd.py +++ b/examples/run_quac_kd.py @@ -361,6 +361,7 @@ def predict(args, model, tokenizer, prefix=""): result = result_lookup[feature.unique_id] feature.start_targets = _compute_softmax(result.kd_start_logits) feature.end_targets = _compute_softmax(result.kd_end_logits) + updated_features.append(feature) torch.save(updated_features, updated_features_file) @@ -471,6 +472,8 @@ def main(): help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") + parser.add_argument("--do_predict", action='store_true', + help="Whether to run predict on the dev set.") parser.add_argument("--evaluate_during_training", action='store_true', help="Rul evaluation during training at each logging step.") parser.add_argument("--do_lower_case", action='store_true', From 9c58687db4c58d36e34d831298506e813882b77d Mon Sep 17 00:00:00 2001 From: Mingzhi Zheng Date: Tue, 29 Oct 2019 11:03:08 -0700 Subject: [PATCH 2/2] fix issues in quac runner --- examples/utils_quac.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/utils_quac.py b/examples/utils_quac.py index 45c4cec28ea48d..aeb926267777df 100644 --- a/examples/utils_quac.py +++ b/examples/utils_quac.py @@ -948,9 +948,9 @@ def write_predictions_v2(all_examples, all_features, all_results, n_best_size, if null_score_threshold is not None: if nbest_json[0]["text"] == "CANNOTANSWER" and nbest_json[0]["probability"] > null_score_threshold: - all_predictions[qas_id] = "CANNOTANSWER" + all_predictions[example.qas_id] = "CANNOTANSWER" else: - all_predictions[qas_id] = nbest_json[0]["text"] if nbest_json[0]["text"] != "CANNOTANSWER" else nbest_json[1]["text"] + all_predictions[example.qas_id] = nbest_json[0]["text"] if nbest_json[0]["text"] != "CANNOTANSWER" else nbest_json[1]["text"] else: if not version_2_with_negative: all_predictions[example.qas_id] = nbest_json[0]["text"]