-
Notifications
You must be signed in to change notification settings - Fork 8
/
nereval.py
177 lines (138 loc) · 4.53 KB
/
nereval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# pylint: disable=C0103
from __future__ import division
import argparse
import collections
import json
Entity = collections.namedtuple('Entity', ['text', 'type', 'start'])
def has_overlap(x, y):
"""
Determines whether the text of two entities overlap. This function is symmetric.
Returns
-------
bool
True iff text overlaps.
"""
end_x = x.start + len(x.text)
end_y = y.start + len(y.text)
return x.start < end_y and y.start < end_x
def correct_text(x, y):
"""
Assert entity boundaries are correct regardless of entity type.
"""
return x.text == y.text and x.start == y.start
def correct_type(x, y):
"""
Assert entity types match and that there is an overlap in the text of the two entities.
"""
return x.type == y.type and has_overlap(x, y)
def count_correct(true, pred):
"""
Computes the count of correctly predicted entities on two axes: type and text.
Parameters
----------
true: list of Entity
The list of ground truth entities.
pred: list of Entity
The list of predicted entities.
Returns
-------
count_text: int
The number of entities predicted where the text matches exactly.
count_type: int
The number of entities where the type is correctly predicted and the text overlaps.
"""
count_text, count_type = 0, 0
for x in true:
for y in pred:
text_match = correct_text(x, y)
type_match = correct_type(x, y)
if text_match:
count_text += 1
if type_match:
count_type += 1
if type_match or text_match:
# Stop as soon as an entity has been recognized by the system
break
return count_text, count_type
def precision(correct, actual):
if actual == 0:
return 0
return correct / actual
def recall(correct, possible):
if possible == 0:
return 0
return correct / possible
def f1(p, r):
if p + r == 0:
return 0
return 2 * (p * r) / (p + r)
def evaluate(y_true, y_pred):
"""
Evaluate classification results for a whole dataset. Each row corresponds to one text in the
dataset.
Parameters
----------
y_true: list of list
For each text in the dataset, a list of ground-truth entities.
y_pred: list of list
For each text in the dataset, a list of predicted entities.
Returns
-------
float:
Micro-averaged F1 score of precision and recall.
Example
-------
>>> from nereval import Entity, evaluate
>>> y_true = [
... [Entity('a', 'b', 0), Entity('b', 'b', 2)]
... ]
>>> y_pred = [
... [Entity('b', 'b', 2)]
... ]
>>> evaluate(y_true, y_pred)
0.6666666666666666
"""
if len(y_true) != len(y_pred):
raise ValueError('Bad input shape: y_true and y_pred should have the same length.')
correct, actual, possible = 0, 0, 0
for x, y in zip(y_true, y_pred):
correct += sum(count_correct(x, y))
# multiply by two to account for both type and text
possible += len(x) * 2
actual += len(y) * 2
return f1(precision(correct, actual), recall(correct, possible))
def sign_test(truth, model_a, model_b):
better = 0
worse = 0
for true, a, b in zip(truth, model_a, model_b):
score_a = evaluate([true], [a])
score_b = evaluate([true], [b])
if score_a - score_b > 0:
worse += 1
elif score_a - score_b < 0:
better += 1
return better, worse
def _parse_json(file_name):
data = None
with open(file_name) as json_file:
data = json.load(json_file)
dict_to_entity = lambda e: Entity(e['text'], e['type'], e['start'])
for instance in data:
instance['true'] = [dict_to_entity(e) for e in instance['true']]
instance['predicted'] = [dict_to_entity(e) for e in instance['predicted']]
return data
def evaluate_json(file_name):
"""
Evaluate according to results in JSON file format.
"""
y_true = []
y_pred = []
for instance in _parse_json(file_name):
y_true.append(instance['true'])
y_pred.append(instance['predicted'])
return evaluate(y_true, y_pred)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Compute F1 score for predictions in JSON file.')
parser.add_argument('file_name', help='The JSON containing classification results')
args = parser.parse_args()
print('F1-score: %.2f' % evaluate_json(args.file_name))