-
Notifications
You must be signed in to change notification settings - Fork 0
/
persian_ner.py
60 lines (50 loc) · 1.64 KB
/
persian_ner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import sqlite3
import sys
from tqdm import tqdm
from flair.data import Sentence
from flair.models import SequenceTagger
db_file_path = './volume/fars_news.db'
ner_count = {
'PER': 0,
'LOC': 0,
'ORG': 0,
'DAT': 0,
'TIM': 0,
'PCT': 0,
'MON': 0,
'MISC': 0,
'ERROR': 0,
}
def main():
db_connection = sqlite3.connect(db_file_path)
db_cursor = db_connection.cursor()
query = "SELECT body FROM news ORDER BY published_datetime ASC;"
db_cursor.execute(query)
result = db_cursor.fetchall()
db_connection.commit()
db_connection.close()
# load tagger
tagger = SequenceTagger.load("PooryaPiroozfar/Flair-Persian-NER")
for record in tqdm(result):
for item in record[0].split('.'):
if item != '':
try:
# make example sentence
sentence = Sentence(item)
# predict NER tags
tagger.predict(sentence)
# iterate over entities and print
for entity in sentence.get_spans('ner'):
ner_count[entity.tag] = ner_count.get(entity.tag) + 1
# print(entity.text)
# print(entity.tag)
# print(ner_count[entity.tag])
except KeyboardInterrupt:
print(ner_count)
sys.exit(0)
except:
ner_count['ERROR'] = ner_count.get('ERROR') + 1
print(ner_count)
print(ner_count)
if __name__ == '__main__':
main()