-
Notifications
You must be signed in to change notification settings - Fork 0
/
modeling.py
95 lines (73 loc) · 3.29 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import math
import torch
import logging
from torch import nn
import numpy as np
from transformers.models.bert.modeling_bert import (BertModel, BertPreTrainedModel)
from transformers.models.deberta.modeling_deberta import (DebertaModel, DebertaPreTrainedModel)
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer, AdamW
logger = logging.getLogger(__name__)
class RankingBERT_Train(DebertaPreTrainedModel):
def __init__(self, config):
super(RankingBERT_Train, self).__init__(config)
self.bert = DebertaModel(config)
self.init_weights()
self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
self.out = nn.Linear(config.hidden_size, 1)
def forward(self, input_ids, token_type_ids,
position_ids, labels=None):
attention_mask = (input_ids != 0)
bert_pooler_output = self.bert(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids)[1]
output = self.out(self.dropout(bert_pooler_output))
# shape = [B, 1]
if labels is not None:
loss_fct = nn.MarginRankingLoss(margin=1.0, reduction='mean')
y_pos, y_neg = [], []
for batch_index in range(len(labels)):
label = labels[batch_index]
if label > 0:
y_pos.append(output[batch_index])
else:
y_neg.append(output[batch_index])
y_pos = torch.cat(y_pos, dim=-1)
y_neg = torch.cat(y_neg, dim=-1)
y_true = torch.ones_like(y_pos)
assert len(y_pos) == len(y_neg)
loss = loss_fct(y_pos, y_neg, y_true)
output = loss, *output
return output
# class RankingBERT_Train(BertPreTrainedModel):
# def __init__(self, config):
# super(RankingBERT_Train, self).__init__(config)
# self.bert = BertModel(config)
# self.init_weights()
# self.dropout = nn.Dropout(p=config.hidden_dropout_prob)
# self.out = nn.Linear(config.hidden_size, 1)
# def forward(self, input_ids, token_type_ids,
# position_ids, labels=None):
# attention_mask = (input_ids != 0)
# bert_pooler_output = self.bert(input_ids,
# attention_mask=attention_mask,
# token_type_ids=token_type_ids,
# position_ids=position_ids)[1]
# output = self.out(self.dropout(bert_pooler_output))
# # shape = [B, 1]
# if labels is not None:
# loss_fct = nn.MarginRankingLoss(margin=1.0, reduction='mean')
# y_pos, y_neg = [], []
# for batch_index in range(len(labels)):
# label = labels[batch_index]
# if label > 0:
# y_pos.append(output[batch_index])
# else:
# y_neg.append(output[batch_index])
# y_pos = torch.cat(y_pos, dim=-1)
# y_neg = torch.cat(y_neg, dim=-1)
# y_true = torch.ones_like(y_pos)
# assert len(y_pos) == len(y_neg)
# loss = loss_fct(y_pos, y_neg, y_true)
# output = loss, *output
# return output