-
Notifications
You must be signed in to change notification settings - Fork 2.4k
/
MultipleNegativesRankingLoss.py
122 lines (99 loc) · 6.05 KB
/
MultipleNegativesRankingLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import annotations
from typing import Any, Iterable
import torch
from torch import Tensor, nn
from sentence_transformers import util
from sentence_transformers.SentenceTransformer import SentenceTransformer
class MultipleNegativesRankingLoss(nn.Module):
def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.cos_sim) -> None:
"""
This loss expects as input a batch consisting of sentence pairs ``(a_1, p_1), (a_2, p_2)..., (a_n, p_n)``
where we assume that ``(a_i, p_i)`` are a positive pair and ``(a_i, p_j)`` for ``i != j`` a negative pair.
For each ``a_i``, it uses all other ``p_j`` as negative samples, i.e., for ``a_i``, we have 1 positive example
(``p_i``) and ``n-1`` negative examples (``p_j``). It then minimizes the negative log-likehood for softmax
normalized scores.
This loss function works great to train embeddings for retrieval setups where you have positive pairs
(e.g. (query, relevant_doc)) as it will sample in each batch ``n-1`` negative docs randomly.
The performance usually increases with increasing batch sizes.
You can also provide one or multiple hard negatives per anchor-positive pair by structuring the data like this:
``(a_1, p_1, n_1), (a_2, p_2, n_2)``. Then, ``n_1`` is a hard negative for ``(a_1, p_1)``. The loss will use for
the pair ``(a_i, p_i)`` all ``p_j`` for ``j != i`` and all ``n_j`` as negatives.
Args:
model: SentenceTransformer model
scale: Output of similarity function is multiplied by scale
value
similarity_fct: similarity function between sentence
embeddings. By default, cos_sim. Can also be set to dot
product (and then set scale to 1)
References:
- Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://arxiv.org/pdf/1705.00652.pdf
- `Training Examples > Natural Language Inference <../../examples/training/nli/README.html>`_
- `Training Examples > Paraphrase Data <../../examples/training/paraphrases/README.html>`_
- `Training Examples > Quora Duplicate Questions <../../examples/training/quora_duplicate_questions/README.html>`_
- `Training Examples > MS MARCO <../../examples/training/ms_marco/README.html>`_
- `Unsupervised Learning > SimCSE <../../examples/unsupervised_learning/SimCSE/README.html>`_
- `Unsupervised Learning > GenQ <../../examples/unsupervised_learning/query_generation/README.html>`_
Requirements:
1. (anchor, positive) pairs or (anchor, positive, negative) triplets
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| (anchor, positive) pairs | none |
+---------------------------------------+--------+
| (anchor, positive, negative) triplets | none |
+---------------------------------------+--------+
Recommendations:
- Use ``BatchSamplers.NO_DUPLICATES`` (:class:`docs <sentence_transformers.training_args.BatchSamplers>`) to
ensure that no in-batch negatives are duplicates of the anchor or positive samples.
Relations:
- :class:`CachedMultipleNegativesRankingLoss` is equivalent to this loss, but it uses caching that allows for
much higher batch sizes (and thus better performance) without extra memory usage. However, it is slightly
slower.
- :class:`MultipleNegativesSymmetricRankingLoss` is equivalent to this loss, but with an additional loss term.
- :class:`GISTEmbedLoss` is equivalent to this loss, but uses a guide model to guide the in-batch negative
sample selection. `GISTEmbedLoss` yields a stronger training signal at the cost of some training overhead.
Example:
::
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
from datasets import Dataset
model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
"anchor": ["It's nice weather outside today.", "He drove to work."],
"positive": ["It's so sunny.", "He took the car to the office."],
})
loss = losses.MultipleNegativesRankingLoss(model)
trainer = SentenceTransformerTrainer(
model=model,
train_dataset=train_dataset,
loss=loss,
)
trainer.train()
"""
super().__init__()
self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
self.cross_entropy_loss = nn.CrossEntropyLoss()
def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor) -> Tensor:
reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
embeddings_a = reps[0]
embeddings_b = torch.cat(reps[1:])
scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
# Example a[i] should match with b[i]
range_labels = torch.arange(0, scores.size(0), device=scores.device)
return self.cross_entropy_loss(scores, range_labels)
def get_config_dict(self) -> dict[str, Any]:
return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}
@property
def citation(self) -> str:
return """
@misc{henderson2017efficient,
title={Efficient Natural Language Response Suggestion for Smart Reply},
author={Matthew Henderson and Rami Al-Rfou and Brian Strope and Yun-hsuan Sung and Laszlo Lukacs and Ruiqi Guo and Sanjiv Kumar and Balint Miklos and Ray Kurzweil},
year={2017},
eprint={1705.00652},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
"""