From a1b356d45c5fd5d28743f3fb40427f37f3397b56 Mon Sep 17 00:00:00 2001 From: expnoob Date: Thu, 9 Mar 2023 05:51:58 +0530 Subject: [PATCH] added circle loss --- .gitignore | 1 + docs/source/api/index.rst | 1 + quaterion/loss/__init__.py | 1 + quaterion/loss/circle_loss.py | 58 +++++++++++++++++++++++++++ tests/eval/losses/test_circle_loss.py | 25 ++++++++++++ 5 files changed, 86 insertions(+) create mode 100644 quaterion/loss/circle_loss.py create mode 100644 tests/eval/losses/test_circle_loss.py diff --git a/.gitignore b/.gitignore index b4fcd0cf..c040efc9 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,4 @@ lightning_logs/ # dataset and model downloads torchvision/ cache_dir/ +.vscode/ \ No newline at end of file diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 1e9a9542..b038aff1 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -127,6 +127,7 @@ Implementations ~multiple_negatives_ranking_loss.MultipleNegativesRankingLoss ~softmax_loss.SoftmaxLoss ~triplet_loss.TripletLoss + ~circle_loss.CircleLoss Extras ++++++ diff --git a/quaterion/loss/__init__.py b/quaterion/loss/__init__.py index 1f9ef757..9541af6e 100644 --- a/quaterion/loss/__init__.py +++ b/quaterion/loss/__init__.py @@ -7,3 +7,4 @@ from quaterion.loss.similarity_loss import SimilarityLoss from quaterion.loss.softmax_loss import SoftmaxLoss from quaterion.loss.triplet_loss import TripletLoss +from quaterion.loss.circle_loss import CircleLoss diff --git a/quaterion/loss/circle_loss.py b/quaterion/loss/circle_loss.py new file mode 100644 index 00000000..255f2527 --- /dev/null +++ b/quaterion/loss/circle_loss.py @@ -0,0 +1,58 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import LongTensor, Tensor + +from quaterion.loss.group_loss import GroupLoss + + +class CircleLoss(GroupLoss): + """Implements Circle Loss as defined in https://arxiv.org/abs/2002.10857. + + Args: + margin: Margin value to push negative examples. + scale_factor: scale factor γ determines the largest scale of each similarity score. + """ + + def __init__(self, margin: Optional[float], scale_factor: Optional[float], distance_metric_name: Optional[Distance] = Distance.COSINE): + super(GroupLoss, self).__init__() + self.margin = margin + self.scale_factor = scale_factor + self.op = 1 + self._margin + self.on = -self._margin + self.delta_positive = 1 - self._margin + self.delta_negative = self._margin + + def forward( + self, + embeddings: Tensor, + groups: LongTensor, + ) -> Tensor: + """Compute loss value. + + Args: + embeddings: shape: (batch_size, vector_length) - Batch of embeddings. + groups: shape: (batch_size,) - Batch of labels associated with `embeddings` + + Returns: + Tensor: Scalar loss value. + """ + # Shape: (batch_size, batch_size) + dists = self.distance_metric.distance_matrix(embeddings) + # Calculate loss for all possible triplets first, then filter by group mask + # Shape: (batch_size, batch_size, 1) + sp = dists.unsqueeze(2) + # Shape: (batch_size, 1, batch_size) + sn = dists.unsqueeze(1) + # get alpha-positive and alpha-negative weights as described in https://arxiv.org/abs/2002.10857. + ap = torch.clamp_min(self.op + sp.detach(), min=0) + an = torch.clamp_min(self.on + sn.detach(), min=0) + + exp_p = - ap * self.scale_factor * (sp - self.delta_positive) + exp_n = an * self.scale_factor * (sn-self.delta_negative) + + circle_loss = F.softplus(torch.logsumexp(exp_n, dim=0) + torch.logsumexp(exp_p, dim=0)) + + return circle_loss diff --git a/tests/eval/losses/test_circle_loss.py b/tests/eval/losses/test_circle_loss.py new file mode 100644 index 00000000..7593fb26 --- /dev/null +++ b/tests/eval/losses/test_circle_loss.py @@ -0,0 +1,25 @@ +import torch + +from quaterion.loss import CircleLoss + + +class TestCircleLoss: + embeddings = torch.Tensor( + [ + [0.0, -1.0, 0.5], + [0.1, 2.0, 0.5], + [0.0, 0.3, 0.2], + [1.0, 0.0, 0.9], + [1.2, -1.2, 0.01], + [-0.7, 0.0, 1.5], + ] + ) + + groups = torch.LongTensor([1, 2, 3, 3, 2, 1]) + + def test_batch_all(self): + loss = CircleLoss(margin=0.5, scale_factor = 2) + + loss_res = loss.forward(embeddings=self.embeddings, groups=self.groups) + + assert loss_res.shape == torch.Size([])