From 25ec3f2617450ed7c5f848e18c0f10354e9786d6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 30 Aug 2023 12:00:27 +0100 Subject: [PATCH] tv_tensor -> TVTensor where it matters (#7904) Co-authored-by: Philip Meier --- docs/source/transforms.rst | 2 +- docs/source/tv_tensors.rst | 9 ++-- gallery/transforms/plot_custom_transforms.py | 4 +- gallery/transforms/plot_custom_tv_tensors.py | 6 +-- .../plot_transforms_getting_started.py | 14 ++--- gallery/transforms/plot_tv_tensors.py | 42 +++++++-------- test/test_tv_tensors.py | 53 ++++++++++++------- .../tv_tensors/_torch_function_helpers.py | 17 +++--- torchvision/tv_tensors/_tv_tensor.py | 2 +- 9 files changed, 84 insertions(+), 65 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index fe9258d7382..3cae407a70a 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -183,7 +183,7 @@ Transforms are available as classes like This is very much like the :mod:`torch.nn` package which defines both classes and functional equivalents in :mod:`torch.nn.functional`. -The functionals support PIL images, pure tensors, or :ref:`tv_tensors +The functionals support PIL images, pure tensors, or :ref:`TVTensors `, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are valid. diff --git a/docs/source/tv_tensors.rst b/docs/source/tv_tensors.rst index d9a96b98161..e80a1ed88fb 100644 --- a/docs/source/tv_tensors.rst +++ b/docs/source/tv_tensors.rst @@ -5,10 +5,11 @@ TVTensors .. currentmodule:: torchvision.tv_tensors -TVTensors are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to -dispatch their inputs to the appropriate lower-level kernels. Most users do not -need to manipulate tv_tensors directly and can simply rely on dataset wrapping - -see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`. +TVTensors are :class:`torch.Tensor` subclasses which the v2 :ref:`transforms +` use under the hood to dispatch their inputs to the appropriate +lower-level kernels. Most users do not need to manipulate TVTensors directly and +can simply rely on dataset wrapping - see e.g. +:ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`. .. autosummary:: :toctree: generated/ diff --git a/gallery/transforms/plot_custom_transforms.py b/gallery/transforms/plot_custom_transforms.py index 789de0ea6b9..898c2cd0bea 100644 --- a/gallery/transforms/plot_custom_transforms.py +++ b/gallery/transforms/plot_custom_transforms.py @@ -74,7 +74,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }") # %% # .. note:: -# While working with tv_tensor classes in your code, make sure to +# While working with TVTensor classes in your code, make sure to # familiarize yourself with this section: # :ref:`tv_tensor_unwrapping_behaviour` # @@ -111,7 +111,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured # In brief, the core logic is to unpack the input into a flat list using `pytree # `_, and # then transform only the entries that can be transformed (the decision is made -# based on the **class** of the entries, as all tv_tensors are +# based on the **class** of the entries, as all TVTensors are # tensor-subclasses) plus some custom logic that is out of score here - check the # code for details. The (potentially transformed) entries are then repacked and # returned, in the same structure as the input. diff --git a/gallery/transforms/plot_custom_tv_tensors.py b/gallery/transforms/plot_custom_tv_tensors.py index 75c4e82547a..bf5ee198837 100644 --- a/gallery/transforms/plot_custom_tv_tensors.py +++ b/gallery/transforms/plot_custom_tv_tensors.py @@ -1,14 +1,14 @@ """ -===================================== +==================================== How to write your own TVTensor class -===================================== +==================================== .. note:: Try on `collab `_ or :ref:`go to the end ` to download the full example code. This guide is intended for advanced users and downstream library maintainers. We explain how to -write your own tv_tensor class, and how to make it compatible with the built-in +write your own TVTensor class, and how to make it compatible with the built-in Torchvision v2 transforms. Before continuing, make sure you have read :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`. """ diff --git a/gallery/transforms/plot_transforms_getting_started.py b/gallery/transforms/plot_transforms_getting_started.py index 62ee13643ad..cbaab3dc97d 100644 --- a/gallery/transforms/plot_transforms_getting_started.py +++ b/gallery/transforms/plot_transforms_getting_started.py @@ -115,7 +115,7 @@ # segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have # passed them to the transforms in exactly the same way. # -# By now you likely have a few questions: what are these tv_tensors, how do we +# By now you likely have a few questions: what are these TVTensors, how do we # use them, and what is the expected input/output of those transforms? We'll # answer these in the next sections. @@ -126,7 +126,7 @@ # What are TVTensors? # -------------------- # -# TVTensors are :class:`torch.Tensor` subclasses. The available tv_tensors are +# TVTensors are :class:`torch.Tensor` subclasses. The available TVTensors are # :class:`~torchvision.tv_tensors.Image`, # :class:`~torchvision.tv_tensors.BoundingBoxes`, # :class:`~torchvision.tv_tensors.Mask`, and @@ -134,7 +134,7 @@ # # TVTensors look and feel just like regular tensors - they **are** tensors. # Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` -# or any ``torch.*`` operator will also work on a tv_tensor: +# or any ``torch.*`` operator will also work on a TVTensor: img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8)) @@ -146,7 +146,7 @@ # transform a given input, the transforms first look at the **class** of the # object, and dispatch to the appropriate implementation accordingly. # -# You don't need to know much more about tv_tensors at this point, but advanced +# You don't need to know much more about TVTensors at this point, but advanced # users who want to learn more can refer to # :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`. # @@ -234,9 +234,9 @@ # Torchvision also supports datasets for object detection or segmentation like # :class:`torchvision.datasets.CocoDetection`. Those datasets predate # the existence of the :mod:`torchvision.transforms.v2` module and of the -# tv_tensors, so they don't return tv_tensors out of the box. +# TVTensors, so they don't return TVTensors out of the box. # -# An easy way to force those datasets to return tv_tensors and to make them +# An easy way to force those datasets to return TVTensors and to make them # compatible with v2 transforms is to use the # :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function: # @@ -246,7 +246,7 @@ # # dataset = CocoDetection(..., transforms=my_transforms) # dataset = wrap_dataset_for_transforms_v2(dataset) -# # Now the dataset returns tv_tensors! +# # Now the dataset returns TVTensors! # # Using your own datasets # ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/gallery/transforms/plot_tv_tensors.py b/gallery/transforms/plot_tv_tensors.py index c5813189e53..0cdbe9d0831 100644 --- a/gallery/transforms/plot_tv_tensors.py +++ b/gallery/transforms/plot_tv_tensors.py @@ -9,18 +9,18 @@ TVTensors are Tensor subclasses introduced together with -``torchvision.transforms.v2``. This example showcases what these tv_tensors are +``torchvision.transforms.v2``. This example showcases what these TVTensors are and how they behave. .. warning:: - **Intended Audience** Unless you're writing your own transforms or your own tv_tensors, you + **Intended Audience** Unless you're writing your own transforms or your own TVTensors, you probably do not need to read this guide. This is a fairly low-level topic that most users will not need to worry about: you do not need to understand - the internals of tv_tensors to efficiently rely on + the internals of TVTensors to efficiently rely on ``torchvision.transforms.v2``. It may however be useful for advanced users trying to implement their own datasets, transforms, or work directly with - the tv_tensors. + the TVTensors. """ # %% @@ -31,8 +31,8 @@ # %% -# What are tv_tensors? -# -------------------- +# What are TVTensors? +# ------------------- # # TVTensors are zero-copy tensor subclasses: @@ -46,31 +46,31 @@ # Under the hood, they are needed in :mod:`torchvision.transforms.v2` to correctly dispatch to the appropriate function # for the input data. # -# :mod:`torchvision.tv_tensors` supports four types of tv_tensors: +# :mod:`torchvision.tv_tensors` supports four types of TVTensors: # # * :class:`~torchvision.tv_tensors.Image` # * :class:`~torchvision.tv_tensors.Video` # * :class:`~torchvision.tv_tensors.BoundingBoxes` # * :class:`~torchvision.tv_tensors.Mask` # -# What can I do with a tv_tensor? -# ------------------------------- +# What can I do with a TVTensor? +# ------------------------------ # # TVTensors look and feel just like regular tensors - they **are** tensors. # Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()`` or -# any ``torch.*`` operator will also work on tv_tensors. See +# any ``torch.*`` operator will also work on TVTensors. See # :ref:`tv_tensor_unwrapping_behaviour` for a few gotchas. # %% # .. _tv_tensor_creation: # -# How do I construct a tv_tensor? -# ------------------------------- +# How do I construct a TVTensor? +# ------------------------------ # # Using the constructor # ^^^^^^^^^^^^^^^^^^^^^ # -# Each tv_tensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor` +# Each TVTensor class takes any tensor-like data that can be turned into a :class:`~torch.Tensor` image = tv_tensors.Image([[[[0, 1], [1, 0]]]]) print(image) @@ -92,7 +92,7 @@ print(image.shape, image.dtype) # %% -# Some tv_tensors require additional metadata to be passed in ordered to be constructed. For example, +# Some TVTensors require additional metadata to be passed in ordered to be constructed. For example, # :class:`~torchvision.tv_tensors.BoundingBoxes` requires the coordinate format as well as the size of the # corresponding image (``canvas_size``) alongside the actual values. These # metadata are required to properly transform the bounding boxes. @@ -109,7 +109,7 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # You can also use the :func:`~torchvision.tv_tensors.wrap` function to wrap a tensor object -# into a tv_tensor. This is useful when you already have an object of the +# into a TVTensor. This is useful when you already have an object of the # desired type, which typically happens when writing transforms: you just want # to wrap the output like the input. @@ -125,7 +125,7 @@ # .. _tv_tensor_unwrapping_behaviour: # # I had a TVTensor but now I have a Tensor. Help! -# ------------------------------------------------ +# ----------------------------------------------- # # By default, operations on :class:`~torchvision.tv_tensors.TVTensor` objects # will return a pure Tensor: @@ -151,7 +151,7 @@ # But I want a TVTensor back! # ^^^^^^^^^^^^^^^^^^^^^^^^^^^ # -# You can re-wrap a pure tensor into a tv_tensor by just calling the tv_tensor +# You can re-wrap a pure tensor into a TVTensor by just calling the TVTensor # constructor, or by using the :func:`~torchvision.tv_tensors.wrap` function # (see more details above in :ref:`tv_tensor_creation`): @@ -164,7 +164,7 @@ # as a global config setting for the whole program, or as a context manager # (read its docs to learn more about caveats): -with tv_tensors.set_return_type("tv_tensor"): +with tv_tensors.set_return_type("TVTensor"): new_bboxes = bboxes + 3 assert isinstance(new_bboxes, tv_tensors.BoundingBoxes) @@ -203,9 +203,9 @@ # There are a few exceptions to this "unwrapping" rule: # :meth:`~torch.Tensor.clone`, :meth:`~torch.Tensor.to`, # :meth:`torch.Tensor.detach`, and :meth:`~torch.Tensor.requires_grad_` retain -# the tv_tensor type. +# the TVTensor type. # -# Inplace operations on tv_tensors like ``obj.add_()`` will preserve the type of +# Inplace operations on TVTensors like ``obj.add_()`` will preserve the type of # ``obj``. However, the **returned** value of inplace operations will be a pure # tensor: @@ -213,7 +213,7 @@ new_image = image.add_(1).mul_(2) -# image got transformed in-place and is still an Image tv_tensor, but new_image +# image got transformed in-place and is still a TVTensor Image, but new_image # is a Tensor. They share the same underlying data and they're equal, just # different classes. assert isinstance(image, tv_tensors.Image) diff --git a/test/test_tv_tensors.py b/test/test_tv_tensors.py index 92747f7ebf0..ed75ae35ecd 100644 --- a/test/test_tv_tensors.py +++ b/test/test_tv_tensors.py @@ -91,7 +91,7 @@ def test_to_wrapping(make_input): @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) -@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) +@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_to_tv_tensor_reference(make_input, return_type): tensor = torch.rand((3, 16, 16), dtype=torch.float64) dp = make_input() @@ -99,13 +99,13 @@ def test_to_tv_tensor_reference(make_input, return_type): with tv_tensors.set_return_type(return_type): tensor_to = tensor.to(dp) - assert type(tensor_to) is (type(dp) if return_type == "tv_tensor" else torch.Tensor) + assert type(tensor_to) is (type(dp) if return_type == "TVTensor" else torch.Tensor) assert tensor_to.dtype is dp.dtype assert type(tensor) is torch.Tensor @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) -@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) +@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_clone_wrapping(make_input, return_type): dp = make_input() @@ -117,7 +117,7 @@ def test_clone_wrapping(make_input, return_type): @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) -@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) +@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_requires_grad__wrapping(make_input, return_type): dp = make_input(dtype=torch.float) @@ -132,7 +132,7 @@ def test_requires_grad__wrapping(make_input, return_type): @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) -@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) +@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_detach_wrapping(make_input, return_type): dp = make_input(dtype=torch.float).requires_grad_(True) @@ -142,7 +142,7 @@ def test_detach_wrapping(make_input, return_type): assert type(dp_detached) is type(dp) -@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) +@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_force_subclass_with_metadata(return_type): # Sanity checks for the ops in _FORCE_TORCHFUNCTION_SUBCLASS and tv_tensors with metadata # Largely the same as above, we additionally check that the metadata is preserved @@ -151,27 +151,27 @@ def test_force_subclass_with_metadata(return_type): tv_tensors.set_return_type(return_type) bbox = bbox.clone() - if return_type == "tv_tensor": + if return_type == "TVTensor": assert bbox.format, bbox.canvas_size == (format, canvas_size) bbox = bbox.to(torch.float64) - if return_type == "tv_tensor": + if return_type == "TVTensor": assert bbox.format, bbox.canvas_size == (format, canvas_size) bbox = bbox.detach() - if return_type == "tv_tensor": + if return_type == "TVTensor": assert bbox.format, bbox.canvas_size == (format, canvas_size) assert not bbox.requires_grad bbox.requires_grad_(True) - if return_type == "tv_tensor": + if return_type == "TVTensor": assert bbox.format, bbox.canvas_size == (format, canvas_size) assert bbox.requires_grad tv_tensors.set_return_type("tensor") @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) -@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) +@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_other_op_no_wrapping(make_input, return_type): dp = make_input() @@ -179,7 +179,7 @@ def test_other_op_no_wrapping(make_input, return_type): # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here output = dp * 2 - assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor) + assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor) @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) @@ -200,7 +200,7 @@ def test_no_tensor_output_op_no_wrapping(make_input, op): @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) -@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) +@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) def test_inplace_op_no_wrapping(make_input, return_type): dp = make_input() original_type = type(dp) @@ -208,7 +208,7 @@ def test_inplace_op_no_wrapping(make_input, return_type): with tv_tensors.set_return_type(return_type): output = dp.add_(0) - assert type(output) is (type(dp) if return_type == "tv_tensor" else torch.Tensor) + assert type(output) is (type(dp) if return_type == "TVTensor" else torch.Tensor) assert type(dp) is original_type @@ -243,7 +243,7 @@ def test_deepcopy(make_input, requires_grad): @pytest.mark.parametrize("make_input", [make_image, make_bounding_boxes, make_segmentation_mask, make_video]) -@pytest.mark.parametrize("return_type", ["Tensor", "tv_tensor"]) +@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"]) @pytest.mark.parametrize( "op", ( @@ -267,8 +267,8 @@ def test_usual_operations(make_input, return_type, op): dp = make_input() with tv_tensors.set_return_type(return_type): out = op(dp) - assert type(out) is (type(dp) if return_type == "tv_tensor" else torch.Tensor) - if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "tv_tensor": + assert type(out) is (type(dp) if return_type == "TVTensor" else torch.Tensor) + if isinstance(dp, tv_tensors.BoundingBoxes) and return_type == "TVTensor": assert hasattr(out, "format") assert hasattr(out, "canvas_size") @@ -286,16 +286,16 @@ def test_set_return_type(): assert type(img + 3) is torch.Tensor - with tv_tensors.set_return_type("tv_tensor"): + with tv_tensors.set_return_type("TVTensor"): assert type(img + 3) is tv_tensors.Image assert type(img + 3) is torch.Tensor - tv_tensors.set_return_type("tv_tensor") + tv_tensors.set_return_type("TVTensor") assert type(img + 3) is tv_tensors.Image with tv_tensors.set_return_type("tensor"): assert type(img + 3) is torch.Tensor - with tv_tensors.set_return_type("tv_tensor"): + with tv_tensors.set_return_type("TVTensor"): assert type(img + 3) is tv_tensors.Image tv_tensors.set_return_type("tensor") assert type(img + 3) is torch.Tensor @@ -305,3 +305,16 @@ def test_set_return_type(): assert type(img + 3) is tv_tensors.Image tv_tensors.set_return_type("tensor") + + +def test_return_type_input(): + img = make_image() + + # Case-insensitive + with tv_tensors.set_return_type("tvtensor"): + assert type(img + 3) is tv_tensors.Image + + with pytest.raises(ValueError, match="return_type must be"): + tv_tensors.set_return_type("typo") + + tv_tensors.set_return_type("tensor") diff --git a/torchvision/tv_tensors/_torch_function_helpers.py b/torchvision/tv_tensors/_torch_function_helpers.py index 106c4260505..7edc471b110 100644 --- a/torchvision/tv_tensors/_torch_function_helpers.py +++ b/torchvision/tv_tensors/_torch_function_helpers.py @@ -16,7 +16,7 @@ def __exit__(self, *args): def set_return_type(return_type: str): - """[BETA] Set the return type of torch operations on tv_tensors. + """[BETA] Set the return type of torch operations on :class:`~torchvision.tv_tensors.TVTensor`. This only affects the behaviour of torch operations. It has no effect on ``torchvision`` transforms or functionals, which will always return as @@ -26,7 +26,7 @@ def set_return_type(return_type: str): We recommend using :class:`~torchvision.transforms.v2.ToPureTensor` at the end of your transform pipelines if you use - ``set_return_type("dataptoint")``. This will avoid the + ``set_return_type("TVTensor")``. This will avoid the ``__torch_function__`` overhead in the models ``forward()``. Can be used as a global flag for the entire program: @@ -36,7 +36,7 @@ def set_return_type(return_type: str): img = tv_tensors.Image(torch.rand(3, 5, 5)) img + 2 # This is a pure Tensor (default behaviour) - set_return_type("tv_tensors") + set_return_type("TVTensor") img + 2 # This is an Image or as a context manager to restrict the scope: @@ -45,16 +45,21 @@ def set_return_type(return_type: str): img = tv_tensors.Image(torch.rand(3, 5, 5)) img + 2 # This is a pure Tensor - with set_return_type("tv_tensors"): + with set_return_type("TVTensor"): img + 2 # This is an Image img + 2 # This is a pure Tensor Args: - return_type (str): Can be "tv_tensor" or "tensor". Default is "tensor". + return_type (str): Can be "TVTensor" or "Tensor" (case-insensitive). + Default is "Tensor" (i.e. pure :class:`torch.Tensor`). """ global _TORCHFUNCTION_SUBCLASS to_restore = _TORCHFUNCTION_SUBCLASS - _TORCHFUNCTION_SUBCLASS = {"tensor": False, "tv_tensor": True}[return_type.lower()] + + try: + _TORCHFUNCTION_SUBCLASS = {"tensor": False, "tvtensor": True}[return_type.lower()] + except KeyError: + raise ValueError(f"return_type must be 'TVTensor' or 'Tensor', got {return_type}") from None return _ReturnTypeCM(to_restore) diff --git a/torchvision/tv_tensors/_tv_tensor.py b/torchvision/tv_tensors/_tv_tensor.py index abeab9ae017..0c6af95af87 100644 --- a/torchvision/tv_tensors/_tv_tensor.py +++ b/torchvision/tv_tensors/_tv_tensor.py @@ -13,7 +13,7 @@ class TVTensor(torch.Tensor): - """[Beta] Base class for all tv_tensors. + """[Beta] Base class for all TVTensors. You probably don't want to use this class unless you're defining your own custom TVTensors. See