From 7aa827cedacf1363dccb83d1a273d433419599ba Mon Sep 17 00:00:00 2001 From: JacksonCakes Date: Sat, 13 Apr 2024 17:38:15 +0800 Subject: [PATCH 1/9] Add CachedGISTEmbedLoss to __init__ to allow for easier import --- sentence_transformers/losses/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sentence_transformers/losses/__init__.py b/sentence_transformers/losses/__init__.py index a48976d00..00a64e2cb 100644 --- a/sentence_transformers/losses/__init__.py +++ b/sentence_transformers/losses/__init__.py @@ -21,6 +21,7 @@ from .MegaBatchMarginLoss import MegaBatchMarginLoss from .DenoisingAutoEncoderLoss import DenoisingAutoEncoderLoss from .GISTEmbedLoss import GISTEmbedLoss +from .CachedGISTEmbedLoss import CachedGISTEmbedLoss # Triplet losses from .BatchHardTripletLoss import BatchHardTripletLoss, BatchHardTripletLossDistanceFunction @@ -42,6 +43,7 @@ "MSELoss", "ContrastiveLoss", "SiameseDistanceMetric", + "CachedGISTEmbedLoss", "CachedMultipleNegativesRankingLoss", "ContrastiveTensionLoss", "ContrastiveTensionLossInBatchNegatives", From 5e050452a5dc06b541f2157eb20023003233b1f5 Mon Sep 17 00:00:00 2001 From: JacksonCakes Date: Sat, 13 Apr 2024 22:58:35 +0800 Subject: [PATCH 2/9] Add intiial implementation --- .../losses/CachedGISTEmbedLoss.py | 322 ++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 sentence_transformers/losses/CachedGISTEmbedLoss.py diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py new file mode 100644 index 000000000..a791fe736 --- /dev/null +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -0,0 +1,322 @@ +from __future__ import annotations +from contextlib import nullcontext +from functools import partial +import torch +from torch import nn, Tensor +from torch.utils.checkpoint import get_device_states, set_device_states +from typing import Iterable, Dict, Iterator, List, Optional, Tuple +from sentence_transformers import SentenceTransformer +from sentence_transformers import util +import tqdm +from sentence_transformers.models import Transformer + + +class RandContext: + """ + Random-state context manager class. Reference: https://github.com/luyug/GradCache. + + This class will back up the pytorch's random state during initialization. Then when the context is activated, + the class will set up the random state with the backed-up one. + """ + + def __init__(self, *tensors): + self.fwd_cpu_state = torch.get_rng_state() + self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors) + + def __enter__(self): + self._fork = torch.random.fork_rng(devices=self.fwd_gpu_devices, enabled=True) + self._fork.__enter__() + torch.set_rng_state(self.fwd_cpu_state) + set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._fork.__exit__(exc_type, exc_val, exc_tb) + self._fork = None + + +def _backward_hook( + grad_output: Tensor, + sentence_features: Iterable[Dict[str, Tensor]], + loss_obj: CachedGISTEmbedLoss, +): + """A backward hook to backpropagate the cached gradients mini-batch by mini-batch.""" + assert loss_obj.cache is not None + assert loss_obj.random_states is not None + with torch.enable_grad(): + for sentence_feature, grad, random_states in zip(sentence_features, loss_obj.cache, loss_obj.random_states): + for (reps_mb, _, _), grad_mb in zip( + loss_obj.embed_minibatch_iter( + sentence_feature=sentence_feature, + with_grad=True, + copy_random_state=False, + random_states=random_states, + ), + grad, + ): + surrogate = torch.dot(reps_mb.flatten(), grad_mb.flatten()) * grad_output + surrogate.backward() + + +class CachedGISTEmbedLoss(nn.Module): + def __init__( + self, + model: SentenceTransformer, + guide: SentenceTransformer, + temperature: float = 0.01, + similarity_fct: callable[[Tensor, Tensor], Tensor] = util.cos_sim, + mini_batch_size: int = 32, + show_progress_bar: bool = False, + ): + """ + Boosted version of MultipleNegativesRankingLoss (https://arxiv.org/pdf/1705.00652.pdf) by GradCache (https://arxiv.org/pdf/2101.06983.pdf). + + Constrastive learning (here our MNRL loss) with in-batch negatives is usually hard to work with large batch sizes due to (GPU) memory limitation. + Even with batch-scaling methods like gradient-scaling, it cannot work either. This is because the in-batch negatives make the data points within + the same batch non-independent and thus the batch cannot be broke down into mini-batches. GradCache is a smart way to solve this problem. + It achieves the goal by dividing the computation into two stages of embedding and loss calculation, which both can be scaled by mini-batches. + As a result, memory of constant size (e.g. that works with batch size = 32) can now process much larger batches (e.g. 65536). + + In detail: + + (1) It first does a quick embedding step without gradients/computation graphs to get all the embeddings; + (2) Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings; + (3) A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain. + + Notes: All steps are done with mini-batches. In the original implementation of GradCache, (2) is not done in mini-batches and + requires a lot memory when batch size large. One drawback is about the speed. GradCache will sacrifice around 20% computation time according to the paper. + + :param model: SentenceTransformer model + :param scale: Output of similarity function is multiplied by scale value + :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1) + + References: + - Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://arxiv.org/pdf/1705.00652.pdf + - Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://arxiv.org/pdf/2101.06983.pdf + + Requirements: + 1. (anchor, positive) pairs or (anchor, positive, negative pairs) + 2. Should be used with large batch sizes for superior performance, but has slower training time than :class:`MultipleNegativesRankingLoss` + + Relations: + - Equivalent to :class:`MultipleNegativesRankingLoss`, but with caching that allows for much higher batch sizes + (and thus better performance) without extra memory usage. This loss also trains roughly 2x to 2.4x slower than + :class:`MultipleNegativesRankingLoss`. + + Inputs: + +---------------------------------------+--------+ + | Texts | Labels | + +=======================================+========+ + | (anchor, positive) pairs | none | + +---------------------------------------+--------+ + | (anchor, positive, negative) triplets | none | + +---------------------------------------+--------+ + + Example: + :: + + from sentence_transformers import SentenceTransformer, losses, InputExample + from torch.utils.data import DataLoader + + model = SentenceTransformer('distilbert-base-uncased') + train_examples = [ + InputExample(texts=['Anchor 1', 'Positive 1']), + InputExample(texts=['Anchor 2', 'Positive 2']), + ] + train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=1024) # Here we can try much larger batch sizes! + train_loss = losses.CachedGISTEmbedLoss(model=model, mini_batch_size = 32) + model.fit( + [(train_dataloader, train_loss)], + epochs=10, + ) + """ + super(CachedGISTEmbedLoss, self).__init__() + self.model = model + # self.scale = scale + self.guide = guide + self.temperature = temperature + self.similarity_fct = nn.CosineSimilarity(dim=-1) + if not isinstance(model[0], Transformer) or not isinstance(guide[0], Transformer): + raise ValueError( + "Both the training model and the guiding model must be based on the `transformers` architecture." + ) + self.cross_entropy_loss = nn.CrossEntropyLoss() + self.mini_batch_size = mini_batch_size + self.cache: Optional[List[List[Tensor]]] = None + self.random_states: Optional[List[List[RandContext]]] = None + self.show_progress_bar = show_progress_bar + self.must_retokenize = ( + model.tokenizer.vocab != guide.tokenizer.vocab or guide.max_seq_length < model.max_seq_length + ) + + def sim_matrix(self, embed1, embed2): + return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0)) + + def embed_minibatch( + self, + sentence_feature: Dict[str, Tensor], + begin: int, + end: int, + with_grad: bool, + copy_random_state: bool, + random_state: Optional[RandContext] = None, + ) -> Tuple[Tensor, Optional[RandContext]]: + """Do forward pass on a minibatch of the input features and return corresponding embeddings.""" + grad_context = nullcontext if with_grad else torch.no_grad + random_state_context = nullcontext() if random_state is None else random_state + sentence_feature_minibatch = {k: v[begin:end] for k, v in sentence_feature.items()} + with random_state_context: + with grad_context(): + random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None + reps = self.model(sentence_feature_minibatch)["sentence_embedding"] # (mbsz, hdim) + # TODO: Compute guide embeddings + with torch.no_grad(): + if self.must_retokenize: + decoded = self.model.tokenizer.batch_decode( + sentence_feature_minibatch["input_ids"], skip_special_tokens=True + ) + sentence_feature_minibatch = self.guide.tokenize(decoded) + sentence_feature_minibatch = { + key: value.to(self.guide.device) for key, value in sentence_feature_minibatch.items() + } + guide_reps = self.guide(sentence_feature_minibatch)["sentence_embedding"] + + return reps, guide_reps, random_state + + def embed_minibatch_iter( + self, + sentence_feature: Dict[str, Tensor], + with_grad: bool, + copy_random_state: bool, + random_states: Optional[List[RandContext]] = None, + ) -> Iterator[Tuple[Tensor, Optional[RandContext]]]: + """Do forward pass on all the minibatches of the input features and yield corresponding embeddings.""" + input_ids: Tensor = sentence_feature["input_ids"] + bsz, _ = input_ids.shape + for i, b in enumerate( + tqdm.trange( + 0, + bsz, + self.mini_batch_size, + desc="Embed mini-batches", + disable=not self.show_progress_bar, + ) + ): + e = b + self.mini_batch_size + reps, guide_reps, random_state = self.embed_minibatch( + sentence_feature=sentence_feature, + begin=b, + end=e, + with_grad=with_grad, + copy_random_state=copy_random_state, + random_state=None if random_states is None else random_states[i], + ) + yield reps, guide_reps, random_state # reps: (mbsz, hdim) + + def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor: + """Calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" + if len(reps) == 2: + anchor, positive = reps + anchor_guide, positive_guide = reps_guided + negative = None + negative_guide = None + elif len(reps) == 3: + anchor, positive, negative = reps + anchor_guide, positive_guide, negative_guide = reps_guided + else: + raise ValueError("Expected 2 or 3 embeddings, got {}".format(len(reps))) + + # Concatenate the lists into single tensors. + anchor = torch.cat(anchor, dim=0) + positive = torch.cat(positive, dim=0) + anchor_guide = torch.cat(anchor_guide, dim=0) + positive_guide = torch.cat(positive_guide, dim=0) + if negative: + negative = torch.cat(negative, dim=0) + negative_guide = torch.cat(negative_guide, dim=0) + guided_an_sim = self.sim_matrix(anchor_guide, negative_guide) + + # Let's compute the similarity matrices for the combinations of anchor and positive samples. + guided_ap_sim = self.sim_matrix(anchor_guide, positive_guide) + guided_aa_sim = self.sim_matrix(anchor_guide, anchor_guide) + guided_pp_sim = self.sim_matrix(positive_guide, positive_guide) + # Define the anchor threshold + guided_sim = guided_ap_sim.diagonal().view(-1, 1) + + labels = torch.arange(anchor.size(0)).long().to(anchor.device) + batch_size = anchor.shape[0] + + losses: List[torch.Tensor] = [] + for b in tqdm.trange( + 0, + batch_size, + self.mini_batch_size, + desc="Preparing caches", + disable=not self.show_progress_bar, + ): + e = b + self.mini_batch_size + # Compute similarity scores for current mini-batch. + ap_sim = self.sim_matrix(anchor[b:e], positive) + aa_sim = self.sim_matrix(anchor[b:e], anchor) + pp_sim = self.sim_matrix(positive[b:e], positive) + + # Find which samples cannot be used as negatives because they are + # more similar to the query than the assigned positive as deemed by the guide model. + # For these samples, we mask them with -inf to basically ignore their contribution to + # the loss. + ap_sim[guided_ap_sim[b:e] > guided_sim[b:e]] = -torch.inf + aa_sim[guided_aa_sim[b:e] > guided_sim[b:e]] = -torch.inf + pp_sim[guided_pp_sim[b:e] > guided_sim[b:e]] = -torch.inf + + scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1) + + # Handle the case where we have a negative sample + if negative is not None: + an_sim = self.sim_matrix(anchor[b:e], negative) + an_sim[guided_an_sim[b:e] > guided_sim[b:e]] = -torch.inf + scores = torch.cat([scores, an_sim], dim=1) + scores = scores / self.temperature + loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size + loss_mbatch.backward() + losses.append(loss_mbatch.detach()) + + loss = sum(losses).requires_grad_() + + self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim) + + return loss + + def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor) -> Tensor: + # Step (1): A quick embedding step without gradients/computation graphs to get all the embeddings + reps = [] + reps_guided = [] + self.random_states = [] # Copy random states to guarantee exact reproduction of the embeddings during the second forward pass, i.e. step (3) + for sentence_feature in sentence_features: + reps_mbs = [] + reps_guided_mbs = [] + random_state_mbs = [] + for reps_mb, reps_guided_mb, random_state in self.embed_minibatch_iter( + sentence_feature=sentence_feature, + with_grad=False, + copy_random_state=True, + ): + # TODO: reps contains reps_mbs contains reps_mb, reps contains each feature + # anchor + pos + neg, then for each of these contains minibatch + reps_mbs.append(reps_mb.detach().requires_grad_()) + reps_guided_mbs.append(reps_guided_mb.detach()) # does not requires gradient + random_state_mbs.append(random_state) + reps.append(reps_mbs) + reps_guided.append(reps_guided_mbs) + self.random_states.append(random_state_mbs) + + # Step (2): Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings + loss = self.calculate_loss_and_cache_gradients(reps, reps_guided) + print(loss) + # Step (3): A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain + loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self)) + return loss + + def get_config_dict(self): + return { + "guide": self.guide, + "temperature": self.temperature, + } From 98b6ecf4c82e00dda36ab0fef5073df48d60ab81 Mon Sep 17 00:00:00 2001 From: JacksonCakes Date: Sat, 13 Apr 2024 22:58:35 +0800 Subject: [PATCH 3/9] Add initial implementation --- .../losses/CachedGISTEmbedLoss.py | 322 ++++++++++++++++++ 1 file changed, 322 insertions(+) create mode 100644 sentence_transformers/losses/CachedGISTEmbedLoss.py diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py new file mode 100644 index 000000000..a791fe736 --- /dev/null +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -0,0 +1,322 @@ +from __future__ import annotations +from contextlib import nullcontext +from functools import partial +import torch +from torch import nn, Tensor +from torch.utils.checkpoint import get_device_states, set_device_states +from typing import Iterable, Dict, Iterator, List, Optional, Tuple +from sentence_transformers import SentenceTransformer +from sentence_transformers import util +import tqdm +from sentence_transformers.models import Transformer + + +class RandContext: + """ + Random-state context manager class. Reference: https://github.com/luyug/GradCache. + + This class will back up the pytorch's random state during initialization. Then when the context is activated, + the class will set up the random state with the backed-up one. + """ + + def __init__(self, *tensors): + self.fwd_cpu_state = torch.get_rng_state() + self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors) + + def __enter__(self): + self._fork = torch.random.fork_rng(devices=self.fwd_gpu_devices, enabled=True) + self._fork.__enter__() + torch.set_rng_state(self.fwd_cpu_state) + set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states) + + def __exit__(self, exc_type, exc_val, exc_tb): + self._fork.__exit__(exc_type, exc_val, exc_tb) + self._fork = None + + +def _backward_hook( + grad_output: Tensor, + sentence_features: Iterable[Dict[str, Tensor]], + loss_obj: CachedGISTEmbedLoss, +): + """A backward hook to backpropagate the cached gradients mini-batch by mini-batch.""" + assert loss_obj.cache is not None + assert loss_obj.random_states is not None + with torch.enable_grad(): + for sentence_feature, grad, random_states in zip(sentence_features, loss_obj.cache, loss_obj.random_states): + for (reps_mb, _, _), grad_mb in zip( + loss_obj.embed_minibatch_iter( + sentence_feature=sentence_feature, + with_grad=True, + copy_random_state=False, + random_states=random_states, + ), + grad, + ): + surrogate = torch.dot(reps_mb.flatten(), grad_mb.flatten()) * grad_output + surrogate.backward() + + +class CachedGISTEmbedLoss(nn.Module): + def __init__( + self, + model: SentenceTransformer, + guide: SentenceTransformer, + temperature: float = 0.01, + similarity_fct: callable[[Tensor, Tensor], Tensor] = util.cos_sim, + mini_batch_size: int = 32, + show_progress_bar: bool = False, + ): + """ + Boosted version of MultipleNegativesRankingLoss (https://arxiv.org/pdf/1705.00652.pdf) by GradCache (https://arxiv.org/pdf/2101.06983.pdf). + + Constrastive learning (here our MNRL loss) with in-batch negatives is usually hard to work with large batch sizes due to (GPU) memory limitation. + Even with batch-scaling methods like gradient-scaling, it cannot work either. This is because the in-batch negatives make the data points within + the same batch non-independent and thus the batch cannot be broke down into mini-batches. GradCache is a smart way to solve this problem. + It achieves the goal by dividing the computation into two stages of embedding and loss calculation, which both can be scaled by mini-batches. + As a result, memory of constant size (e.g. that works with batch size = 32) can now process much larger batches (e.g. 65536). + + In detail: + + (1) It first does a quick embedding step without gradients/computation graphs to get all the embeddings; + (2) Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings; + (3) A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain. + + Notes: All steps are done with mini-batches. In the original implementation of GradCache, (2) is not done in mini-batches and + requires a lot memory when batch size large. One drawback is about the speed. GradCache will sacrifice around 20% computation time according to the paper. + + :param model: SentenceTransformer model + :param scale: Output of similarity function is multiplied by scale value + :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1) + + References: + - Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://arxiv.org/pdf/1705.00652.pdf + - Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://arxiv.org/pdf/2101.06983.pdf + + Requirements: + 1. (anchor, positive) pairs or (anchor, positive, negative pairs) + 2. Should be used with large batch sizes for superior performance, but has slower training time than :class:`MultipleNegativesRankingLoss` + + Relations: + - Equivalent to :class:`MultipleNegativesRankingLoss`, but with caching that allows for much higher batch sizes + (and thus better performance) without extra memory usage. This loss also trains roughly 2x to 2.4x slower than + :class:`MultipleNegativesRankingLoss`. + + Inputs: + +---------------------------------------+--------+ + | Texts | Labels | + +=======================================+========+ + | (anchor, positive) pairs | none | + +---------------------------------------+--------+ + | (anchor, positive, negative) triplets | none | + +---------------------------------------+--------+ + + Example: + :: + + from sentence_transformers import SentenceTransformer, losses, InputExample + from torch.utils.data import DataLoader + + model = SentenceTransformer('distilbert-base-uncased') + train_examples = [ + InputExample(texts=['Anchor 1', 'Positive 1']), + InputExample(texts=['Anchor 2', 'Positive 2']), + ] + train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=1024) # Here we can try much larger batch sizes! + train_loss = losses.CachedGISTEmbedLoss(model=model, mini_batch_size = 32) + model.fit( + [(train_dataloader, train_loss)], + epochs=10, + ) + """ + super(CachedGISTEmbedLoss, self).__init__() + self.model = model + # self.scale = scale + self.guide = guide + self.temperature = temperature + self.similarity_fct = nn.CosineSimilarity(dim=-1) + if not isinstance(model[0], Transformer) or not isinstance(guide[0], Transformer): + raise ValueError( + "Both the training model and the guiding model must be based on the `transformers` architecture." + ) + self.cross_entropy_loss = nn.CrossEntropyLoss() + self.mini_batch_size = mini_batch_size + self.cache: Optional[List[List[Tensor]]] = None + self.random_states: Optional[List[List[RandContext]]] = None + self.show_progress_bar = show_progress_bar + self.must_retokenize = ( + model.tokenizer.vocab != guide.tokenizer.vocab or guide.max_seq_length < model.max_seq_length + ) + + def sim_matrix(self, embed1, embed2): + return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0)) + + def embed_minibatch( + self, + sentence_feature: Dict[str, Tensor], + begin: int, + end: int, + with_grad: bool, + copy_random_state: bool, + random_state: Optional[RandContext] = None, + ) -> Tuple[Tensor, Optional[RandContext]]: + """Do forward pass on a minibatch of the input features and return corresponding embeddings.""" + grad_context = nullcontext if with_grad else torch.no_grad + random_state_context = nullcontext() if random_state is None else random_state + sentence_feature_minibatch = {k: v[begin:end] for k, v in sentence_feature.items()} + with random_state_context: + with grad_context(): + random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None + reps = self.model(sentence_feature_minibatch)["sentence_embedding"] # (mbsz, hdim) + # TODO: Compute guide embeddings + with torch.no_grad(): + if self.must_retokenize: + decoded = self.model.tokenizer.batch_decode( + sentence_feature_minibatch["input_ids"], skip_special_tokens=True + ) + sentence_feature_minibatch = self.guide.tokenize(decoded) + sentence_feature_minibatch = { + key: value.to(self.guide.device) for key, value in sentence_feature_minibatch.items() + } + guide_reps = self.guide(sentence_feature_minibatch)["sentence_embedding"] + + return reps, guide_reps, random_state + + def embed_minibatch_iter( + self, + sentence_feature: Dict[str, Tensor], + with_grad: bool, + copy_random_state: bool, + random_states: Optional[List[RandContext]] = None, + ) -> Iterator[Tuple[Tensor, Optional[RandContext]]]: + """Do forward pass on all the minibatches of the input features and yield corresponding embeddings.""" + input_ids: Tensor = sentence_feature["input_ids"] + bsz, _ = input_ids.shape + for i, b in enumerate( + tqdm.trange( + 0, + bsz, + self.mini_batch_size, + desc="Embed mini-batches", + disable=not self.show_progress_bar, + ) + ): + e = b + self.mini_batch_size + reps, guide_reps, random_state = self.embed_minibatch( + sentence_feature=sentence_feature, + begin=b, + end=e, + with_grad=with_grad, + copy_random_state=copy_random_state, + random_state=None if random_states is None else random_states[i], + ) + yield reps, guide_reps, random_state # reps: (mbsz, hdim) + + def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guided: List[List[Tensor]]) -> Tensor: + """Calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" + if len(reps) == 2: + anchor, positive = reps + anchor_guide, positive_guide = reps_guided + negative = None + negative_guide = None + elif len(reps) == 3: + anchor, positive, negative = reps + anchor_guide, positive_guide, negative_guide = reps_guided + else: + raise ValueError("Expected 2 or 3 embeddings, got {}".format(len(reps))) + + # Concatenate the lists into single tensors. + anchor = torch.cat(anchor, dim=0) + positive = torch.cat(positive, dim=0) + anchor_guide = torch.cat(anchor_guide, dim=0) + positive_guide = torch.cat(positive_guide, dim=0) + if negative: + negative = torch.cat(negative, dim=0) + negative_guide = torch.cat(negative_guide, dim=0) + guided_an_sim = self.sim_matrix(anchor_guide, negative_guide) + + # Let's compute the similarity matrices for the combinations of anchor and positive samples. + guided_ap_sim = self.sim_matrix(anchor_guide, positive_guide) + guided_aa_sim = self.sim_matrix(anchor_guide, anchor_guide) + guided_pp_sim = self.sim_matrix(positive_guide, positive_guide) + # Define the anchor threshold + guided_sim = guided_ap_sim.diagonal().view(-1, 1) + + labels = torch.arange(anchor.size(0)).long().to(anchor.device) + batch_size = anchor.shape[0] + + losses: List[torch.Tensor] = [] + for b in tqdm.trange( + 0, + batch_size, + self.mini_batch_size, + desc="Preparing caches", + disable=not self.show_progress_bar, + ): + e = b + self.mini_batch_size + # Compute similarity scores for current mini-batch. + ap_sim = self.sim_matrix(anchor[b:e], positive) + aa_sim = self.sim_matrix(anchor[b:e], anchor) + pp_sim = self.sim_matrix(positive[b:e], positive) + + # Find which samples cannot be used as negatives because they are + # more similar to the query than the assigned positive as deemed by the guide model. + # For these samples, we mask them with -inf to basically ignore their contribution to + # the loss. + ap_sim[guided_ap_sim[b:e] > guided_sim[b:e]] = -torch.inf + aa_sim[guided_aa_sim[b:e] > guided_sim[b:e]] = -torch.inf + pp_sim[guided_pp_sim[b:e] > guided_sim[b:e]] = -torch.inf + + scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1) + + # Handle the case where we have a negative sample + if negative is not None: + an_sim = self.sim_matrix(anchor[b:e], negative) + an_sim[guided_an_sim[b:e] > guided_sim[b:e]] = -torch.inf + scores = torch.cat([scores, an_sim], dim=1) + scores = scores / self.temperature + loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size + loss_mbatch.backward() + losses.append(loss_mbatch.detach()) + + loss = sum(losses).requires_grad_() + + self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim) + + return loss + + def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor) -> Tensor: + # Step (1): A quick embedding step without gradients/computation graphs to get all the embeddings + reps = [] + reps_guided = [] + self.random_states = [] # Copy random states to guarantee exact reproduction of the embeddings during the second forward pass, i.e. step (3) + for sentence_feature in sentence_features: + reps_mbs = [] + reps_guided_mbs = [] + random_state_mbs = [] + for reps_mb, reps_guided_mb, random_state in self.embed_minibatch_iter( + sentence_feature=sentence_feature, + with_grad=False, + copy_random_state=True, + ): + # TODO: reps contains reps_mbs contains reps_mb, reps contains each feature + # anchor + pos + neg, then for each of these contains minibatch + reps_mbs.append(reps_mb.detach().requires_grad_()) + reps_guided_mbs.append(reps_guided_mb.detach()) # does not requires gradient + random_state_mbs.append(random_state) + reps.append(reps_mbs) + reps_guided.append(reps_guided_mbs) + self.random_states.append(random_state_mbs) + + # Step (2): Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings + loss = self.calculate_loss_and_cache_gradients(reps, reps_guided) + print(loss) + # Step (3): A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain + loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self)) + return loss + + def get_config_dict(self): + return { + "guide": self.guide, + "temperature": self.temperature, + } From f4d7f38c2a167027906c9b705aa6050ae5e8e30f Mon Sep 17 00:00:00 2001 From: JacksonCakes Date: Sun, 14 Apr 2024 00:32:22 +0800 Subject: [PATCH 4/9] Add docstring --- .../losses/CachedGISTEmbedLoss.py | 44 ++++++------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index a791fe736..ccc79f4d7 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -6,7 +6,6 @@ from torch.utils.checkpoint import get_device_states, set_device_states from typing import Iterable, Dict, Iterator, List, Optional, Tuple from sentence_transformers import SentenceTransformer -from sentence_transformers import util import tqdm from sentence_transformers.models import Transformer @@ -63,44 +62,29 @@ def __init__( model: SentenceTransformer, guide: SentenceTransformer, temperature: float = 0.01, - similarity_fct: callable[[Tensor, Tensor], Tensor] = util.cos_sim, mini_batch_size: int = 32, show_progress_bar: bool = False, ): """ - Boosted version of MultipleNegativesRankingLoss (https://arxiv.org/pdf/1705.00652.pdf) by GradCache (https://arxiv.org/pdf/2101.06983.pdf). - - Constrastive learning (here our MNRL loss) with in-batch negatives is usually hard to work with large batch sizes due to (GPU) memory limitation. - Even with batch-scaling methods like gradient-scaling, it cannot work either. This is because the in-batch negatives make the data points within - the same batch non-independent and thus the batch cannot be broke down into mini-batches. GradCache is a smart way to solve this problem. - It achieves the goal by dividing the computation into two stages of embedding and loss calculation, which both can be scaled by mini-batches. - As a result, memory of constant size (e.g. that works with batch size = 32) can now process much larger batches (e.g. 65536). - - In detail: - - (1) It first does a quick embedding step without gradients/computation graphs to get all the embeddings; - (2) Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings; - (3) A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain. - - Notes: All steps are done with mini-batches. In the original implementation of GradCache, (2) is not done in mini-batches and - requires a lot memory when batch size large. One drawback is about the speed. GradCache will sacrifice around 20% computation time according to the paper. + This loss is a combination of GISTEmbedLoss and CachedMultipleNegativeRankingLoss. + Typically, MNR Loss requires a larger batch size for better performance. + GISTEmbedLoss yields stronger training signals than MNR Loss due to the use of a guide model for in-batch negative sample selection. Meanwhile, CachedMNR Loss allows for scaling of the batch size by dividing the computation into two stages of embedding and loss calculation, which both can be scaled by mini-batches(https://arxiv.org/pdf/2101.06983.pdf). By combining the guided selection from GISTEmbedLoss and Gradient Cache by CachedMNRLoss, it is possible to reduce memory usage while maintaining performance levels comparable to those of GISTEmbedLoss. :param model: SentenceTransformer model - :param scale: Output of similarity function is multiplied by scale value - :param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1) + :param guide: SentenceTransformer model to guide the in-batch negative sample selection. + :param temperature: Temperature parameter to scale the cosine similarities. References: - Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://arxiv.org/pdf/1705.00652.pdf - Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://arxiv.org/pdf/2101.06983.pdf + - GISTEmbed: Guided In-sample Selection of Training Negatives for Text Embedding Fine-tuning https://arxiv.org/abs/2402.16829 Requirements: 1. (anchor, positive) pairs or (anchor, positive, negative pairs) 2. Should be used with large batch sizes for superior performance, but has slower training time than :class:`MultipleNegativesRankingLoss` Relations: - - Equivalent to :class:`MultipleNegativesRankingLoss`, but with caching that allows for much higher batch sizes - (and thus better performance) without extra memory usage. This loss also trains roughly 2x to 2.4x slower than - :class:`MultipleNegativesRankingLoss`. + - Equivalent to :class:`GISTEmbedLoss`, but with caching that allows for much higher batch sizes Inputs: +---------------------------------------+--------+ @@ -118,12 +102,14 @@ def __init__( from torch.utils.data import DataLoader model = SentenceTransformer('distilbert-base-uncased') + guide = SentenceTransformer('avsolatorio/GIST-small-Embedding-v0') + train_examples = [ InputExample(texts=['Anchor 1', 'Positive 1']), InputExample(texts=['Anchor 2', 'Positive 2']), ] train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=1024) # Here we can try much larger batch sizes! - train_loss = losses.CachedGISTEmbedLoss(model=model, mini_batch_size = 32) + train_loss = losses.CachedGISTEmbedLoss(model=model, mini_batch_size=32, guide=guide) model.fit( [(train_dataloader, train_loss)], epochs=10, @@ -131,7 +117,6 @@ def __init__( """ super(CachedGISTEmbedLoss, self).__init__() self.model = model - # self.scale = scale self.guide = guide self.temperature = temperature self.similarity_fct = nn.CosineSimilarity(dim=-1) @@ -168,7 +153,6 @@ def embed_minibatch( with grad_context(): random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None reps = self.model(sentence_feature_minibatch)["sentence_embedding"] # (mbsz, hdim) - # TODO: Compute guide embeddings with torch.no_grad(): if self.must_retokenize: decoded = self.model.tokenizer.batch_decode( @@ -225,11 +209,11 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid else: raise ValueError("Expected 2 or 3 embeddings, got {}".format(len(reps))) - # Concatenate the lists into single tensors. anchor = torch.cat(anchor, dim=0) positive = torch.cat(positive, dim=0) anchor_guide = torch.cat(anchor_guide, dim=0) positive_guide = torch.cat(positive_guide, dim=0) + # Handle the case where we have a negative sample if negative: negative = torch.cat(negative, dim=0) negative_guide = torch.cat(negative_guide, dim=0) @@ -255,7 +239,8 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid ): e = b + self.mini_batch_size # Compute similarity scores for current mini-batch. - ap_sim = self.sim_matrix(anchor[b:e], positive) + # anchor (mbsz,hdim), positive (bsz,hdim) + ap_sim = self.sim_matrix(anchor[b:e], positive) # (mbsz,bsz) aa_sim = self.sim_matrix(anchor[b:e], anchor) pp_sim = self.sim_matrix(positive[b:e], positive) @@ -299,8 +284,6 @@ def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor with_grad=False, copy_random_state=True, ): - # TODO: reps contains reps_mbs contains reps_mb, reps contains each feature - # anchor + pos + neg, then for each of these contains minibatch reps_mbs.append(reps_mb.detach().requires_grad_()) reps_guided_mbs.append(reps_guided_mb.detach()) # does not requires gradient random_state_mbs.append(random_state) @@ -310,7 +293,6 @@ def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor # Step (2): Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings loss = self.calculate_loss_and_cache_gradients(reps, reps_guided) - print(loss) # Step (3): A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self)) return loss From 3208e611247f920ffa18262e1d1839a8f85af340 Mon Sep 17 00:00:00 2001 From: JacksonCakes Date: Mon, 15 Apr 2024 16:15:42 +0800 Subject: [PATCH 5/9] Update guided similarity computation in mini-batch --- .../losses/CachedGISTEmbedLoss.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index ccc79f4d7..9e0e1e170 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -217,14 +217,6 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid if negative: negative = torch.cat(negative, dim=0) negative_guide = torch.cat(negative_guide, dim=0) - guided_an_sim = self.sim_matrix(anchor_guide, negative_guide) - - # Let's compute the similarity matrices for the combinations of anchor and positive samples. - guided_ap_sim = self.sim_matrix(anchor_guide, positive_guide) - guided_aa_sim = self.sim_matrix(anchor_guide, anchor_guide) - guided_pp_sim = self.sim_matrix(positive_guide, positive_guide) - # Define the anchor threshold - guided_sim = guided_ap_sim.diagonal().view(-1, 1) labels = torch.arange(anchor.size(0)).long().to(anchor.device) batch_size = anchor.shape[0] @@ -238,6 +230,13 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid disable=not self.show_progress_bar, ): e = b + self.mini_batch_size + # Let's compute the similarity matrices for the combinations of anchor and positive samples. + guided_ap_sim = self.sim_matrix(anchor_guide[b:e], positive_guide) + guided_aa_sim = self.sim_matrix(anchor_guide[b:e], anchor_guide) + guided_pp_sim = self.sim_matrix(positive_guide[b:e], positive_guide) + # Define the anchor threshold + guided_sim = guided_ap_sim.diagonal().view(-1, 1) + # Compute similarity scores for current mini-batch. # anchor (mbsz,hdim), positive (bsz,hdim) ap_sim = self.sim_matrix(anchor[b:e], positive) # (mbsz,bsz) @@ -248,16 +247,17 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid # more similar to the query than the assigned positive as deemed by the guide model. # For these samples, we mask them with -inf to basically ignore their contribution to # the loss. - ap_sim[guided_ap_sim[b:e] > guided_sim[b:e]] = -torch.inf - aa_sim[guided_aa_sim[b:e] > guided_sim[b:e]] = -torch.inf - pp_sim[guided_pp_sim[b:e] > guided_sim[b:e]] = -torch.inf + ap_sim[guided_ap_sim > guided_sim] = -torch.inf + aa_sim[guided_aa_sim > guided_sim] = -torch.inf + pp_sim[guided_pp_sim > guided_sim] = -torch.inf scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1) # Handle the case where we have a negative sample if negative is not None: + guided_an_sim = self.sim_matrix(anchor_guide[b:e], negative_guide) an_sim = self.sim_matrix(anchor[b:e], negative) - an_sim[guided_an_sim[b:e] > guided_sim[b:e]] = -torch.inf + an_sim[guided_an_sim > guided_sim] = -torch.inf scores = torch.cat([scores, an_sim], dim=1) scores = scores / self.temperature loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size From 3215c06f96bd7ff4f6e06cbc7d7478e218e29bcd Mon Sep 17 00:00:00 2001 From: JacksonCakes Date: Mon, 15 Apr 2024 18:10:10 +0800 Subject: [PATCH 6/9] Fix guiding mask by adding offset --- sentence_transformers/losses/CachedGISTEmbedLoss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index 9e0e1e170..c1175baed 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -235,7 +235,7 @@ def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]], reps_guid guided_aa_sim = self.sim_matrix(anchor_guide[b:e], anchor_guide) guided_pp_sim = self.sim_matrix(positive_guide[b:e], positive_guide) # Define the anchor threshold - guided_sim = guided_ap_sim.diagonal().view(-1, 1) + guided_sim = guided_ap_sim.diagonal(offset=b).view(-1, 1) # Compute similarity scores for current mini-batch. # anchor (mbsz,hdim), positive (bsz,hdim) From 6ccf388180f9899f0003836cdfc17f4bf3af5101 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 16 Apr 2024 10:53:39 +0200 Subject: [PATCH 7/9] Write docstring on multiple lines --- sentence_transformers/losses/CachedGISTEmbedLoss.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index c1175baed..f8623e325 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -66,9 +66,16 @@ def __init__( show_progress_bar: bool = False, ): """ - This loss is a combination of GISTEmbedLoss and CachedMultipleNegativeRankingLoss. - Typically, MNR Loss requires a larger batch size for better performance. - GISTEmbedLoss yields stronger training signals than MNR Loss due to the use of a guide model for in-batch negative sample selection. Meanwhile, CachedMNR Loss allows for scaling of the batch size by dividing the computation into two stages of embedding and loss calculation, which both can be scaled by mini-batches(https://arxiv.org/pdf/2101.06983.pdf). By combining the guided selection from GISTEmbedLoss and Gradient Cache by CachedMNRLoss, it is possible to reduce memory usage while maintaining performance levels comparable to those of GISTEmbedLoss. + This loss is a combination of :class:`GISTEmbedLoss` and :class:`CachedMultipleNegativesRankingLoss`. + Typically, :class:`MultipleNegativesRankingLoss` requires a larger batch size for better performance. + :class:`GISTEmbedLoss` yields stronger training signals than :class:`MultipleNegativesRankingLoss` due to the + use of a guide model for in-batch negative sample selection. Meanwhile, :class:`CachedMultipleNegativesRankingLoss` + allows for scaling of the batch size by dividing the computation into two stages of embedding and loss + calculation, which both can be scaled by mini-batches (https://arxiv.org/pdf/2101.06983.pdf). + + By combining the guided selection from :class:`GISTEmbedLoss` and Gradient Cache from + :class:`CachedMultipleNegativesRankingLoss`, it is possible to reduce memory usage while maintaining performance + levels comparable to those of :class:`GISTEmbedLoss`. :param model: SentenceTransformer model :param guide: SentenceTransformer model to guide the in-batch negative sample selection. From 6a6455d8c6705ad63be03754ebbb87005d0dcc8a Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 16 Apr 2024 10:53:48 +0200 Subject: [PATCH 8/9] Add CachedGIST to loss API --- docs/package_reference/losses.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/package_reference/losses.md b/docs/package_reference/losses.md index 4da18c5ad..65475427d 100644 --- a/docs/package_reference/losses.md +++ b/docs/package_reference/losses.md @@ -80,6 +80,11 @@ This allows our network to be fine-tuned to recognize the similarity of sentence .. autoclass:: sentence_transformers.losses.GISTEmbedLoss ``` +## CachedGISTEmbedLoss +```eval_rst +.. autoclass:: sentence_transformers.losses.CachedGISTEmbedLoss +``` + ## MSELoss ```eval_rst .. autoclass:: sentence_transformers.losses.MSELoss From fccc92a28018041c19487ac412364ec1510be3b9 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Tue, 16 Apr 2024 10:53:57 +0200 Subject: [PATCH 9/9] Add CachedGIST to Loss Overview --- docs/training/loss_overview.md | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/training/loss_overview.md b/docs/training/loss_overview.md index 9a7951e9d..4a756a818 100644 --- a/docs/training/loss_overview.md +++ b/docs/training/loss_overview.md @@ -4,17 +4,17 @@ Loss functions play a critical role in the performance of your fine-tuned model. **Note**: you can often convert one training data format into another, allowing more loss functions to be viable for your scenario. For example, `(sentence_A, sentence_B) pairs` with `class` labels can be converted into `(anchor, positive, negative) triplets` by sampling sentences with the same or different classes. -| Texts | Labels | Appropriate Loss Functions | -|-----------------------------------------------|--------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `single sentences` | `class` | `BatchAllTripletLoss`
`BatchHardSoftMarginTripletLoss`
`BatchHardTripletLoss`
`BatchSemiHardTripletLoss` | -| `single sentences` | `none` | `ContrastiveTensionLoss`
`DenoisingAutoEncoderLoss` | -| `(anchor, anchor) pairs` | `none` | `ContrastiveTensionLossInBatchNegatives` | -| `(damaged_sentence, original_sentence) pairs` | `none` | `DenoisingAutoEncoderLoss` | -| `(sentence_A, sentence_B) pairs` | `class` | `SoftmaxLoss` | -| `(anchor, positive) pairs` | `none` | `CachedMultipleNegativesRankingLoss`
`MultipleNegativesRankingLoss`
`MultipleNegativesSymmetricRankingLoss`
`MegaBatchMarginLoss`
`GISTEmbedLoss` | -| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | `ContrastiveLoss`
`OnlineContrastiveLoss` | -| `(sentence_A, sentence_B) pairs` | `float similarity score` | `CoSENTLoss`
`AnglELoss`
`CosineSimilarityLoss` | -| `(anchor, positive, negative) triplets` | `none` | `CachedMultipleNegativesRankingLoss`
`MultipleNegativesRankingLoss`
`TripletLoss`
`GISTEmbedLoss` | +| Texts | Labels | Appropriate Loss Functions | +|-----------------------------------------------|--------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `single sentences` | `class` | `BatchAllTripletLoss`
`BatchHardSoftMarginTripletLoss`
`BatchHardTripletLoss`
`BatchSemiHardTripletLoss` | +| `single sentences` | `none` | `ContrastiveTensionLoss`
`DenoisingAutoEncoderLoss` | +| `(anchor, anchor) pairs` | `none` | `ContrastiveTensionLossInBatchNegatives` | +| `(damaged_sentence, original_sentence) pairs` | `none` | `DenoisingAutoEncoderLoss` | +| `(sentence_A, sentence_B) pairs` | `class` | `SoftmaxLoss` | +| `(anchor, positive) pairs` | `none` | `CachedMultipleNegativesRankingLoss`
`MultipleNegativesRankingLoss`
`MultipleNegativesSymmetricRankingLoss`
`MegaBatchMarginLoss`
`CachedGISTEmbedLoss`
`GISTEmbedLoss` | +| `(anchor, positive/negative) pairs` | `1 if positive, 0 if negative` | `ContrastiveLoss`
`OnlineContrastiveLoss` | +| `(sentence_A, sentence_B) pairs` | `float similarity score` | `CoSENTLoss`
`AnglELoss`
`CosineSimilarityLoss` | +| `(anchor, positive, negative) triplets` | `none` | `CachedMultipleNegativesRankingLoss`
`MultipleNegativesRankingLoss`
`TripletLoss`
`CachedGISTEmbedLoss`
`GISTEmbedLoss` | ## Loss modifiers