Skip to content

Commit

Permalink
adding option to print eval failures
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed May 9, 2022
1 parent d08877e commit d23bf6f
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion frame_semantic_transformer/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def evaluate(
tokenizer: T5Tokenizer,
samples: Sequence[TaskSample],
batch_size: int = 10,
print_failures: bool = False,
) -> dict[str, list[int]]:
results: dict[str, list[int]] = defaultdict(lambda: [0, 0, 0])
for samples_chunk in tqdm(
Expand All @@ -45,8 +46,12 @@ def evaluate(

predictions = batch_predict(model, tokenizer, inputs)
for sample, prediction in zip(samples_chunk, predictions):
true_pos, false_pos, false_neg = sample.evaluate_prediction(prediction)
score = sample.evaluate_prediction(prediction)
true_pos, false_pos, false_neg = score
results[sample.get_task_name()][0] += true_pos
results[sample.get_task_name()][1] += false_pos
results[sample.get_task_name()][2] += false_neg
if print_failures and (false_neg > 0 or false_pos > 0):
print(score, sample.get_target(), prediction)

return results

0 comments on commit d23bf6f

Please sign in to comment.