-
Notifications
You must be signed in to change notification settings - Fork 2
/
evaluate.py
400 lines (333 loc) · 13.9 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
import csv
import json
import pprint
import sys
import yaml
from copy import deepcopy
import argparse
# Instantiate the parser
parser = argparse.ArgumentParser(description='Optional app description')
# Create the pretty printer
pp = pprint.PrettyPrinter(indent=4)
"""
It is possible for a metric to present multiple annotations which map to one annotation in the GSML.
or vice verse.
For example, "Miami Heat" could be a 2-token error "Miami Heat".
or two 1-token errors "Miami" (wring city) "Heat" (wrong name).
We award correct recall when at least one submitted mistake matches a GSML mistake.
We award correct precision when a submitted mistake matches at least one GSML mistake.
Mistakes are said to match when their ranges of token ids overlap.
Once a submitted mistake has recalled a GSML mistake, the submitted mistake is consumed
- It cannot recall a subsequent GSML mistake
"""
''' Returns the category labels described in the paper '''
def all_categories():
return ['NAME', 'NUMBER', 'WORD', 'CONTEXT', 'NOT_CHECKABLE', 'OTHER']
''' Returns an int or None '''
def csv_int(x):
return int(x) if x else None
''' Helper that checks that either DOC or SENT based tokenization is used throughout'''
def consistent_tokenization(tokenization_mode, current_line_mode):
if tokenization_mode not in {None, current_line_mode}:
raise Exception('You must consistently use either document-based, sentence-based or both for the tokenization method.')
return current_line_mode
"""
Creates and returns a dictionary representation of the mistake list (GSML or Submission)
The dictiory is structured as:
- TEXT_ID, TEXT_DATA
- START_IDX, MISTAKE_DATA
The function returns a tuple where the first element is the dict, and the second is num_mistakes
"""
def create_mistake_dict(filename, categories, token_lookup):
mistake_dict = {}
tokens_used = {}
matches = 0
with open(filename, newline='') as csvfile:
reader = csv.reader(csvfile, delimiter=',', quotechar='"')
next(reader, None)
num_mistakes = 0
tokenization_mode = None
for i, row in enumerate(reader):
# Columns from the CSV
text_id = row[0].replace('.txt','')
sentence_id = csv_int(row[1])
annotation_id = csv_int(row[2])
tokens = row[3]
sent_start_idx = csv_int(row[4])
sent_end_idx = csv_int(row[5])
doc_start_idx = csv_int(row[6])
doc_end_idx = csv_int(row[7])
category = row[8]
# Check the sanity of the token submissions
sent_given = (sent_start_idx != None and sent_end_idx != None and sentence_id != None)
doc_given = (doc_start_idx != None and doc_end_idx != None)
if sent_given and doc_given:
tokenization_mode = consistent_tokenization(tokenization_mode, 'BOTH')
# Check mapping from sent to doc tokenization matches our token_lookup
assert doc_start_idx == token_lookup['sent_to_doc'][text_id][sentence_id][sent_start_idx]
assert doc_end_idx == token_lookup['sent_to_doc'][text_id][sentence_id][sent_end_idx]
# And doc to sent
assert sentence_id == token_lookup['doc_to_sent'][text_id][doc_start_idx]['sentence_id']
assert sent_start_idx == token_lookup['doc_to_sent'][text_id][doc_start_idx]['token_id']
assert sent_end_idx == token_lookup['doc_to_sent'][text_id][doc_end_idx]['token_id']
elif sent_given:
tokenization_mode = consistent_tokenization(tokenization_mode, 'SENT')
doc_start_idx = token_lookup['sent_to_doc'][text_id][sentence_id][sent_start_idx]
doc_end_idx = token_lookup['sent_to_doc'][text_id][sentence_id][sent_end_idx]
elif doc_given:
tokenization_mode = consistent_tokenization(tokenization_mode, 'DOC')
sentence_id = token_lookup['doc_to_sent'][text_id][doc_start_idx]['sentence_id']
sent_start_idx = token_lookup['doc_to_sent'][text_id][doc_start_idx]['token_id']
sent_end_idx = token_lookup['doc_to_sent'][text_id][doc_end_idx]['token_id']
else:
err_str = f'You must provide either document or sentence based token ids on {filename} row {i}'
raise Exception(err_str)
if category not in categories:
continue
# For detecting overlapping spans
if text_id not in tokens_used:
tokens_used[text_id] = set([])
for x in range(doc_start_idx, doc_end_idx+1):
if x in tokens_used[text_id]:
err_str = f'Token {x} already used, duplicate on {text_id}:{i}'
raise Exception(err_str)
tokens_used[text_id].add(x)
# The mistake data structure
if text_id not in mistake_dict:
mistake_dict[text_id] = {}
mistake_dict[text_id][doc_start_idx] = {
'set': set(range(doc_start_idx, doc_end_idx+1)),
'category': category,
'sent_start_idx': sent_start_idx,
'sent_end_idx': sent_start_idx,
'doc_start_idx': doc_start_idx,
'doc_end_idx': doc_end_idx,
'sentence_id': sentence_id,
'annotation_id': annotation_id,
'tokens': tokens,
}
num_mistakes += 1
return mistake_dict, num_mistakes
"""
Recall is when at least one submitted mistake overlaps the GSML mistake
- once a submitted mistake has been used for correct recall, it cannot be used again (it is consumed).
Precision is when a submitted mistake overlaps any GSML mistake.
"""
def match_mistake_dicts(gsml, submitted):
per_category_matches = {k:{} for k in all_categories()}
# Copy this because the algorithm consumes elements
# - this can break the token level calcs if done in wrong order
# - so copy for least surprise
copy_submitted = deepcopy(submitted)
for text_id, gsml_text_data in gsml.items():
# mistake level - match each submission to at most one gold mistake
used_submissions = set([])
for doc_start_idx, gsml_error_data in gsml_text_data.items():
doc_end_idx = gsml_error_data['doc_end_idx']
category = gsml_error_data['category']
assert category in per_category_matches
pop_key = None
if text_id in copy_submitted:
for submitted_doc_start_idx, submitted_error_data in copy_submitted[text_id].items():
# TODO - this is pretty brute force, it loops needlessly, but it works so ...
# Check if the submitted error intersects the GSML error
if submitted_error_data['set'].intersection(gsml_error_data['set']):
# Only use a submission once, it cannot recall multiple gold mistakes
pop_key = submitted_doc_start_idx
break
match = pop_key != None
if match:
# Remove the submission so it will not be used again
copy_submitted[text_id].pop(pop_key, None)
per_category_matches[category][f'{text_id}_{doc_start_idx}'] = match
return per_category_matches
'''Returns the correct and incorrect recall totals'''
def get_recall(matches):
correct = {k:0 for k in all_categories()}
incorrect = {k:0 for k in all_categories()}
for category, h in matches.items():
for mistake_id_str, v in h.items():
if v:
correct[category] += 1
else:
incorrect[category] += 1
return correct, incorrect
def get_document_tokens(token_lookup):
document_tokens = {}
for text_id, token_data in token_lookup['doc_to_sent'].items():
document_tokens[text_id] = {}
for doc_token_id in token_data.keys():
document_tokens[text_id][doc_token_id] = {
'gsml': False,
'submitted': False
}
return document_tokens
def match_tokens(data, document_tokens, mode):
for text_id, text_data in data.items():
for start_idx, error_data in text_data.items():
for x in range(error_data['doc_start_idx'], error_data['doc_end_idx']+1):
document_tokens[text_id][x][mode] = True
def get_token_level_result(gsml, submitted, token_lookup):
document_tokens = get_document_tokens(token_lookup)
match_tokens(gsml, document_tokens, 'gsml')
match_tokens(submitted, document_tokens, 'submitted')
recall = 0
recall_denominator = 0
precision_denominator = 0
for text_id, data in document_tokens.items():
for token_id, v in data.items():
if v['gsml'] and v['submitted']:
recall += 1
if v['gsml']:
recall_denominator += 1
if v['submitted']:
precision_denominator += 1
return {
'recall': recall,
'recall_denominator': recall_denominator,
'precision_denominator': precision_denominator,
}
def safe_divide(x, y):
if y > 0:
return x / y
return None
"""
checks that the token text in the submssion matches that which is retrieved by DOCUMENT level IDs
"""
def check_token_ids(mistake_dict, text_dir):
for text_id, text_errors in mistake_dict.items():
with open(f'{text_dir}/{text_id}.txt', 'r') as fh:
raw_text = fh.read()
raw_tokens = raw_text.split()
for doc_start_idx, h in text_errors.items():
assert doc_start_idx == h['doc_start_idx']
tokens = h['tokens'].split()
for i, t in enumerate(tokens):
# Check the token reported matches the one in the raw text
# - Token IDs in submission start at 1 (because of WebAnno)
x = doc_start_idx+i-1
# print(f'{text_id}:{x} => {raw_tokens[doc_start_idx+i-1]} == {t}')
assert raw_tokens[x] == t
"""
Returns a dict containing sub-dicts of recall, precision and overlaps between a GSML and a submission
Takes as input dicts created with match_mistake_dicts(), plus a list of categories
Only the categories given will be checked.
"""
def calculate_recall_and_precision(gsml_filename, submitted_filename, token_lookup, text_dir, categories=[]):
gsml, gsml_num_lines = create_mistake_dict(gsml_filename, categories, token_lookup)
submitted, submitted_num_lines = create_mistake_dict(submitted_filename, categories, token_lookup)
if text_dir != None:
print('\tChecking GSML for token match against raw texts:')
check_token_ids(gsml, text_dir)
print('\tChecking Submitted for token match against raw texts:')
check_token_ids(submitted, text_dir)
# Mistake level
per_category_matches = match_mistake_dicts(gsml, submitted)
correct_recall_h, incorrect_recall_h = get_recall(per_category_matches)
correct_recall = sum(correct_recall_h.values())
incorrect_recall = sum(incorrect_recall_h.values())
assert (correct_recall + incorrect_recall) == gsml_num_lines
recall = safe_divide(correct_recall, gsml_num_lines)
precision = safe_divide(correct_recall, submitted_num_lines)
# Token level
token_result = get_token_level_result(gsml, submitted, token_lookup)
token_recall = safe_divide(token_result['recall'], token_result['recall_denominator'])
token_precision = safe_divide(token_result['recall'], token_result['precision_denominator'])
# Values to display
return {
'recall': {
'value': recall,
'correct': correct_recall,
'of_total': gsml_num_lines
},
'precision': {
'value': precision,
'correct': correct_recall,
'of_total': submitted_num_lines
},
'token_recall': {
'value': token_recall,
'correct': token_result['recall'],
'of_total': token_result['recall_denominator']
},
'token_precision': {
'value': token_precision,
'correct': token_result['recall'],
'of_total': token_result['precision_denominator']
},
'correct_recall_debug': correct_recall_h,
'incorrect_recall_debug': incorrect_recall_h
}
def format_result_value(value, dcp=3):
if value:
return round(value, dcp)
return None
# CLI args
parser.add_argument('--gsml', type=str,
help='The GSML file path (CSV)')
parser.add_argument('--submitted', type=str, nargs='?',
help='The submitted file path (CSV)')
parser.add_argument('--token_lookup', type=str,
help='The tokenization file (YAML)')
parser.add_argument('--text_dir', type=str,
help='The directory where the raw texts are')
parser.add_argument('--csv_out', type=str,
help='Path to an output CSV file for stats (optional)')
args = parser.parse_args()
gsml_filename = args.gsml
submitted_filename = args.submitted
token_lookup_filename = args.token_lookup
text_dir = args.text_dir
csv_out = args.csv_out
with open(token_lookup_filename, 'r') as fh:
token_lookup = yaml.full_load(fh)
print('\n\n')
print('-' * 80)
print('GSML: EVALUATE')
print(f'comparing GSML => "{gsml_filename}" to submission => "{submitted_filename}"')
# Check all catogories combined, as well as each category individually
categories_list = [all_categories()] + [[x] for x in all_categories()]
csv_lines = [
[
'categories',
'recall',
'precision',
'token_recall',
'token_precision',
'submitted_filename',
'gsml_filename',
'token_lookup_filename',
'text_dir',
]
]
for categories in categories_list:
category_display_str = ', '.join(categories)
print('\n\n--------------------------------------------')
print(f'-- GSML for categories: [{category_display_str}]')
result = calculate_recall_and_precision(gsml_filename, submitted_filename, token_lookup, text_dir, categories)
recall = format_result_value(result['recall']['value'])
precision = format_result_value(result['precision']['value'])
token_recall = format_result_value(result['token_recall']['value'])
token_precision = format_result_value(result['token_precision']['value'])
csv_lines.append(
[
'|'.join(categories),
str(recall),
str(precision),
str(token_recall),
str(token_precision),
submitted_filename,
gsml_filename,
text_dir,
]
)
print(f'\tsummary: recall => {recall}, precision => {precision}, token_recall => {token_recall}, token_precision => {token_precision}')
print('\tbreakdown:')
for k, v in result.items():
print(f'\t\t{k}')
for sub_k, sub_v in v.items():
print(f'\t\t\t{sub_k} => {sub_v}')
if csv_out != None:
with open(csv_out, 'w') as fh:
s = '\n'.join([','.join(arr) for arr in csv_lines])
fh.write(f'{s}\n')