-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable one-hot-encoded labels in MixUp and CutMix #8427
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/8427
Note: Links to docs will display an error until the docs builds have been completed. ❌ 12 New FailuresAs of commit 218fc58 with merge base 778ce48 (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Thanks for the PR @mahdilamb . Supporting labels that are already one-hot-encoded sounds OK to me, but instead of adding a new parameter, it seems that we could instead just check the shape of the labels and only call We would also need to add a few tests here vision/test/test_transforms_v2.py Line 2170 in 8b6c5e7
|
Hi @NicolasHug makes sense to me. I'll get that moving |
Hi @NicolasHug , that's updates as requested... it breaks |
Hi @mahdilamb - I've made a few changes to the PR locally but when I cannot push to update the PR, because you created from your Would you mind closing this one and re-opening a new PR from a dev branch (i.e. do Alternatively you could also apply this diff to the current PR: diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py
index 190b590c89..07235333af 100644
--- a/test/test_transforms_v2.py
+++ b/test/test_transforms_v2.py
@@ -2169,29 +2169,30 @@ class TestAdjustBrightness:
class TestCutMixMixUp:
class DummyDataset:
- def __init__(self, size, num_classes, encode_labels:bool):
+ def __init__(self, size, num_classes, one_hot_labels):
self.size = size
self.num_classes = num_classes
- self.encode_labels = encode_labels
+ self.one_hot_labels = one_hot_labels
assert size < num_classes
def __getitem__(self, idx):
img = torch.rand(3, 100, 100)
- label = torch.tensor(idx) # This ensures all labels in a batch are unique and makes testing easier
- if self.encode_labels:
- label = torch.nn.functional.one_hot(label, num_classes=self.num_classes)
+ label = idx # This ensures all labels in a batch are unique and makes testing easier
+ if self.one_hot_labels:
+ label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes)
return img, label
def __len__(self):
return self.size
- @pytest.mark.parametrize(["T", "encode_labels"], [[transforms.CutMix, False], [transforms.MixUp, False], [transforms.CutMix, True], [transforms.MixUp, True]])
- def test_supported_input_structure(self, T, encode_labels: bool):
+ @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp])
+ @pytest.mark.parametrize("one_hot_labels", (True, False))
+ def test_supported_input_structure(self, T, one_hot_labels):
batch_size = 32
num_classes = 100
- dataset = self.DummyDataset(size=batch_size, num_classes=num_classes,encode_labels=encode_labels)
+ dataset = self.DummyDataset(size=batch_size, num_classes=num_classes, one_hot_labels=one_hot_labels)
cutmix_mixup = T(num_classes=num_classes)
@@ -2201,10 +2202,7 @@ class TestCutMixMixUp:
img, target = next(iter(dl))
input_img_size = img.shape[-3:]
assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor)
- if encode_labels:
- assert target.shape == (batch_size, num_classes)
- else:
- assert target.shape == (batch_size,)
+ assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
def check_output(img, target):
assert img.shape == (batch_size, *input_img_size)
@@ -2215,10 +2213,7 @@ class TestCutMixMixUp:
# After Dataloader, as unpacked input
img, target = next(iter(dl))
- if encode_labels:
- assert target.shape == (batch_size, num_classes)
- else:
- assert target.shape == (batch_size,)
+ assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,)
img, target = cutmix_mixup(img, target)
check_output(img, target)
@@ -2273,7 +2268,7 @@ class TestCutMixMixUp:
with pytest.raises(ValueError, match="Could not infer where the labels are"):
cutmix_mixup({"img": imgs, "Nothing_else": 3})
- with pytest.raises(ValueError, match="labels tensor should be of shape"):
+ with pytest.raises(ValueError, match="labels should be index based"):
# Note: the error message isn't ideal, but that's because the label heuristic found the img as the label
# It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently
cutmix_mixup(imgs)
@@ -2281,22 +2276,21 @@ class TestCutMixMixUp:
with pytest.raises(ValueError, match="When using the default labels_getter"):
cutmix_mixup(imgs, "not_a_tensor")
- with pytest.raises(ValueError, match="labels tensor should be of shape"):
- cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3)))
-
with pytest.raises(ValueError, match="Expected a batched input with 4 dims"):
cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,)))
with pytest.raises(ValueError, match="does not match the batch size of the labels"):
cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,)))
- with pytest.raises(ValueError, match="labels tensor should be of shape"):
- # The purpose of this check is more about documenting the current
- # behaviour of what happens on a Compose(), rather than actually
- # asserting the expected behaviour. We may support Compose() in the
- # future, e.g. for 2 consecutive CutMix?
- labels = torch.randint(0, num_classes, size=(batch_size,))
- transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels)
+ with pytest.raises(ValueError, match="When passing 2D labels"):
+ wrong_num_classes = num_classes + 1
+ T(alpha=0.5, num_classes=num_classes)(imgs, torch.randint(0, 2, size=(batch_size, wrong_num_classes)))
+
+ with pytest.raises(ValueError, match="but got a tensor of shape"):
+ cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3, 4)))
+
+ with pytest.raises(ValueError, match="num_classes must be passed"):
+ T(alpha=0.5)(imgs, torch.randint(0, num_classes, size=(batch_size,)))
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py
index 48daa271ea..1d01012654 100644
--- a/torchvision/transforms/v2/_augment.py
+++ b/torchvision/transforms/v2/_augment.py
@@ -1,7 +1,7 @@
import math
import numbers
import warnings
-from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
@@ -142,7 +142,7 @@ class RandomErasing(_RandomApplyTransform):
class _BaseMixUpCutMix(Transform):
- def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default", labels_encoded: bool = False) -> None:
+ def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None:
super().__init__()
self.alpha = float(alpha)
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
@@ -150,7 +150,6 @@ class _BaseMixUpCutMix(Transform):
self.num_classes = num_classes
self._labels_getter = _parse_labels_getter(labels_getter)
- self._labels_encoded = labels_encoded
def forward(self, *inputs):
inputs = inputs if len(inputs) > 1 else inputs[0]
@@ -163,10 +162,21 @@ class _BaseMixUpCutMix(Transform):
labels = self._labels_getter(inputs)
if not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.")
- elif not 0 < labels.ndim <= 2 or (labels.ndim == 2 and labels.shape[1] != self.num_classes):
+ if labels.ndim not in (1, 2):
raise ValueError(
- f"labels tensor should be of shape (batch_size,) or (batch_size,num_classes) " f"but got shape {labels.shape} instead."
+ f"labels should be index based with shape (batch_size,) "
+ f"or probability based with shape (batch_size, num_classes), "
+ f"but got a tensor of shape {labels.shape} instead."
)
+ if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes:
+ raise ValueError(
+ f"When passing 2D labels, "
+ f"the number of elements in last dimension must match num_classes: "
+ f"{labels.shape[-1]} != {self.num_classes}. "
+ f"You can Leave num_classes to None."
+ )
+ if labels.ndim == 1 and self.num_classes is None:
+ raise ValueError("num_classes must be passed if the labels are index-based (1D)")
params = {
"labels": labels,
@@ -225,7 +235,8 @@ class MixUp(_BaseMixUpCutMix):
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
- num_classes (int): number of classes in the batch. Used for one-hot-encoding.
+ num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
+ Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``.
@@ -273,7 +284,8 @@ class CutMix(_BaseMixUpCutMix):
Args:
alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1.
- num_classes (int): number of classes in the batch. Used for one-hot-encoding.
+ num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding.
+ Can be None only if the labels are already one-hot-encoded.
labels_getter (callable or "default", optional): indicates how to identify the labels in the input.
By default, this will pick the second parameter as the labels if it's a tensor. This covers the most
common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``. |
Hi @NicolasHug, that's diff applied! Hope you have a great weekend. Mahdi |
Thank you @mahdilamb Before I can merge, do you mind fixing this one linting issue:
I think adding a simple Thanks! |
@NicolasHug, made the change, but if it fails will look into it properly. Also added you as a collaborator on the fork so you can mess about! |
Thank you @mahdilamb ! |
Summary: Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com> Reviewed By: vmoens Differential Revision: D58283866 fbshipit-source-id: 32b0b2ade02b3a81d167f64a3743c2bf62049308
Todo:
cc @vfdev-5