Skip to content

Commit

Permalink
updated docs
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinMusgrave committed Jan 29, 2023
1 parent 0099529 commit 4a94285
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 61 deletions.
16 changes: 6 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,22 +106,18 @@ This customized triplet loss has the following properties:

### Using loss functions for unsupervised / self-supervised learning

The TripletMarginLoss is an embedding-based or tuple-based loss. This means that internally, there is no real notion of "classes". Tuples (pairs or triplets) are formed at each iteration, based on the labels it receives. The labels don't have to represent classes. They simply need to indicate the positive and negative relationships between the embeddings. Thus, it is easy to use these loss functions for unsupervised or self-supervised learning.

For example, the code below is a simplified version of the augmentation strategy commonly used in self-supervision. The dataset does not come with any labels. Instead, the labels are created in the training loop, solely to indicate which embeddings are positive pairs.
A `SelfSupervisedLoss` wrapper is provided for self-supervised learning:

```python
from pytorch_metric_learning.losses import SelfSupervisedLoss
loss_func = SelfSupervisedLoss(TripletMarginLoss())

# your training for-loop
for i, data in enumerate(dataloader):
optimizer.zero_grad()
embeddings = your_model(data)
augmented = your_model(your_augmentation(data))
labels = torch.arange(embeddings.size(0))

embeddings = torch.cat([embeddings, augmented], dim=0)
labels = torch.cat([labels, labels], dim=0)

loss = loss_func(embeddings, labels)
loss = loss_func(embeddings, augmented)
loss.backward()
optimizer.step()
```
Expand Down Expand Up @@ -229,7 +225,7 @@ Thanks to the contributors who made pull requests!
| Contributor | Highlights |
| -- | -- |
|[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets <br/> - Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons |
|[cwkeam](https://github.com/cwkeam) | - [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss) <br/> - Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) |
|[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss) <br/> - [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss) <br/> - Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) |
|[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer) <br/> - [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss) <br/> - [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester) <br/> - [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) |
| [chingisooinar](https://github.com/chingisooinar) | [SubCenterArcFaceLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#subcenterarcfaceloss) |
| [elias-ramzi](https://github.com/elias-ramzi) | [HierarchicalSampler](https://kevinmusgrave.github.io/pytorch-metric-learning/samplers/#hierarchicalsampler) |
Expand Down
9 changes: 0 additions & 9 deletions docs/accuracy_calculation.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,3 @@ labels = torch.tensor([
])
```


### Warning for versions <= 0.9.97

The behavior of the ```k``` parameter described in the [Parameters](#parameters) section is for versions >= 0.9.98.

For versions <= 0.9.97, the behavior was:

* If ```k = None```, then ```k = min(1023, max(bincount(reference_labels)))```
* Otherwise ```k = k```
14 changes: 5 additions & 9 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,18 @@ This customized triplet loss has the following properties:

### Using loss functions for unsupervised / self-supervised learning

The TripletMarginLoss is an embedding-based or tuple-based loss. This means that internally, there is no real notion of "classes". Tuples (pairs or triplets) are formed at each iteration, based on the labels it receives. The labels don't have to represent classes. They simply need to indicate the positive and negative relationships between the embeddings. Thus, it is easy to use these loss functions for unsupervised or self-supervised learning.

For example, the code below is a simplified version of the augmentation strategy commonly used in self-supervision. The dataset does not come with any labels. Instead, the labels are created in the training loop, solely to indicate which embeddings are positive pairs.
A `SelfSupervisedLoss` wrapper is provided for self-supervised learning:

```python
from pytorch_metric_learning.losses import SelfSupervisedLoss
loss_func = SelfSupervisedLoss(TripletMarginLoss())

# your training for-loop
for i, data in enumerate(dataloader):
optimizer.zero_grad()
embeddings = your_model(data)
augmented = your_model(your_augmentation(data))
labels = torch.arange(embeddings.size(0))

embeddings = torch.cat([embeddings, augmented], dim=0)
labels = torch.cat([labels, labels], dim=0)

loss = loss_func(embeddings, labels)
loss = loss_func(embeddings, augmented)
loss.backward()
optimizer.step()
```
Expand Down
54 changes: 28 additions & 26 deletions docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,19 +321,21 @@ As shown above, CrossBatchMemory comes with a 4th argument in its ```forward```
* **enqueue_mask**: A boolean tensor where `enqueue_mask[i]` is True if `embeddings[i]` should be added to the memory queue. This enables CrossBatchMemory to be used in self-supervision frameworks like [MoCo](https://arxiv.org/pdf/1911.05722.pdf). Check out the [MoCo on CIFAR100](https://github.com/KevinMusgrave/pytorch-metric-learning/tree/master/examples#simple-examples) notebook to see how this works.


**Supported Loss Functions**:
- [AngularLoss](losses.md#AngularLoss)
- [CircleLoss](losses.md#CircleLoss)
- [ContrastiveLoss](losses.md#ContrastiveLoss)
- [GeneralizedLiftedStructureLoss](losses.md#GeneralizedLiftedStructureLoss)
- [IntraPairVarianceLoss](losses.md#IntraPairVarianceLoss)
- [LiftedStructureLoss](losses.md#LiftedStructureLoss)
- [MultiSimilarityLoss](losses.md#MultiSimilarityLoss)
- [NTXentLoss](losses.md#NTXentLoss)
- [SignalToNoiseRatioContrastiveLoss](losses.md#SignalToNoiseRatioContrastiveLoss)
- [SupConLoss](losses.md#SupConLoss)
- [TripletMarginLoss](losses.md#TripletMarginLoss)
- [TupletMarginLoss](losses.md#TupletMarginLoss)
??? note "Supported Loss Functions"
- [AngularLoss](losses.md#angularloss)
- [CircleLoss](losses.md#circleloss)
- [ContrastiveLoss](losses.md#contrastiveloss)
- [GeneralizedLiftedStructureLoss](losses.md#generalizedliftedstructureloss)
- [IntraPairVarianceLoss](losses.md#intrapairvarianceloss)
- [LiftedStructureLoss](losses.md#liftedstructureloss)
- [MarginLoss](losses.md#marginloss)
- [MultiSimilarityLoss](losses.md#multisimilarityloss)
- [NCALoss](losses.md#ncaloss)
- [NTXentLoss](losses.md#ntxentloss)
- [SignalToNoiseRatioContrastiveLoss](losses.md#signaltonoiseratiocontrastiveloss)
- [SupConLoss](losses.md#supconloss)
- [TripletMarginLoss](losses.md#tripletmarginloss)
- [TupletMarginLoss](losses.md#tupletmarginloss)


**Reset queue**
Expand Down Expand Up @@ -839,25 +841,25 @@ loss_optimizer.step()

## SelfSupervisedLoss

A common use case is to have embeddings and ref_emb be augmented versions of each other. For most losses right now you have to create labels to indicate which embeddings correspond with which ref_emb. `SelfSupervisedLoss` automates this.
A common use case is to have `embeddings` and `ref_emb` be augmented versions of each other. For most losses, you have to create labels to indicate which `embeddings` correspond with which `ref_emb`. `SelfSupervisedLoss` automates this.

```python
loss_fn = losses.TripletMarginLoss()
loss_fn = SelfSupervisedLoss(loss_fn)
loss = loss_fn(embeddings, labels)
loss = loss_fn(embeddings, ref_emb)
```

**Supported Loss Functions**:
- [AngularLoss](losses.md#AngularLoss)
- [CircleLoss](losses.md#CircleLoss)
- [ContrastiveLoss](losses.md#ContrastiveLoss)
- [IntraPairVarianceLoss](losses.md#IntraPairVarianceLoss)
- [MultiSimilarityLoss](losses.md#MultiSimilarityLoss)
- [NTXentLoss](losses.md#NTXentLoss)
- [SignalToNoiseRatioContrastiveLoss](losses.md#SignalToNoiseRatioContrastiveLoss)
- [SupConLoss](losses.md#SupConLoss)
- [TripletMarginLoss](losses.md#TripletMarginLoss)
- [TupletMarginLoss](losses.md#TupletMarginLoss)
??? "Supported Loss Functions"
- [AngularLoss](losses.md#angularloss)
- [CircleLoss](losses.md#circleloss)
- [ContrastiveLoss](losses.md#contrastiveloss)
- [IntraPairVarianceLoss](losses.md#intrapairvarianceloss)
- [MultiSimilarityLoss](losses.md#multisimilarityloss)
- [NTXentLoss](losses.md#ntxentloss)
- [SignalToNoiseRatioContrastiveLoss](losses.md#signaltonoiseratiocontrastiveloss)
- [SupConLoss](losses.md#supconloss)
- [TripletMarginLoss](losses.md#tripletmarginloss)
- [TupletMarginLoss](losses.md#tupletmarginloss)



Expand Down
8 changes: 3 additions & 5 deletions docs/miners.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ Mining functions take a batch of ```n``` embeddings and return ```k``` pairs/tri
- Triplet miners output a tuple of size 3: (anchors, positives, negatives).
- Without a tuple miner, loss functions will by default use all possible pairs/triplets in the batch.

You might be familiar with the terminology: "online" and "offline" miners. Tuple miners are online. Offline miners should be implemented as a [PyTorch Sampler](samplers.md).
You might be familiar with the terminology: "online" and "offline" miners. These miners are online. Offline miners should be implemented as a [PyTorch Sampler](samplers.md).

Tuple miners are used with loss functions as follows:
Miners are used with loss functions as follows:

```python
from pytorch_metric_learning import miners, losses
Expand Down Expand Up @@ -39,13 +39,11 @@ All miners extend this class and therefore inherit its ```__init__``` parameters
miners.BaseMiner(collect_stats=False, distance=None)
```

It outputs a tuple of indices:
Every miner outputs a tuple of indices:

* Pair miners output a tuple of size 4: (anchors, positives, anchors, negatives)
* Triplet miners output a tuple of size 3: (anchors, positives, negatives)

If you write your own miner, the ```mine``` function should work such that anchor indices correspond to ```embeddings``` and ```labels```, and all other indices correspond to ```ref_emb``` and ```ref_labels```. By default, ```embeddings == ref_emb``` and ```labels == ref_labels```, but separating the anchor source from the positive/negative source allows for interesting use cases. For example, see [CrossBatchMemory](losses.md#crossbatchmemory).

See [custom miners](extend/miners.md) for details on how to write your own miner.

**Parameters**:
Expand Down
4 changes: 2 additions & 2 deletions docs/trainers.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ trainers.BaseTrainer(models,
optimizers,
batch_size,
loss_funcs,
mining_funcs,
dataset,
mining_funcs=None,
iterations_per_epoch=None,
data_device=None,
dtype=None,
Expand Down Expand Up @@ -47,9 +47,9 @@ trainers.BaseTrainer(models,
* **batch_size**: The number of elements that are retrieved at each iteration.
* **loss_funcs**: A dictionary mapping strings to loss functions. The required keys depend on the training method, but all methods are likely to require at least:
* {"metric_loss": loss_func}.
* **dataset**: The dataset you want to train on. Note that training methods do not perform validation, so do not pass in your validation or test set.
* **mining_funcs**: A dictionary mapping strings to mining functions. Pass in an empty dictionary, or one or more of the following keys:
* {"subset_batch_miner": mining_func1, "tuple_miner": mining_func2}
* **dataset**: The dataset you want to train on. Note that training methods do not perform validation, so do not pass in your validation or test set.
* **data_device**: The device that you want to put batches of data on. If not specified, the trainer will put the data on any available GPUs.
* **dtype**: The type that the dataset output will be converted to, e.g. ```torch.float16```. If set to ```None```, then no type casting will be done.
* **iterations_per_epoch**: Optional.
Expand Down

0 comments on commit 4a94285

Please sign in to comment.