forked from XavierWww/Chinese-Medical-Entity-Recognition
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
47 lines (40 loc) · 1.54 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# -*- coding: utf-8 -*-
'''
@Author: Xavier WU
@Date: 2021-11-30
@LastEditTime: 2022-1-6
@Description: This file is for building model.
@All Right Reserve
'''
import torch
import torch.nn as nn
from transformers import BertModel
from torchcrf import CRF
class Bert_BiLSTM_CRF(nn.Module):
def __init__(self, tag_to_ix, embedding_dim=768, hidden_dim=256):
super(Bert_BiLSTM_CRF, self).__init__()
self.tag_to_ix = tag_to_ix
self.tagset_size = len(tag_to_ix)
self.hidden_dim = hidden_dim
self.embedding_dim = embedding_dim
self.bert = BertModel.from_pretrained('bert-base-chinese')
self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim//2,
num_layers=2, bidirectional=True, batch_first=True)
self.dropout = nn.Dropout(p=0.1)
self.linear = nn.Linear(hidden_dim, self.tagset_size)
self.crf = CRF(self.tagset_size, batch_first=True)
def _get_features(self, sentence):
with torch.no_grad():
embeds, _ = self.bert(sentence)
enc, _ = self.lstm(embeds)
enc = self.dropout(enc)
feats = self.linear(enc)
return feats
def forward(self, sentence, tags, mask, is_test=False):
emissions = self._get_features(sentence)
if not is_test: # Training,return loss
loss=-self.crf.forward(emissions, tags, mask, reduction='mean')
return loss
else: # Testing,return decoding
decode=self.crf.decode(emissions, mask)
return decode