Skip to content

Commit

Permalink
Applying Center Loss (#213)
Browse files Browse the repository at this point in the history
* Applying Center Loss

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed minor typos and errors + Center loss implementation

* sphinx version change

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
parthkl021 and pre-commit-ci[bot] authored Dec 29, 2023
1 parent 2371ddf commit db4f455
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/generate_docs_netlify.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ poetry build -f wheel
pip install dist/$(ls -1 dist | grep .whl)
pip install pytorch-metric-learning==1.3.2

pip install sphinx>=5.0.1
pip install sphinx==6.1.3
pip install "git+https://github.com/qdrant/qdrant_sphinx_theme.git@master#egg=qdrant-sphinx-theme"

sphinx-apidoc --force --separate --no-toc -o docs/source quaterion
Expand Down
3 changes: 2 additions & 1 deletion docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ Implementations
~softmax_loss.SoftmaxLoss
~triplet_loss.TripletLoss
~circle_loss.CircleLoss
~fastap_loss.FastAPLoss
~fast_ap_loss.FastAPLoss
~cos_face_loss.CosFaceLoss
~center_loss.CenterLoss

Extras
++++++
Expand Down
7 changes: 7 additions & 0 deletions docs/source/quaterion.loss.center_loss.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quaterion.loss.center\_loss module
==================================

.. automodule:: quaterion.loss.center_loss
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/quaterion.loss.circle_loss.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quaterion.loss.circle\_loss module
==================================

.. automodule:: quaterion.loss.circle_loss
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/quaterion.loss.cos_face_loss.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quaterion.loss.cos\_face\_loss module
=====================================

.. automodule:: quaterion.loss.cos_face_loss
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/quaterion.loss.fast_ap_loss.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quaterion.loss.fast\_ap\_loss module
====================================

.. automodule:: quaterion.loss.fast_ap_loss
:members:
:undoc-members:
:show-inheritance:
4 changes: 4 additions & 0 deletions docs/source/quaterion.loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ Submodules
:maxdepth: 4

quaterion.loss.arcface_loss
quaterion.loss.center_loss
quaterion.loss.circle_loss
quaterion.loss.contrastive_loss
quaterion.loss.cos_face_loss
quaterion.loss.fast_ap_loss
quaterion.loss.group_loss
quaterion.loss.multiple_negatives_ranking_loss
quaterion.loss.online_contrastive_loss
Expand Down
2 changes: 1 addition & 1 deletion docs/source/tutorials/triplet_loss_trick.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Triplet Loss: Vector Collapse Prevention
============================
========================================

Triplet Loss is one of the most widely known loss functions in similarity learning.
If you want to deep-dive into the details of its implementations and advantages,
Expand Down
1 change: 1 addition & 0 deletions quaterion/loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from quaterion.loss.arcface_loss import ArcFaceLoss
from quaterion.loss.center_loss import CenterLoss
from quaterion.loss.circle_loss import CircleLoss
from quaterion.loss.contrastive_loss import ContrastiveLoss
from quaterion.loss.cos_face_loss import CosFaceLoss
Expand Down
57 changes: 57 additions & 0 deletions quaterion/loss/center_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import LongTensor, Tensor

from quaterion.loss.group_loss import GroupLoss
from quaterion.utils import l2_norm


class CenterLoss(GroupLoss):
"""
Center Loss as defined in the paper "A Discriminative Feature Learning Approach
for Deep Face Recognition" (http://ydwen.github.io/papers/WenECCV16.pdf)
It aims to minimize the intra-class variations while keeping the features of
different classes separable.
Args:
embedding_size: Output dimension of the encoder.
num_groups: Number of groups (classes) in the dataset.
lambda_c: A regularization parameter that controls the contribution of the center loss.
"""

def __init__(
self, embedding_size: int, num_groups: int, lambda_c: Optional[float] = 0.5
):
super(GroupLoss, self).__init__()
self.num_groups = num_groups
self.centers = nn.Parameter(torch.randn(num_groups, embedding_size))
self.lambda_c = lambda_c

nn.init.xavier_uniform_(self.centers)

def forward(self, embeddings: Tensor, groups: LongTensor) -> Tensor:
"""
Compute the Center Loss value.
Args:
embeddings: shape (batch_size, vector_length) - Output embeddings from the encoder.
groups: shape (batch_size,) - Group (class) ids associated with embeddings.
Returns:
Tensor: loss value.
"""
embeddings = l2_norm(embeddings, 1)

# Gather the center for each embedding's corresponding group
centers_batch = self.centers.index_select(0, groups)

# Calculate the distance between embeddings and their respective class centers
loss = F.mse_loss(embeddings, centers_batch)

# Scale the loss by the regularization parameter
loss *= self.lambda_c

return loss
32 changes: 32 additions & 0 deletions tests/eval/losses/test_center_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

from quaterion.loss import CenterLoss


class TestCenterLoss:
embeddings = torch.Tensor(
[
[0.0, -1.0, 0.5],
[0.1, 2.0, 0.5],
[0.0, 0.3, 0.2],
[1.0, 0.0, 0.9],
[1.2, -1.2, 0.01],
[-0.7, 0.0, 1.5],
]
)
groups = torch.LongTensor([1, 2, 0, 0, 2, 1])

def test_batch_all(self):
# Initialize the CenterLoss
loss = CenterLoss(embedding_size=self.embeddings.size()[1], num_groups=3)

# Calculate the loss
loss_res = loss.forward(embeddings=self.embeddings, groups=self.groups)

# Assertions to check the output shape and type
assert isinstance(
loss_res, torch.Tensor
), "Loss result should be a torch.Tensor"
assert loss_res.shape == torch.Size(
[]
), "Loss result should be a scalar (0-dimension tensor)"

0 comments on commit db4f455

Please sign in to comment.