-
Notifications
You must be signed in to change notification settings - Fork 0
/
nq_eval.py
484 lines (381 loc) · 16.6 KB
/
nq_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Official evaluation script for Natural Questions.
https://ai.google.com/research/NaturalQuestions
------------------------------------------------------------------------------
Example usage:
nq_eval --gold_path=<path-to-gold-files> --predictions_path=<path_to_json>
This will compute both the official F1 scores as well as recall@precision
tables for both long and short answers. Note that R@P are only meaningful
if your model populates the score fields of the prediction JSON format.
gold_path should point to the five way annotated dev data in the
original download format (gzipped jsonlines).
predictions_path should point to a json file containing the predictions in
the format given below.
------------------------------------------------------------------------------
Prediction format:
{'predictions': [
{
'example_id': -2226525965842375672,
'long_answer': {
'start_byte': 62657, 'end_byte': 64776,
'start_token': 391, 'end_token': 604
},
'long_answer_score': 13.5,
'short_answers': [
{'start_byte': 64206, 'end_byte': 64280,
'start_token': 555, 'end_token': 560}, ...],
'short_answers_score': 26.4,
'yes_no_answer': 'NONE'
}, ... ]
}
The prediction format mirrors the annotation format in defining each long or
short answer span both in terms of byte offsets and token offsets. We do not
expect participants to supply both.
The order of preference is:
if start_byte >= 0 and end_byte >=0, use byte offsets,
else if start_token >= 0 and end_token >= 0, use token offsets,
else no span is defined (null answer).
The short answer metric takes both short answer spans, and the yes/no answer
into account. If the 'short_answers' list contains any non/null spans, then
'yes_no_answer' should be set to 'NONE'.
-----------------------------------------------------------------------------
Metrics:
If >= 2 of the annotators marked a non-null long answer, then the prediction
must match any one of the non-null long answers to be considered correct.
If >= 2 of the annotators marked a non-null set of short answers, or a yes/no
answer, then the short answers prediction must match any one of the non-null
sets of short answers *or* the yes/no prediction must match one of the
non-null yes/no answer labels.
All span comparisons are exact and each individual prediction can be fully
correct, or incorrect.
Each prediction should be provided with a long answer score, and a short
answers score. At evaluation time, the evaluation script will find a score
threshold at which F1 is maximized. All predictions with scores below this
threshold are ignored (assumed to be null). If the score is not provided,
the evaluation script considers all predictions to be valid. The script
will also output the maximum recall at precision points of >= 0.5, >= 0.75,
and >= 0.9.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import json
import os
import pickle
from absl import app
from absl import flags
from absl import logging
import eval_utils as util
import six
flags.DEFINE_string(
"gold_path",
None,
"Path to the gzip JSON data. For "
"multiple files, should be a glob "
'pattern (e.g. "/path/to/files-*"',
)
flags.DEFINE_string("predictions_path", None, "Path to prediction JSON.")
flags.DEFINE_bool(
"cache_gold_data",
False,
"Whether to cache gold data in Pickle format to speed up " "multiple evaluations.",
)
flags.DEFINE_integer("num_threads", 10, "Number of threads for reading.")
flags.DEFINE_bool("pretty_print", False, "Whether to pretty print output.")
FLAGS = flags.FLAGS
def safe_divide(x, y):
"""Compute x / y, but return 0 if y is zero."""
if y == 0:
return 0
else:
return x / y
def score_long_answer(gold_label_list, pred_label):
"""Scores a long answer as correct or not.
1) First decide if there is a gold long answer with LONG_NO_NULL_THRESHOLD.
2) The prediction will get a match if:
a. There is a gold long answer.
b. The prediction span match exactly with *one* of the non-null gold
long answer span.
Args:
gold_label_list: A list of NQLabel, could be None.
pred_label: A single NQLabel, could be None.
Returns:
gold_has_answer, pred_has_answer, is_correct, score
"""
gold_has_answer = util.gold_has_long_answer(gold_label_list)
pred_has_answer = pred_label and (not pred_label.long_answer_span.is_null_span())
is_correct = False
score = pred_label.long_score
# Both sides are non-null spans.
if gold_has_answer and pred_has_answer:
for gold_label in gold_label_list:
# while the voting results indicate there is an long answer, each
# annotator might still say there is no long answer.
if gold_label.long_answer_span.is_null_span():
continue
if util.nonnull_span_equal(
gold_label.long_answer_span, pred_label.long_answer_span
):
is_correct = True
break
return gold_has_answer, pred_has_answer, is_correct, score
def score_short_answer(gold_label_list, pred_label):
"""Scores a short answer as correct or not.
1) First decide if there is a gold short answer with SHORT_NO_NULL_THRESHOLD.
2) The prediction will get a match if:
a. There is a gold short answer.
b. The prediction span *set* match exactly with *one* of the non-null gold
short answer span *set*.
Args:
gold_label_list: A list of NQLabel.
pred_label: A single NQLabel.
Returns:
gold_has_answer, pred_has_answer, is_correct, score
"""
# There is a gold short answer if gold_label_list not empty and non null
# answers is over the threshold (sum over annotators).
gold_has_answer = util.gold_has_short_answer(gold_label_list)
# There is a pred long answer if pred_label is not empty and short answer
# set is not empty.
pred_has_answer = pred_label and (
(not util.is_null_span_list(pred_label.short_answer_span_list))
or pred_label.yes_no_answer != "none"
)
is_correct = False
score = pred_label.short_score
# Both sides have short answers, which contains yes/no questions.
if gold_has_answer and pred_has_answer:
if pred_label.yes_no_answer != "none": # System thinks its y/n questions.
for gold_label in gold_label_list:
if pred_label.yes_no_answer == gold_label.yes_no_answer:
is_correct = True
break
else:
for gold_label in gold_label_list:
if util.span_set_equal(
gold_label.short_answer_span_list, pred_label.short_answer_span_list
):
is_correct = True
break
return gold_has_answer, pred_has_answer, is_correct, score
def score_answers(gold_annotation_dict, pred_dict):
"""Scores all answers for all documents.
Args:
gold_annotation_dict: a dict from example id to list of NQLabels.
pred_dict: a dict from example id to list of NQLabels.
Returns:
long_answer_stats: List of scores for long answers.
short_answer_stats: List of scores for short answers.
"""
gold_id_set = set(gold_annotation_dict.keys())
pred_id_set = set(pred_dict.keys())
if gold_id_set.symmetric_difference(pred_id_set):
raise ValueError(
"ERROR: the example ids in gold annotations and example "
"ids in the prediction are not equal."
)
long_answer_stats = []
short_answer_stats = []
for example_id in gold_id_set:
gold = gold_annotation_dict[example_id]
pred = pred_dict[example_id]
long_answer_stats.append(score_long_answer(gold, pred))
short_answer_stats.append(score_short_answer(gold, pred))
# use the 'score' column, which is last
long_answer_stats.sort(key=lambda x: x[-1], reverse=True)
short_answer_stats.sort(key=lambda x: x[-1], reverse=True)
return long_answer_stats, short_answer_stats
def compute_f1(answer_stats, prefix=""):
"""Computes F1, precision, recall for a list of answer scores.
Args:
answer_stats: List of per-example scores.
prefix (''): Prefix to prepend to score dictionary.
Returns:
Dictionary mapping string names to scores.
"""
has_gold, has_pred, is_correct, _ = list(zip(*answer_stats))
precision = safe_divide(sum(is_correct), sum(has_pred))
recall = safe_divide(sum(is_correct), sum(has_gold))
f1 = safe_divide(2 * precision * recall, precision + recall)
return OrderedDict(
{
prefix + "n": len(answer_stats),
prefix + "f1": f1,
prefix + "precision": precision,
prefix + "recall": recall,
}
)
def compute_final_f1(long_answer_stats, short_answer_stats):
"""Computes overall F1 given long and short answers, ignoring scores.
Note: this assumes that the answers have been thresholded.
Arguments:
long_answer_stats: List of long answer scores.
short_answer_stats: List of short answer scores.
Returns:
Dictionary of name (string) -> score.
"""
scores = compute_f1(long_answer_stats, prefix="long-answer-")
scores.update(compute_f1(short_answer_stats, prefix="short-answer-"))
return scores
def compute_pr_curves(answer_stats, targets=None):
"""Computes PR curve and returns R@P for specific targets.
The values are computed as follows: find the (precision, recall) point
with maximum recall and where precision > target.
Arguments:
answer_stats: List of statistic tuples from the answer scores.
targets (None): List of precision thresholds to target.
Returns:
List of table with rows: [target, r, p, score].
"""
total_correct = 0
total_has_pred = 0
total_has_gold = 0
# Count the number of gold annotations.
for has_gold, _, _, _ in answer_stats:
total_has_gold += has_gold
# Keep track of the point of maximum recall for each target.
max_recall = [0 for _ in targets]
max_precision = [0 for _ in targets]
max_scores = [None for _ in targets]
# Only keep track of unique thresholds in this dictionary.
scores_to_stats = OrderedDict()
# Loop through every possible threshold and compute precision + recall.
for has_gold, has_pred, is_correct, score in answer_stats:
total_correct += is_correct
total_has_pred += has_pred
precision = safe_divide(total_correct, total_has_pred)
recall = safe_divide(total_correct, total_has_gold)
# If there are any ties, this will be updated multiple times until the
# ties are all counted.
scores_to_stats[score] = [precision, recall]
best_f1 = 0.0
best_precision = 0.0
best_recall = 0.0
best_threshold = 0.0
for threshold, (precision, recall) in six.iteritems(scores_to_stats):
# Match the thresholds to the find the closest precision above some target.
for t, target in enumerate(targets):
if precision >= target and recall > max_recall[t]:
max_recall[t] = recall
max_precision[t] = precision
max_scores[t] = threshold
# Compute optimal threshold.
f1 = safe_divide(2 * precision * recall, precision + recall)
if f1 > best_f1:
best_f1 = f1
best_precision = precision
best_recall = recall
best_threshold = threshold
return (
(best_f1, best_precision, best_recall, best_threshold),
list(zip(targets, max_recall, max_precision, max_scores)),
)
def print_r_at_p_table(answer_stats):
"""Pretty prints the R@P table for default targets."""
opt_result, pr_table = compute_pr_curves(answer_stats, targets=[0.5, 0.75, 0.9])
f1, precision, recall, threshold = opt_result
print("Optimal threshold: {:.5}".format(threshold))
print(" F1 / P / R")
print("{: >7.2%} / {: >7.2%} / {: >7.2%}".format(f1, precision, recall))
for target, recall, precision, row in pr_table:
print(
"R@P={}: {:.2%} (actual p={:.2%}, score threshold={:.4})".format(
target, recall, precision, row
)
)
def get_metrics_as_dict(gold_path, prediction_path, num_threads=10):
"""Library version of the end-to-end evaluation.
Arguments:
gold_path: Path to the gzip JSON data. For multiple files, should be a glob
pattern (e.g. "/path/to/files-*")
prediction_path: Path to the JSON prediction data.
num_threads (10): Number of threads to use when parsing multiple files.
Returns:
metrics: A dictionary mapping string names to metric scores.
"""
nq_gold_dict = util.read_annotation(gold_path, n_threads=num_threads)
nq_pred_dict = util.read_prediction_json(prediction_path)
long_answer_stats, short_answer_stats = score_answers(nq_gold_dict, nq_pred_dict)
return get_metrics_with_answer_stats(long_answer_stats, short_answer_stats)
def get_metrics_with_answer_stats(long_answer_stats, short_answer_stats):
"""Generate metrics dict using long and short answer stats."""
def _get_metric_dict(answer_stats, prefix=""):
"""Compute all metrics for a set of answer statistics."""
opt_result, pr_table = compute_pr_curves(answer_stats, targets=[0.5, 0.75, 0.9])
f1, precision, recall, threshold = opt_result
metrics = OrderedDict(
{
"best-threshold-f1": f1,
"best-threshold-precision": precision,
"best-threshold-recall": recall,
"best-threshold": threshold,
}
)
for target, recall, precision, _ in pr_table:
metrics["recall-at-precision>={:.2}".format(target)] = recall
metrics["precision-at-precision>={:.2}".format(target)] = precision
# Add prefix before returning.
return dict([(prefix + k, v) for k, v in six.iteritems(metrics)])
metrics = _get_metric_dict(long_answer_stats, "long-")
metrics.update(_get_metric_dict(short_answer_stats, "short-"))
return metrics
def main(_):
cache_path = os.path.join(os.path.dirname(FLAGS.gold_path), "cache")
if FLAGS.cache_gold_data and os.path.exists(cache_path):
logging.info("Reading from cache: %s", format(cache_path))
nq_gold_dict = pickle.load(open(cache_path, "r"))
else:
nq_gold_dict = util.read_annotation(
FLAGS.gold_path, n_threads=FLAGS.num_threads
)
if FLAGS.cache_gold_data:
logging.info("Caching gold data for next time to: %s", format(cache_path))
pickle.dump(nq_gold_dict, open(cache_path, "w"))
nq_pred_dict = util.read_prediction_json(FLAGS.predictions_path)
long_answer_stats, short_answer_stats = score_answers(nq_gold_dict, nq_pred_dict)
if FLAGS.pretty_print:
print("*" * 20)
print("LONG ANSWER R@P TABLE:")
print_r_at_p_table(long_answer_stats)
print("*" * 20)
print("SHORT ANSWER R@P TABLE:")
print_r_at_p_table(short_answer_stats)
scores = compute_final_f1(long_answer_stats, short_answer_stats)
print("*" * 20)
print("METRICS IGNORING SCORES (n={}):".format(scores["long-answer-n"]))
print(" F1 / P / R")
print(
"Long answer {: >7.2%} / {: >7.2%} / {: >7.2%}".format(
scores["long-answer-f1"],
scores["long-answer-precision"],
scores["long-answer-recall"],
)
)
print(
"Short answer {: >7.2%} / {: >7.2%} / {: >7.2%}".format(
scores["short-answer-f1"],
scores["short-answer-precision"],
scores["short-answer-recall"],
)
)
else:
metrics = get_metrics_with_answer_stats(long_answer_stats, short_answer_stats)
print(json.dumps(metrics))
if __name__ == "__main__":
flags.mark_flag_as_required("gold_path")
flags.mark_flag_as_required("predictions_path")
app.run(main)