-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
Copy pathsentence_transformer.py
134 lines (105 loc) · 4.37 KB
/
sentence_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from enum import Enum
from typing import List, Optional, Union
import torch
import torch.nn.functional as F
from torch import Tensor
class PoolingStrategy(Enum):
MEAN = 'mean'
LAST = 'last'
CLS = 'cls'
LAST_HIDDEN_STATE = 'last_hidden_state'
class SentenceTransformer(torch.nn.Module):
def __init__(
self,
model_name: str,
pooling_strategy: Union[PoolingStrategy, str] = 'mean',
) -> None:
super().__init__()
self.model_name = model_name
self.pooling_strategy = PoolingStrategy(pooling_strategy)
from transformers import AutoModel, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
out = self.model(input_ids=input_ids, attention_mask=attention_mask)
emb = out[0] # First element contains all token embeddings.
if self.pooling_strategy == PoolingStrategy.MEAN:
emb = mean_pooling(emb, attention_mask)
elif self.pooling_strategy == PoolingStrategy.LAST:
emb = last_pooling(emb, attention_mask)
elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE:
emb = out.last_hidden_state
else:
assert self.pooling_strategy == PoolingStrategy.CLS
emb = emb[:, 0, :]
emb = F.normalize(emb, p=2, dim=1)
return emb
def get_input_ids(
self,
text: List[str],
batch_size: Optional[int] = None,
output_device: Optional[Union[torch.device, str]] = None,
) -> Tensor:
is_empty = len(text) == 0
text = ['dummy'] if is_empty else text
batch_size = len(text) if batch_size is None else batch_size
input_ids: List[Tensor] = []
attention_masks: List[Tensor] = []
for start in range(0, len(text), batch_size):
token = self.tokenizer(
text[start:start + batch_size],
padding=True,
truncation=True,
return_tensors='pt',
)
input_ids.append(token.input_ids.to(self.device))
attention_masks.append(token.attention_mask.to(self.device))
def _out(x: List[Tensor]) -> Tensor:
out = torch.cat(x, dim=0) if len(x) > 1 else x[0]
out = out[:0] if is_empty else out
return out.to(output_device)
return _out(input_ids), _out(attention_masks)
@property
def device(self) -> torch.device:
return next(iter(self.model.parameters())).device
@torch.no_grad()
def encode(
self,
text: List[str],
batch_size: Optional[int] = None,
output_device: Optional[Union[torch.device, str]] = None,
) -> Tensor:
is_empty = len(text) == 0
text = ['dummy'] if is_empty else text
batch_size = len(text) if batch_size is None else batch_size
embs: List[Tensor] = []
for start in range(0, len(text), batch_size):
token = self.tokenizer(
text[start:start + batch_size],
padding=True,
truncation=True,
return_tensors='pt',
)
emb = self(
input_ids=token.input_ids.to(self.device),
attention_mask=token.attention_mask.to(self.device),
).to(output_device)
embs.append(emb)
out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]
out = out[:0] if is_empty else out
return out
def __repr__(self) -> str:
return f'{self.__class__.__name__}(model_name={self.model_name})'
def mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)
return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
def last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:
# Check whether language model uses left padding,
# which is always used for decoder LLMs
left_padding = attention_mask[:, -1].sum() == attention_mask.size(0)
if left_padding:
return emb[:, -1]
seq_indices = attention_mask.sum(dim=1) - 1
return emb[torch.arange(emb.size(0), device=emb.device), seq_indices]