-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
port tests for container transforms #8012
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -396,6 +396,8 @@ def check_transform(transform, input, check_v1_compatibility=True, check_sample_ | |
if check_v1_compatibility: | ||
_check_transform_v1_compatibility(transform, input, **_to_tolerances(check_v1_compatibility)) | ||
|
||
return output | ||
|
||
|
||
def transform_cls_to_functional(transform_cls, **transform_specific_kwargs): | ||
def wrapper(input, *args, **kwargs): | ||
|
@@ -1773,7 +1775,7 @@ def test_transform_unknown_fill_error(self): | |
transforms.RandomAffine(degrees=0, fill="fill") | ||
|
||
|
||
class TestCompose: | ||
class TestContainerTransforms: | ||
class BuiltinTransform(transforms.Transform): | ||
def _transform(self, inpt, params): | ||
return inpt | ||
|
@@ -1788,7 +1790,10 @@ def forward(self, image, label): | |
return image, label | ||
|
||
@pytest.mark.parametrize( | ||
"transform_clss", | ||
"transform_cls", [transforms.Compose, functools.partial(transforms.RandomApply, p=1), transforms.RandomOrder] | ||
) | ||
@pytest.mark.parametrize( | ||
"wrapped_transform_clss", | ||
[ | ||
[BuiltinTransform], | ||
[PackedInputTransform], | ||
|
@@ -1803,12 +1808,12 @@ def forward(self, image, label): | |
], | ||
) | ||
@pytest.mark.parametrize("unpack", [True, False]) | ||
def test_packed_unpacked(self, transform_clss, unpack): | ||
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss) | ||
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) | ||
def test_packed_unpacked(self, transform_cls, wrapped_transform_clss, unpack): | ||
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in wrapped_transform_clss) | ||
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in wrapped_transform_clss) | ||
assert not (needs_packed_inputs and needs_unpacked_inputs) | ||
|
||
transform = transforms.Compose([cls() for cls in transform_clss]) | ||
transform = transform_cls([cls() for cls in wrapped_transform_clss]) | ||
|
||
image = make_image() | ||
label = 3 | ||
|
@@ -1833,6 +1838,102 @@ def call_transform(): | |
assert output[0] is image | ||
assert output[1] is label | ||
|
||
def test_compose(self): | ||
transform = transforms.Compose( | ||
[ | ||
transforms.RandomHorizontalFlip(p=1), | ||
transforms.RandomVerticalFlip(p=1), | ||
] | ||
) | ||
|
||
input = make_image() | ||
|
||
actual = check_transform(transform, input) | ||
expected = F.vertical_flip(F.horizontal_flip(input)) | ||
|
||
assert_equal(actual, expected) | ||
|
||
@pytest.mark.parametrize("p", [0.0, 1.0]) | ||
@pytest.mark.parametrize("sequence_type", [list, nn.ModuleList]) | ||
def test_random_apply(self, p, sequence_type): | ||
transform = transforms.RandomApply( | ||
sequence_type( | ||
[ | ||
transforms.RandomHorizontalFlip(p=1), | ||
transforms.RandomVerticalFlip(p=1), | ||
] | ||
), | ||
p=p, | ||
) | ||
|
||
# This needs to be a pure tensor (or a PIL image), because otherwise check_transforms skips the v1 compatibility | ||
# check | ||
input = make_image_tensor() | ||
output = check_transform(transform, input, check_v1_compatibility=issubclass(sequence_type, nn.ModuleList)) | ||
|
||
if p == 1: | ||
assert_equal(output, F.vertical_flip(F.horizontal_flip(input))) | ||
else: | ||
assert output is input | ||
|
||
@pytest.mark.parametrize("p", [(0, 1), (1, 0)]) | ||
def test_random_choice(self, p): | ||
transform = transforms.RandomChoice( | ||
[ | ||
transforms.RandomHorizontalFlip(p=1), | ||
transforms.RandomVerticalFlip(p=1), | ||
], | ||
p=p, | ||
) | ||
|
||
input = make_image() | ||
output = check_transform(transform, input) | ||
|
||
p_horz, p_vert = p | ||
if p_horz: | ||
assert_equal(output, F.horizontal_flip(input)) | ||
else: | ||
assert_equal(output, F.vertical_flip(input)) | ||
|
||
def test_random_order(self): | ||
transform = transforms.Compose( | ||
[ | ||
transforms.RandomHorizontalFlip(p=1), | ||
transforms.RandomVerticalFlip(p=1), | ||
] | ||
) | ||
|
||
input = make_image() | ||
|
||
actual = check_transform(transform, input) | ||
# horizontal and vertical flip are commutative. Meaning, although the order in the transform is indeed random, | ||
# we don't need to care here. | ||
expected = F.vertical_flip(F.horizontal_flip(input)) | ||
|
||
assert_equal(actual, expected) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately it also means that this test doesn't check that the order of the transforms is indeed random. At best it checks that the input transforms are applied. I think we should at least acknowledge that in the comment instead of saying "we don't need to care". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough. All ears if you have an idea to check whether the transforms are actually applied in random order. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a pair of non-commutative transforms, with 2 hand-chosen seeds, asserting that the results are different. What makes it a bit harder is that the transforms themselves must be non-random. Worst case scenario we could just define |
||
|
||
def test_errors(self): | ||
for cls in [transforms.Compose, transforms.RandomChoice, transforms.RandomOrder]: | ||
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"): | ||
cls(lambda x: x) | ||
|
||
with pytest.raises(ValueError, match="at least one transform"): | ||
transforms.Compose([]) | ||
|
||
for p in [-1, 2]: | ||
with pytest.raises(ValueError, match=re.escape("value in the interval [0.0, 1.0]")): | ||
transforms.RandomApply([lambda x: x], p=p) | ||
|
||
for transforms_, p in [ | ||
( | ||
[lambda x: x], | ||
[], | ||
), | ||
([], [1.0]), | ||
]: | ||
with pytest.raises(ValueError, match="Length of p doesn't match the number of transforms"): | ||
transforms.RandomChoice(transforms_, p=p) | ||
|
||
|
||
class TestToDtype: | ||
@pytest.mark.parametrize( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,14 +100,15 @@ def _extract_params_for_v1_transform(self) -> Dict[str, Any]: | |
return {"transforms": self.transforms, "p": self.p} | ||
|
||
def forward(self, *inputs: Any) -> Any: | ||
sample = inputs if len(inputs) > 1 else inputs[0] | ||
needs_unpacking = len(inputs) > 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This got the same treatment as |
||
|
||
if torch.rand(1) >= self.p: | ||
return sample | ||
return inputs if needs_unpacking else inputs[0] | ||
|
||
for transform in self.transforms: | ||
sample = transform(sample) | ||
return sample | ||
outputs = transform(*inputs) | ||
inputs = outputs if needs_unpacking else (outputs,) | ||
return outputs | ||
|
||
def extra_repr(self) -> str: | ||
format_string = [] | ||
|
@@ -173,8 +174,9 @@ def __init__(self, transforms: Sequence[Callable]) -> None: | |
self.transforms = transforms | ||
|
||
def forward(self, *inputs: Any) -> Any: | ||
sample = inputs if len(inputs) > 1 else inputs[0] | ||
needs_unpacking = len(inputs) > 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
for idx in torch.randperm(len(self.transforms)): | ||
transform = self.transforms[idx] | ||
sample = transform(sample) | ||
return sample | ||
outputs = transform(*inputs) | ||
inputs = outputs if needs_unpacking else (outputs,) | ||
return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already compute the
output
incheck_transform
. By returning it, we don't need to recompute in case the test performs additional checks.