Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Implementing SmoothAP loss #721

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/imgs/smooth_ap_approx_equation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/imgs/smooth_ap_loss_equation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/imgs/smooth_ap_sigmoid_equation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 31 additions & 0 deletions docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,37 @@ losses.SignalToNoiseRatioContrastiveLoss(pos_margin=0, neg_margin=1, **kwargs):
* **pos_loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```.
* **neg_loss**: The loss per negative pair in the batch. Reduction type is ```"neg_pair"```.

## SmoothAPLoss
[Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval](https://arxiv.org/abs/2007.12163){target=_blank}

```python
losses.SmoothAPLoss(
margin=0.01,
**kwargs
)
```

**Equations**:

![smooth_ap_loss_equation1](imgs/smooth_ap_sigmoid_equation.png){: style="height:100px"}
![smooth_ap_loss_equation2](imgs/smooth_ap_approx_equation.png){: style="height:100px"}
![smooth_ap_loss_equation3](imgs/smooth_ap_loss_equation.png){: style="height:100px"}


**Parameters**:

* **temperature**: The desired temperature for scaling the sigmoid function. This is denoted by $\tau$ in the first and second equations.


**Other info**:

* The loss requires the same number of number of elements for each class in the batch labels. An example of valid labels is: `[1, 1, 2, 2, 3, 3]`. An example of invalid labels is `[1, 1, 1, 2, 2, 3, 3]` because there are `3` elements with the value `1`. This can be achieved by using `samplers.MPerClassSampler` and setting the `batch_size` and `m` hyperparameters.

**Default distance**:

- [```CosineSimilarity()```](distances.md#cosinesimilarity)
- This is the only compatible distance.

## SoftTripleLoss
[SoftTriple Loss: Deep Metric Learning Without Triplet Sampling](http://openaccess.thecvf.com/content_ICCV_2019/papers/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.pdf){target=_blank}
```python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def __init__(self, **kwargs):
assert self.is_inverted

def compute_mat(self, query_emb, ref_emb):
return torch.matmul(query_emb, ref_emb.t())
return torch.matmul(query_emb, ref_emb.transpose(-1, -2))

def pairwise_distance(self, query_emb, ref_emb):
return torch.sum(query_emb * ref_emb, dim=1)
1 change: 1 addition & 0 deletions src/pytorch_metric_learning/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .ranked_list_loss import RankedListLoss
from .self_supervised_loss import SelfSupervisedLoss
from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss
from .smooth_ap import SmoothAPLoss
from .soft_triple_loss import SoftTripleLoss
from .sphereface_loss import SphereFaceLoss
from .subcenter_arcface_loss import SubCenterArcFaceLoss
Expand Down
103 changes: 103 additions & 0 deletions src/pytorch_metric_learning/losses/smooth_ap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import torch.nn.functional as F

from ..distances import CosineSimilarity
from ..utils import common_functions as c_f
from ..utils import loss_and_miner_utils as lmu
from .base_metric_loss_function import BaseMetricLossFunction


class SmoothAPLoss(BaseMetricLossFunction):
"""
Implementation of the SmoothAP loss: https://arxiv.org/abs/2007.12163
"""

def __init__(self, temperature=0.01, **kwargs):
super().__init__(**kwargs)
c_f.assert_distance_type(self, CosineSimilarity)
self.temperature = temperature

def get_default_distance(self):
return CosineSimilarity()

# Implementation is based on the original repository:
# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py#L87
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
# The loss expects labels such that there is the same number of elements for each class
# The number of classes is not important, nor their order, but the number of elements must be the same, eg.
#
# The following label is valid:
# [ A,A,A, B,B,B, C,C,C ]
# The following label is NOT valid:
# [ B,B,B A,A,A,A, C,C,C ]
#
c_f.labels_required(labels)
c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)

counts = torch.bincount(labels)
nonzero_indices = torch.nonzero(counts, as_tuple=True)[0]
nonzero_counts = counts[nonzero_indices]
if nonzero_counts.unique().size(0) != 1:
raise ValueError(
"All classes must have the same number of elements in the labels.\n"
"The given labels have the following number of elements: {}.\n"
"You can achieve this using the samplers.MPerClassSampler class and setting the batch_size and m.".format(
nonzero_counts.cpu().tolist()
)
)

batch_size = embeddings.size(0)
num_classes_batch = batch_size // torch.unique(labels).size(0)

mask = 1.0 - torch.eye(batch_size)
mask = mask.unsqueeze(dim=0).repeat(batch_size, 1, 1)

sims = self.distance(embeddings)

sims_repeat = sims.unsqueeze(dim=1).repeat(1, batch_size, 1)
sims_diff = sims_repeat - sims_repeat.permute(0, 2, 1)
sims_sigm = F.sigmoid(sims_diff / self.temperature) * mask.to(sims_diff.device)
sims_ranks = torch.sum(sims_sigm, dim=-1) + 1

xs = embeddings.view(
num_classes_batch, batch_size // num_classes_batch, embeddings.size(-1)
)
pos_mask = 1.0 - torch.eye(batch_size // num_classes_batch)
pos_mask = (
pos_mask.unsqueeze(dim=0)
.unsqueeze(dim=0)
.repeat(num_classes_batch, batch_size // num_classes_batch, 1, 1)
)

# Circumvent the shape check in forward method
xs_norm = self.distance.maybe_normalize(xs, dim=-1)
sims_pos = self.distance.compute_mat(xs_norm, xs_norm)

sims_pos_repeat = sims_pos.unsqueeze(dim=2).repeat(
1, 1, batch_size // num_classes_batch, 1
)
sims_pos_diff = sims_pos_repeat - sims_pos_repeat.permute(0, 1, 3, 2)

sims_pos_sigm = F.sigmoid(sims_pos_diff / self.temperature) * pos_mask.to(
sims_diff.device
)
sims_pos_ranks = torch.sum(sims_pos_sigm, dim=-1) + 1

g = batch_size // num_classes_batch
ap = torch.zeros(batch_size).to(embeddings.device)
for i in range(num_classes_batch):
for j in range(g):
pos_rank = sims_pos_ranks[i, j]
all_rank = sims_ranks[i * g + j, i * g : (i + 1) * g]
ap[i * g + j] = torch.sum(pos_rank / all_rank) / g

miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=ap.dtype)
loss = (1 - ap) * miner_weights

return {
"ap_loss": {
"losses": loss,
"indices": c_f.torch_arange_from_size(loss),
"reduction_type": "element",
}
}
191 changes: 191 additions & 0 deletions tests/losses/test_smooth_ap_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import unittest

import torch
import torch.nn.functional as F

from pytorch_metric_learning.losses import SmoothAPLoss

from .. import TEST_DEVICE, TEST_DTYPES

HYPERPARAMETERS = {
"temp": 0.01,
"batch_size": 60,
"num_id": 6,
"feat_dims": 256,
}
TEST_SEEDS = [42, 1234, 5642, 9999, 3459]


# Original implementation of the SmoothAP loss taken from:
# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py
def sigmoid(tensor, temp=1.0):
"""temperature controlled sigmoid

takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
"""
exponent = -tensor / temp
# clamp the input tensor for stability
exponent = torch.clamp(exponent, min=-50, max=50)
y = 1.0 / (1.0 + torch.exp(exponent))
return y


def compute_aff(x):
"""computes the affinity matrix between an input vector and itself"""
return torch.mm(x, x.t())


class SmoothAP(torch.nn.Module):
"""PyTorch implementation of the Smooth-AP loss.

implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
have the same number of instances represented in the mini-batch and must be ordered sequentially by class.

e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:

labels = ( A, A, A, B, B, B, C, C, C)

(the order of the classes however does not matter)

For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
same class. The loss returns the average Smooth-AP across all instances in the mini-batch.

Args:
anneal : float
the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
batch_size : int
the batch size being used during training.
num_id : int
the number of different classes that are represented in the batch.
feat_dims : int
the dimension of the input feature embeddings

Shape:
- Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
- Output: scalar

Examples::

>>> loss = SmoothAP(0.01, 60, 6, 256)
>>> input = torch.randn(60, 256, requires_grad=True).to("cuda:0")
>>> output = loss(input)
>>> output.backward()
"""

def __init__(self, anneal, batch_size, num_id, feat_dims):
"""
Parameters
----------
anneal : float
the temperature of the sigmoid that is used to smooth the ranking function
batch_size : int
the batch size being used
num_id : int
the number of different classes that are represented in the batch
feat_dims : int
the dimension of the input feature embeddings
"""
super(SmoothAP, self).__init__()

assert batch_size % num_id == 0

self.anneal = anneal
self.batch_size = batch_size
self.num_id = num_id
self.feat_dims = feat_dims

def forward(self, preds):
"""Forward pass for all input predictions: preds - (batch_size x feat_dims)"""

# ------ differentiable ranking of all retrieval set ------
# compute the mask which ignores the relevance score of the query to itself
mask = 1.0 - torch.eye(self.batch_size)
mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
# compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
sim_all = compute_aff(preds)
sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
# compute the difference matrix
sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
# pass through the sigmoid
sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask.to(TEST_DEVICE)
# compute the rankings
sim_all_rk = torch.sum(sim_sg, dim=-1) + 1

# ------ differentiable ranking of only positive set in retrieval set ------
# compute the mask which only gives non-zero weights to the positive set
xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims)
pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id))
pos_mask = (
pos_mask.unsqueeze(dim=0)
.unsqueeze(dim=0)
.repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1)
)

# compute the relevance scores
sim_pos = torch.bmm(xs, xs.permute(0, 2, 1))
sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(
1, 1, int(self.batch_size / self.num_id), 1
)
# compute the difference matrix
sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)
# pass through the sigmoid
sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask.to(TEST_DEVICE)
# compute the rankings of the positive set
sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1

# sum the values of the Smooth-AP for all instances in the mini-batch
ap = torch.zeros(1).to(TEST_DEVICE)
group = int(self.batch_size / self.num_id)
for ind in range(self.num_id):
pos_divide = torch.sum(
sim_pos_rk[ind]
/ (
sim_all_rk[
(ind * group) : ((ind + 1) * group),
(ind * group) : ((ind + 1) * group),
]
)
)
ap = ap + ((pos_divide / group) / self.batch_size)

return 1 - ap


class TestSmoothAPLoss(unittest.TestCase):
def test_smooth_ap_loss(self):
for dtype in TEST_DTYPES:
for seed in TEST_SEEDS:
torch.manual_seed(seed)
loss = SmoothAP(
HYPERPARAMETERS["temp"],
HYPERPARAMETERS["batch_size"],
HYPERPARAMETERS["num_id"],
HYPERPARAMETERS["feat_dims"],
)
rand_tensor = (
torch.randn(
HYPERPARAMETERS["batch_size"],
HYPERPARAMETERS["feat_dims"],
requires_grad=True,
)
.to(TEST_DEVICE)
.to(dtype)
)
# The original code uses a model that normalizes the output vector
input_ = F.normalize(rand_tensor, p=2.0, dim=-1)
output = loss(input_)

loss2 = SmoothAPLoss(temperature=HYPERPARAMETERS["temp"])
# The original code assumes the label is in this format
labels = []
for i in range(
HYPERPARAMETERS["batch_size"] // HYPERPARAMETERS["num_id"]
):
labels.extend([i for _ in range(HYPERPARAMETERS["num_id"])])

labels = torch.tensor(labels)
output2 = loss2.forward(rand_tensor, labels)
self.assertTrue(torch.isclose(output, output2))
Loading