-
Notifications
You must be signed in to change notification settings - Fork 13
/
modeling_dense.py
136 lines (120 loc) · 5.36 KB
/
modeling_dense.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from ast import Pass
import sys
import os
import torch
from torch import nn
import torch.nn.functional as F
from typing import Union
from transformers import AutoConfig, BertConfig, BertModel, RobertaConfig, RobertaModel, AutoModel
from transformers.models.bert.modeling_bert import BertPreTrainedModel
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
from transformers.models.distilbert.modeling_distilbert import DistilBertPreTrainedModel, DistilBertModel, DistilBertConfig
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
class BertDense(BertPreTrainedModel):
def __init__(self, config: BertConfig):
BertPreTrainedModel.__init__(self, config)
self.bert = BertModel(config, add_pooling_layer=False)
def forward(self, input_ids, attention_mask, return_dict=False):
outputs = self.bert(input_ids, attention_mask, return_dict=True)
if hasattr(self.config, "pooling"):
if self.config.pooling == "cls":
text_embeds = outputs.last_hidden_state[:, 0]
elif self.config.pooling == "mean":
text_embeds = mean_pooling(outputs, attention_mask)
else:
raise NotImplementedError()
else:
text_embeds = outputs.last_hidden_state[:, 0]
if hasattr(self.config, "similarity_metric"):
if self.config.similarity_metric == "METRIC_IP":
pass
elif self.config.similarity_metric == "METRIC_COS":
text_embeds = F.normalize(text_embeds, p=2, dim=-1)
else:
pass
if return_dict:
outputs.embedding = text_embeds
return outputs
else:
return text_embeds
@property
def language_model(self):
return self.bert
class RobertaDense(RobertaPreTrainedModel):
def __init__(self, config: RobertaConfig):
RobertaPreTrainedModel.__init__(self, config)
self.roberta = RobertaModel(config, add_pooling_layer=False)
def forward(self, input_ids, attention_mask, return_dict=False):
outputs = self.roberta(input_ids, attention_mask, return_dict=True)
text_embeds = outputs.last_hidden_state[:, 0]
if hasattr(self.config, "pooling"):
if self.config.pooling == "cls":
text_embeds = outputs.last_hidden_state[:, 0]
elif self.config.pooling == "mean":
text_embeds = mean_pooling(outputs, attention_mask)
else:
raise NotImplementedError()
else: # default: use cls token embedding
text_embeds = outputs.last_hidden_state[:, 0]
if hasattr(self.config, "similarity_metric"):
if self.config.similarity_metric == "METRIC_IP":
pass
elif self.config.similarity_metric == "METRIC_COS":
text_embeds = F.normalize(text_embeds, p=2, dim=-1)
else: # default: use the original embedding
pass
if return_dict:
outputs.embedding = text_embeds
return outputs
else:
return text_embeds
@property
def language_model(self):
return self.roberta
class DistilBertDense(DistilBertPreTrainedModel):
def __init__(self, config: DistilBertConfig):
DistilBertPreTrainedModel.__init__(self, config)
self.distilbert = DistilBertModel(config)
def forward(self, input_ids, attention_mask, return_dict=False):
outputs = self.distilbert(input_ids, attention_mask, return_dict=True)
if hasattr(self.config, "pooling"):
if self.config.pooling == "cls":
text_embeds = outputs.last_hidden_state[:, 0]
elif self.config.pooling == "mean":
text_embeds = mean_pooling(outputs, attention_mask)
else:
raise NotImplementedError()
else:
text_embeds = outputs.last_hidden_state[:, 0]
if hasattr(self.config, "similarity_metric"):
if self.config.similarity_metric == "METRIC_IP":
pass
elif self.config.similarity_metric == "METRIC_COS":
text_embeds = F.normalize(text_embeds, p=2, dim=-1)
else:
pass
if return_dict:
outputs.embedding = text_embeds
return outputs
else:
return text_embeds
@property
def language_model(self):
return self.distilbert
class AutoDense:
@staticmethod
def from_pretrained(model_name_or_path: str, config = None) -> Union[BertDense, RobertaDense, DistilBertDense]:
if config is None:
config = AutoConfig.from_pretrained(model_name_or_path)
if config.model_type == "bert":
model = BertDense.from_pretrained(model_name_or_path, config=config)
elif config.model_type == "roberta":
model = RobertaDense.from_pretrained(model_name_or_path, config=config)
elif config.model_type == "distilbert":
model = DistilBertDense.from_pretrained(model_name_or_path, config=config)
else:
raise NotImplementedError()
return model