Skip to content

Commit

Permalink
Add formatted output for error-counts
Browse files Browse the repository at this point in the history
  • Loading branch information
ConstantineLignos committed Aug 9, 2023
1 parent 35748fe commit 1aad6f1
Showing 1 changed file with 51 additions and 8 deletions.
59 changes: 51 additions & 8 deletions seqscore/conll.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
from collections import defaultdict
from itertools import chain
from statistics import mean, stdev
from typing import Any, DefaultDict, Iterable, List, Optional, Sequence, TextIO, Tuple
from typing import (
Any,
Counter,
DefaultDict,
Iterable,
List,
Optional,
Sequence,
TextIO,
Tuple,
)

from attr import attrib, attrs
from tabulate import tabulate
Expand Down Expand Up @@ -466,13 +476,46 @@ def score_conll_files(
all_acc_scores.append(class_scores)

if error_counts:
headers = ("False positives", "False negatives")
counters = (class_scores.false_pos_examples, class_scores.false_neg_examples)
for error_header, counter in zip(headers, counters):
print(error_header)
for item, count in counter.items():
print(" ".join(item.tokens), item.type, count, sep="\t")
print()
if multi_files:
raise ValueError(
"Outputting error counts is only available for a single prediction file"
)

if output_format == FORMAT_CONLL:
raise ValueError(
f"Format {repr(output_format)} is not supported with error counts"
)
elif output_format in (FORMAT_PRETTY, FORMAT_DELIM):
header = ["Count", "Error", "Type", "Tokens"]

# Combine counts across the two counters
combined_counts: Counter[Tuple[str, str, str]] = Counter()
for counter, error_type in zip(
(class_scores.false_pos_examples, class_scores.false_neg_examples),
("FP", "FN"),
):
for item, count in counter.items():
combined_counts[
(error_type, item.type, " ".join(item.tokens))
] = count

rows = [
[count, error_type, mention_type, token_str]
for (
error_type,
mention_type,
token_str,
), count in combined_counts.most_common()
]

if output_format == FORMAT_PRETTY:
print(tabulate(rows, header, tablefmt="github"))
else:
# Delimited output
score_summaries.append(delim.join(header))
score_summaries.extend(_join_delim(row, delim) for row in rows)
print("\n".join(score_summaries))

# Exit early since all the following logic is for printing scores
return

Expand Down

0 comments on commit 1aad6f1

Please sign in to comment.