-
Notifications
You must be signed in to change notification settings - Fork 6
/
model.py
36 lines (32 loc) · 1.63 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch import nn
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig,
BertForTokenClassification, BertTokenizer,
WarmupLinearSchedule)
class Ner(BertForTokenClassification):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):
sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0]
batch_size,max_len,feat_dim = sequence_output.shape
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda')
for i in range(batch_size):
jj = -1
for j in range(max_len):
if valid_ids[i][j].item() == 1:
jj += 1
valid_output[i][jj] = sequence_output[i][j]
sequence_output = self.dropout(valid_output)
logits = self.classifier(sequence_output)
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=0)
# Only keep active parts of the loss
#attention_mask_label = None
if attention_mask_label is not None:
active_loss = attention_mask_label.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits