Skip to content

Commit

Permalink
added emb similarity (#3349)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Sep 4, 2020
1 parent 7bd2f94 commit 227959b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
6 changes: 6 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ dice_score (F)
.. autofunction:: pytorch_lightning.metrics.functional.dice_score
:noindex:

embedding_similarity (F)
^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: pytorch_lightning.metrics.functional.embedding_similarity
:noindex:

f1_score (F)
^^^^^^^^^^^^

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@
rmsle,
ssim
)
from pytorch_lightning.metrics.functional.self_supervised import (
embedding_similarity
)
49 changes: 49 additions & 0 deletions pytorch_lightning/metrics/functional/self_supervised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch


def embedding_similarity(
batch: torch.Tensor,
similarity: str = 'cosine',
reduction: str = 'none',
zero_diagonal: bool = True
) -> torch.Tensor:
"""
Computes representation similarity
Example:
>>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]])
>>> embedding_similarity(embeddings)
tensor([[0.0000, 1.0000, 0.9759],
[1.0000, 0.0000, 0.9759],
[0.9759, 0.9759, 0.0000]])
Args:
batch: (batch, dim)
similarity: 'dot' or 'cosine'
reduction: 'none', 'sum', 'mean' (all along dim -1)
zero_diagonal: if True, the diagonals are set to zero
Return:
A square matrix (batch, batch) with the similarity scores between all elements
If sum or mean are used, then returns (b, 1) with the reduced value for each row
"""
if similarity == 'cosine':
norm = torch.norm(batch, p=2, dim=1)
batch = batch / norm.unsqueeze(1)

sqr_mtx = batch.mm(batch.transpose(1, 0))

if zero_diagonal:
sqr_mtx = sqr_mtx.fill_diagonal_(0)

if reduction == 'mean':
sqr_mtx = sqr_mtx.mean(dim=-1)

return sqr_mtx


if __name__ == '__main__':
a = torch.rand(3, 5)

print(embedding_similarity(a, 'cosine'))

0 comments on commit 227959b

Please sign in to comment.