Skip to content

Commit

Permalink
Enable one-hot-encoded labels in MixUp and CutMix (#8427)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
mahdilamb and NicolasHug authored May 28, 2024
1 parent 778ce48 commit c585a51
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 23 deletions.
35 changes: 19 additions & 16 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2169,26 +2169,30 @@ def test_image_correctness(self, brightness_factor):

class TestCutMixMixUp:
class DummyDataset:
def __init__(self, size, num_classes):
def __init__(self, size, num_classes, one_hot_labels):
self.size = size
self.num_classes = num_classes
self.one_hot_labels = one_hot_labels
assert size < num_classes

def __getitem__(self, idx):
img = torch.rand(3, 100, 100)
label = idx # This ensures all labels in a batch are unique and makes testing easier
if self.one_hot_labels:
label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes)
return img, label

def __len__(self):
return self.size

@pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
def test_supported_input_structure(self, T):
@pytest.mark.parametrize("one_hot_labels", (True, False))
def test_supported_input_structure(self, T, one_hot_labels):

batch_size = 32
num_classes = 100

dataset = self.DummyDataset(size=batch_size, num_classes=num_classes)
dataset = self.DummyDataset(size=batch_size, num_classes=num_classes, one_hot_labels=one_hot_labels)

cutmix_mixup = T(num_classes=num_classes)

Expand All @@ -2198,7 +2202,7 @@ def test_supported_input_structure(self, T):
img, target = next(iter(dl))
input_img_size = img.shape[-3:]
assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
assert target.shape == (batch_size,)
assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)

def check_output(img, target):
assert img.shape == (batch_size, *input_img_size)
Expand All @@ -2209,7 +2213,7 @@ def check_output(img, target):

# After Dataloader, as unpacked input
img, target = next(iter(dl))
assert target.shape == (batch_size,)
assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
img, target = cutmix_mixup(img, target)
check_output(img, target)

Expand Down Expand Up @@ -2264,30 +2268,29 @@ def test_error(self, T):
with pytest.raises(ValueError, match="Could not infer where the labels are"):
cutmix_mixup({"img": imgs, "Nothing_else": 3})

with pytest.raises(ValueError, match="labels tensor should be of shape"):
with pytest.raises(ValueError, match="labels should be index based"):
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
cutmix_mixup(imgs)

with pytest.raises(ValueError, match="When using the default labels_getter"):
cutmix_mixup(imgs, "not_a_tensor")

with pytest.raises(ValueError, match="labels tensor should be of shape"):
cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3)))

with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))

with pytest.raises(ValueError, match="does not match the batch size of the labels"):
cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))

with pytest.raises(ValueError, match="labels tensor should be of shape"):
# The purpose of this check is more about documenting the current
# behaviour of what happens on a Compose(), rather than actually
# asserting the expected behaviour. We may support Compose() in the
# future, e.g. for 2 consecutive CutMix?
labels = torch.randint(0, num_classes, size=(batch_size,))
transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels)
with pytest.raises(ValueError, match="When passing 2D labels"):
wrong_num_classes = num_classes + 1
T(alpha=0.5, num_classes=num_classes)(imgs, torch.randint(0, 2, size=(batch_size, wrong_num_classes)))

with pytest.raises(ValueError, match="but got a tensor of shape"):
cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3, 4)))

with pytest.raises(ValueError, match="num_classes must be passed"):
T(alpha=0.5)(imgs, torch.randint(0, num_classes, size=(batch_size,)))


@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
Expand Down
28 changes: 21 additions & 7 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import numbers
import warnings
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import PIL.Image
import torch
Expand Down Expand Up @@ -142,7 +142,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class _BaseMixUpCutMix(Transform):
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None:
super().__init__()
self.alpha = float(alpha)
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
Expand All @@ -162,10 +162,21 @@ def forward(self, *inputs):
labels = self._labels_getter(inputs)
if not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
elif labels.ndim != 1:
if labels.ndim not in (1, 2):
raise ValueError(
f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead."
f"labels should be index based with shape (batch_size,) "
f"or probability based with shape (batch_size, num_classes), "
f"but got a tensor of shape {labels.shape} instead."
)
if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes:
raise ValueError(
f"When passing 2D labels, "
f"the number of elements in last dimension must match num_classes: "
f"{labels.shape[-1]} != {self.num_classes}. "
f"You can Leave num_classes to None."
)
if labels.ndim == 1 and self.num_classes is None:
raise ValueError("num_classes must be passed if the labels are index-based (1D)")

params = {
"labels": labels,
Expand Down Expand Up @@ -198,7 +209,8 @@ def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int):
)

def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
label = one_hot(label, num_classes=self.num_classes)
if label.ndim == 1:
label = one_hot(label, num_classes=self.num_classes) # type: ignore[arg-type]
if not label.dtype.is_floating_point:
label = label.float()
return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam))
Expand All @@ -223,7 +235,8 @@ class MixUp(_BaseMixUpCutMix):
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
Expand Down Expand Up @@ -271,7 +284,8 @@ class CutMix(_BaseMixUpCutMix):
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
num_classes (int): number of classes in the batch. Used for one-hot-encoding.
num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.
Expand Down

0 comments on commit c585a51

Please sign in to comment.