Skip to content

Commit

Permalink
added circle loss
Browse files Browse the repository at this point in the history
  • Loading branch information
SudeepRed committed Mar 9, 2023
1 parent 6256088 commit a1b356d
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ lightning_logs/
# dataset and model downloads
torchvision/
cache_dir/
.vscode/
1 change: 1 addition & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ Implementations
~multiple_negatives_ranking_loss.MultipleNegativesRankingLoss
~softmax_loss.SoftmaxLoss
~triplet_loss.TripletLoss
~circle_loss.CircleLoss

Extras
++++++
Expand Down
1 change: 1 addition & 0 deletions quaterion/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from quaterion.loss.similarity_loss import SimilarityLoss
from quaterion.loss.softmax_loss import SoftmaxLoss
from quaterion.loss.triplet_loss import TripletLoss
from quaterion.loss.circle_loss import CircleLoss
58 changes: 58 additions & 0 deletions quaterion/loss/circle_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, Tensor

from quaterion.loss.group_loss import GroupLoss


class CircleLoss(GroupLoss):
"""Implements Circle Loss as defined in https://arxiv.org/abs/2002.10857.
Args:
margin: Margin value to push negative examples.
scale_factor: scale factor γ determines the largest scale of each similarity score.
"""

def __init__(self, margin: Optional[float], scale_factor: Optional[float], distance_metric_name: Optional[Distance] = Distance.COSINE):
super(GroupLoss, self).__init__()
self.margin = margin
self.scale_factor = scale_factor
self.op = 1 + self._margin
self.on = -self._margin
self.delta_positive = 1 - self._margin
self.delta_negative = self._margin

def forward(
self,
embeddings: Tensor,
groups: LongTensor,
) -> Tensor:
"""Compute loss value.
Args:
embeddings: shape: (batch_size, vector_length) - Batch of embeddings.
groups: shape: (batch_size,) - Batch of labels associated with `embeddings`
Returns:
Tensor: Scalar loss value.
"""
# Shape: (batch_size, batch_size)
dists = self.distance_metric.distance_matrix(embeddings)
# Calculate loss for all possible triplets first, then filter by group mask
# Shape: (batch_size, batch_size, 1)
sp = dists.unsqueeze(2)
# Shape: (batch_size, 1, batch_size)
sn = dists.unsqueeze(1)
# get alpha-positive and alpha-negative weights as described in https://arxiv.org/abs/2002.10857.
ap = torch.clamp_min(self.op + sp.detach(), min=0)
an = torch.clamp_min(self.on + sn.detach(), min=0)

exp_p = - ap * self.scale_factor * (sp - self.delta_positive)
exp_n = an * self.scale_factor * (sn-self.delta_negative)

circle_loss = F.softplus(torch.logsumexp(exp_n, dim=0) + torch.logsumexp(exp_p, dim=0))

return circle_loss
25 changes: 25 additions & 0 deletions tests/eval/losses/test_circle_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch

from quaterion.loss import CircleLoss


class TestCircleLoss:
embeddings = torch.Tensor(
[
[0.0, -1.0, 0.5],
[0.1, 2.0, 0.5],
[0.0, 0.3, 0.2],
[1.0, 0.0, 0.9],
[1.2, -1.2, 0.01],
[-0.7, 0.0, 1.5],
]
)

groups = torch.LongTensor([1, 2, 3, 3, 2, 1])

def test_batch_all(self):
loss = CircleLoss(margin=0.5, scale_factor = 2)

loss_res = loss.forward(embeddings=self.embeddings, groups=self.groups)

assert loss_res.shape == torch.Size([])

0 comments on commit a1b356d

Please sign in to comment.