From 1d646d41f72aac87bd3a85ce39ce159d3fd4180c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 5 Oct 2023 12:05:01 +0200 Subject: [PATCH] port prototype tests to new utilities (#8022) --- test/prototype_common_utils.py | 82 ------------------------------ test/test_prototype_transforms.py | 67 ++++++++++++------------ test/transforms_v2_legacy_utils.py | 1 - 3 files changed, 34 insertions(+), 116 deletions(-) delete mode 100644 test/prototype_common_utils.py diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py deleted file mode 100644 index b26bcff3246..00000000000 --- a/test/prototype_common_utils.py +++ /dev/null @@ -1,82 +0,0 @@ -import collections.abc -import dataclasses -from typing import Optional, Sequence - -import pytest -import torch -from torch.nn.functional import one_hot - -from torchvision.prototype import tv_tensors - -from transforms_v2_legacy_utils import combinations_grid, DEFAULT_EXTRA_DIMS, from_loader, from_loaders, TensorLoader - - -@dataclasses.dataclass -class LabelLoader(TensorLoader): - categories: Optional[Sequence[str]] - - -def _parse_categories(categories): - if categories is None: - num_categories = int(torch.randint(1, 11, ())) - elif isinstance(categories, int): - num_categories = categories - categories = [f"category{idx}" for idx in range(num_categories)] - elif isinstance(categories, collections.abc.Sequence) and all(isinstance(category, str) for category in categories): - categories = list(categories) - num_categories = len(categories) - else: - raise pytest.UsageError( - f"`categories` can either be `None` (default), an integer, or a sequence of strings, " - f"but got '{categories}' instead." - ) - return categories, num_categories - - -def make_label_loader(*, extra_dims=(), categories=None, dtype=torch.int64): - categories, num_categories = _parse_categories(categories) - - def fn(shape, dtype, device): - # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, - # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 - data = torch.testing.make_tensor(shape, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) - return tv_tensors.Label(data, categories=categories) - - return LabelLoader(fn, shape=extra_dims, dtype=dtype, categories=categories) - - -make_label = from_loader(make_label_loader) - - -@dataclasses.dataclass -class OneHotLabelLoader(TensorLoader): - categories: Optional[Sequence[str]] - - -def make_one_hot_label_loader(*, categories=None, extra_dims=(), dtype=torch.int64): - categories, num_categories = _parse_categories(categories) - - def fn(shape, dtype, device): - if num_categories == 0: - data = torch.empty(shape, dtype=dtype, device=device) - else: - # The idiom `make_label_loader(..., dtype=torch.int64); ...; one_hot(...).to(dtype)` is intentional - # since `one_hot` only supports int64 - label = make_label_loader(extra_dims=extra_dims, categories=num_categories, dtype=torch.int64).load(device) - data = one_hot(label, num_classes=num_categories).to(dtype) - return tv_tensors.OneHotLabel(data, categories=categories) - - return OneHotLabelLoader(fn, shape=(*extra_dims, num_categories), dtype=dtype, categories=categories) - - -def make_one_hot_label_loaders( - *, - categories=(1, 0, None), - extra_dims=DEFAULT_EXTRA_DIMS, - dtypes=(torch.int64, torch.float32), -): - for params in combinations_grid(categories=categories, extra_dims=extra_dims, dtype=dtypes): - yield make_one_hot_label_loader(**params) - - -make_one_hot_labels = from_loaders(make_one_hot_label_loaders) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 9794b196a70..3f2e5015863 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,41 +1,42 @@ +import collections.abc import re import PIL.Image import pytest import torch -from common_utils import assert_equal +from common_utils import assert_equal, make_bounding_boxes, make_detection_masks, make_image, make_video -from prototype_common_utils import make_label from torchvision.prototype import transforms, tv_tensors from torchvision.transforms.v2._utils import check_type, is_pure_tensor from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image from torchvision.tv_tensors import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video -from transforms_v2_legacy_utils import ( - DEFAULT_EXTRA_DIMS, - make_bounding_boxes, - make_detection_mask, - make_image, - make_video, -) -BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] +def _parse_categories(categories): + if categories is None: + num_categories = int(torch.randint(1, 11, ())) + elif isinstance(categories, int): + num_categories = categories + categories = [f"category{idx}" for idx in range(num_categories)] + elif isinstance(categories, collections.abc.Sequence) and all(isinstance(category, str) for category in categories): + categories = list(categories) + num_categories = len(categories) + else: + raise pytest.UsageError( + f"`categories` can either be `None` (default), an integer, or a sequence of strings, " + f"but got '{categories}' instead." + ) + return categories, num_categories -def parametrize(transforms_with_inputs): - return pytest.mark.parametrize( - ("transform", "input"), - [ - pytest.param( - transform, - input, - id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}", - ) - for transform, inputs in transforms_with_inputs - for idx, input in enumerate(inputs) - ], - ) + +def make_label(*, extra_dims=(), categories=10, dtype=torch.int64, device="cpu"): + categories, num_categories = _parse_categories(categories) + # The idiom `make_tensor(..., dtype=torch.int64).to(dtype)` is intentional to only get integer values, + # regardless of the requested dtype, e.g. 0 or 0.0 rather than 0 or 0.123 + data = torch.testing.make_tensor(extra_dims, low=0, high=num_categories, dtype=torch.int64, device=device).to(dtype) + return tv_tensors.Label(data, categories=categories) class TestSimpleCopyPaste: @@ -167,7 +168,7 @@ def test__get_params(self, mocker): flat_inputs = [ make_image(size=canvas_size, color_space="RGB"), - make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=batch_shape), + make_bounding_boxes(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_shape[0]), ] params = transform._get_params(flat_inputs) @@ -203,9 +204,9 @@ def test__transform_culling(self, mocker): ) bounding_boxes = make_bounding_boxes( - format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,) + format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_size ) - masks = make_detection_mask(size=canvas_size, batch_dims=(batch_size,)) + masks = make_detection_masks(size=canvas_size, num_masks=batch_size) labels = make_label(extra_dims=(batch_size,)) transform = transforms.FixedSizeCrop((-1, -1)) @@ -241,7 +242,7 @@ def test__transform_bounding_boxes_clamping(self, mocker): ) bounding_boxes = make_bounding_boxes( - format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,) + format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, num_boxes=batch_size ) mock = mocker.patch( "torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes", wraps=clamp_bounding_boxes @@ -389,27 +390,27 @@ def make_tv_tensors(): pil_image = to_pil_image(make_image(size=size, color_space="RGB")) target = { - "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), + "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", num_boxes=num_objects, dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), - "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), + "masks": make_detection_masks(size=size, num_masks=num_objects, dtype=torch.long), } yield (pil_image, target) tensor_image = torch.Tensor(make_image(size=size, color_space="RGB")) target = { - "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), + "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", num_boxes=num_objects, dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), - "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), + "masks": make_detection_masks(size=size, num_masks=num_objects, dtype=torch.long), } yield (tensor_image, target) tv_tensor_image = make_image(size=size, color_space="RGB") target = { - "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), + "boxes": make_bounding_boxes(canvas_size=size, format="XYXY", num_boxes=num_objects, dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), - "masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long), + "masks": make_detection_masks(size=size, num_masks=num_objects, dtype=torch.long), } yield (tv_tensor_image, target) diff --git a/test/transforms_v2_legacy_utils.py b/test/transforms_v2_legacy_utils.py index 0cf31f93641..1d121cd1963 100644 --- a/test/transforms_v2_legacy_utils.py +++ b/test/transforms_v2_legacy_utils.py @@ -6,7 +6,6 @@ The following legacy modules depend on this module - test_transforms_v2_consistency.py -- test_prototype_transforms.py """ import collections.abc