-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_models.py
36 lines (29 loc) · 1.17 KB
/
eval_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from os import system
from flair.data import Span
from flair.datasets import ColumnCorpus
from flair.embeddings import *
from flair.models import SequenceTagger
import util
device = torch.device('cuda:0')
for name, model in ('model_name', ...): # TODO: model path
util.print_flag('Loading Model')
tagger: SequenceTagger = SequenceTagger.load(model)
print(tagger)
util.print_flag('Loading Corpus')
corpus: ColumnCorpus = ColumnCorpus(
data_folder=os.path.split('resources/data/background_processed.conll')[0],
train_file=os.path.split('resources/data/background_processed.conll')[1],
column_format={0: 'text', 1: 'begin', 2: 'end', 3: 'ner'}
)
print(corpus)
util.print_flag('Tagging')
results: List[Span] = []
tag_type = 'ner'
tagged: List[Sentence] = tagger.predict(sentences=corpus.get_all_sentences())
for sentence in tagged:
for span in sentence.get_spans(tag_type):
if span.tag is not "O":
results.append(span)
util.print_flag('Evaluating')
system(f'python3 evaluate.py ner gold/test/subtrack1 system/{model}/ | '
f'tee resources/models/{model}/eval_results.txt')