Skip to content

Commit

Permalink
Datapoint -> TVTensor; datapoint[s] -> tv_tensor[s] (#7894)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Aug 30, 2023
1 parent b9447fd commit d5f4cc3
Show file tree
Hide file tree
Showing 85 changed files with 1,121 additions and 1,121 deletions.
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ def __init__(self, src_dir):
"plot_transforms_e2e.py",
"plot_cutmix_mixup.py",
"plot_custom_transforms.py",
"plot_datapoints.py",
"plot_custom_datapoints.py",
"plot_tv_tensors.py",
"plot_custom_tv_tensors.py",
]

def __call__(self, filename):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ architectures, and common image transformations for computer vision.
:caption: Package Reference

transforms
datapoints
tv_tensors
models
datasets
utils
Expand Down
8 changes: 4 additions & 4 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ tasks (image classification, detection, segmentation, video classification).
.. code:: python
# Detection (re-using imports and transforms from above)
from torchvision import datapoints
from torchvision import tv_tensors
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
bboxes = torch.randint(0, H // 2, size=(3, 4))
bboxes[:, 2:] += bboxes[:, :2]
bboxes = datapoints.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W))
bboxes = tv_tensors.BoundingBoxes(bboxes, format="XYXY", canvas_size=(H, W))
# The same transforms can be used!
img, bboxes = transforms(img, bboxes)
Expand Down Expand Up @@ -183,8 +183,8 @@ Transforms are available as classes like
This is very much like the :mod:`torch.nn` package which defines both classes
and functional equivalents in :mod:`torch.nn.functional`.

The functionals support PIL images, pure tensors, or :ref:`datapoints
<datapoints>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
The functionals support PIL images, pure tensors, or :ref:`tv_tensors
<tv_tensors>`, e.g. both ``resize(image_tensor)`` and ``resize(bboxes)`` are
valid.

.. note::
Expand Down
12 changes: 6 additions & 6 deletions docs/source/datapoints.rst → docs/source/tv_tensors.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
.. _datapoints:
.. _tv_tensors:

Datapoints
TVTensors
==========

.. currentmodule:: torchvision.datapoints
.. currentmodule:: torchvision.tv_tensors

Datapoints are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
TVTensors are tensor subclasses which the :mod:`~torchvision.transforms.v2` v2 transforms use under the hood to
dispatch their inputs to the appropriate lower-level kernels. Most users do not
need to manipulate datapoints directly and can simply rely on dataset wrapping -
need to manipulate tv_tensors directly and can simply rely on dataset wrapping -
see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.

.. autosummary::
Expand All @@ -19,6 +19,6 @@ see e.g. :ref:`sphx_glr_auto_examples_transforms_plot_transforms_e2e.py`.
BoundingBoxFormat
BoundingBoxes
Mask
Datapoint
TVTensor
set_return_type
wrap
4 changes: 2 additions & 2 deletions gallery/transforms/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import matplotlib.pyplot as plt
import torch
from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F


Expand All @@ -22,7 +22,7 @@ def plot(imgs, row_title=None, **imshow_kwargs):
if isinstance(target, dict):
boxes = target.get("boxes")
masks = target.get("masks")
elif isinstance(target, datapoints.BoundingBoxes):
elif isinstance(target, tv_tensors.BoundingBoxes):
boxes = target
else:
raise ValueError(f"Unexpected target type: {type(target)}")
Expand Down
10 changes: 5 additions & 5 deletions gallery/transforms/plot_custom_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# %%
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import v2


Expand Down Expand Up @@ -62,7 +62,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured

H, W = 256, 256
img = torch.rand(3, H, W)
bboxes = datapoints.BoundingBoxes(
bboxes = tv_tensors.BoundingBoxes(
torch.tensor([[0, 10, 10, 20], [50, 50, 70, 70]]),
format="XYXY",
canvas_size=(H, W)
Expand All @@ -74,9 +74,9 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
print(f"Output image shape: {out_img.shape}\nout_bboxes = {out_bboxes}\n{out_label = }")
# %%
# .. note::
# While working with datapoint classes in your code, make sure to
# While working with tv_tensor classes in your code, make sure to
# familiarize yourself with this section:
# :ref:`datapoint_unwrapping_behaviour`
# :ref:`tv_tensor_unwrapping_behaviour`
#
# Supporting arbitrary input structures
# =====================================
Expand Down Expand Up @@ -111,7 +111,7 @@ def forward(self, img, bboxes, label): # we assume inputs are always structured
# In brief, the core logic is to unpack the input into a flat list using `pytree
# <https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py>`_, and
# then transform only the entries that can be transformed (the decision is made
# based on the **class** of the entries, as all datapoints are
# based on the **class** of the entries, as all tv_tensors are
# tensor-subclasses) plus some custom logic that is out of score here - check the
# code for details. The (potentially transformed) entries are then repacked and
# returned, in the same structure as the input.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,62 +1,62 @@
"""
=====================================
How to write your own Datapoint class
How to write your own TVTensor class
=====================================
.. note::
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_datapoints.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_datapoints.py>` to download the full example code.
Try on `collab <https://colab.research.google.com/github/pytorch/vision/blob/gh-pages/main/_generated_ipynb_notebooks/plot_custom_tv_tensors.ipynb>`_
or :ref:`go to the end <sphx_glr_download_auto_examples_transforms_plot_custom_tv_tensors.py>` to download the full example code.
This guide is intended for advanced users and downstream library maintainers. We explain how to
write your own datapoint class, and how to make it compatible with the built-in
write your own tv_tensor class, and how to make it compatible with the built-in
Torchvision v2 transforms. Before continuing, make sure you have read
:ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`.
:ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
"""

# %%
import torch
from torchvision import datapoints
from torchvision import tv_tensors
from torchvision.transforms import v2

# %%
# We will create a very simple class that just inherits from the base
# :class:`~torchvision.datapoints.Datapoint` class. It will be enough to cover
# :class:`~torchvision.tv_tensors.TVTensor` class. It will be enough to cover
# what you need to know to implement your more elaborate uses-cases. If you need
# to create a class that carries meta-data, take a look at how the
# :class:`~torchvision.datapoints.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/datapoints/_bounding_box.py>`_.
# :class:`~torchvision.tv_tensors.BoundingBoxes` class is `implemented
# <https://github.com/pytorch/vision/blob/main/torchvision/tv_tensors/_bounding_box.py>`_.


class MyDatapoint(datapoints.Datapoint):
class MyTVTensor(tv_tensors.TVTensor):
pass


my_dp = MyDatapoint([1, 2, 3])
my_dp = MyTVTensor([1, 2, 3])
my_dp

# %%
# Now that we have defined our custom Datapoint class, we want it to be
# Now that we have defined our custom TVTensor class, we want it to be
# compatible with the built-in torchvision transforms, and the functional API.
# For that, we need to implement a kernel which performs the core of the
# transformation, and then "hook" it to the functional that we want to support
# via :func:`~torchvision.transforms.v2.functional.register_kernel`.
#
# We illustrate this process below: we create a kernel for the "horizontal flip"
# operation of our MyDatapoint class, and register it to the functional API.
# operation of our MyTVTensor class, and register it to the functional API.

from torchvision.transforms.v2 import functional as F


@F.register_kernel(functional="hflip", datapoint_cls=MyDatapoint)
def hflip_my_datapoint(my_dp, *args, **kwargs):
@F.register_kernel(functional="hflip", tv_tensor_cls=MyTVTensor)
def hflip_my_tv_tensor(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return datapoints.wrap(out, like=my_dp)
return tv_tensors.wrap(out, like=my_dp)


# %%
# To understand why :func:`~torchvision.datapoints.wrap` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# To understand why :func:`~torchvision.tv_tensors.wrap` is used, see
# :ref:`tv_tensor_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
# .. note::
Expand All @@ -67,9 +67,9 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# ``@register_kernel(functional=F.hflip, ...)``.
#
# Now that we have registered our kernel, we can call the functional API on a
# ``MyDatapoint`` instance:
# ``MyTVTensor`` instance:

my_dp = MyDatapoint(torch.rand(3, 256, 256))
my_dp = MyTVTensor(torch.rand(3, 256, 256))
_ = F.hflip(my_dp)

# %%
Expand Down Expand Up @@ -102,10 +102,10 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
# to its :func:`~torchvision.transforms.v2.functional.hflip` functional. If you
# already defined and registered your own kernel as

def hflip_my_datapoint(my_dp): # noqa
def hflip_my_tv_tensor(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return datapoints.wrap(out, like=my_dp)
return tv_tensors.wrap(out, like=my_dp)


# %%
Expand Down
6 changes: 3 additions & 3 deletions gallery/transforms/plot_transforms_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch
import torch.utils.data

from torchvision import models, datasets, datapoints
from torchvision import models, datasets, tv_tensors
from torchvision.transforms import v2

torch.manual_seed(0)
Expand Down Expand Up @@ -72,7 +72,7 @@
# %%
# We used the ``target_keys`` parameter to specify the kind of output we're
# interested in. Our dataset now returns a target which is dict where the values
# are :ref:`Datapoints <what_are_datapoints>` (all are :class:`torch.Tensor`
# are :ref:`TVTensors <what_are_tv_tensors>` (all are :class:`torch.Tensor`
# subclasses). We're dropped all unncessary keys from the previous output, but
# if you need any of the original keys e.g. "image_id", you can still ask for
# it.
Expand Down Expand Up @@ -103,7 +103,7 @@
[
v2.ToImage(),
v2.RandomPhotometricDistort(p=1),
v2.RandomZoomOut(fill={datapoints.Image: (123, 117, 104), "others": 0}),
v2.RandomZoomOut(fill={tv_tensors.Image: (123, 117, 104), "others": 0}),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=1),
v2.SanitizeBoundingBoxes(),
Expand Down
56 changes: 28 additions & 28 deletions gallery/transforms/plot_transforms_getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@
#
# Let's briefly look at a detection example with bounding boxes.

from torchvision import datapoints # we'll describe this a bit later, bare with us
from torchvision import tv_tensors # we'll describe this a bit later, bare with us

boxes = datapoints.BoundingBoxes(
boxes = tv_tensors.BoundingBoxes(
[
[15, 10, 370, 510],
[275, 340, 510, 510],
Expand All @@ -111,44 +111,44 @@
# %%
#
# The example above focuses on object detection. But if we had masks
# (:class:`torchvision.datapoints.Mask`) for object segmentation or semantic
# segmentation, or videos (:class:`torchvision.datapoints.Video`), we could have
# (:class:`torchvision.tv_tensors.Mask`) for object segmentation or semantic
# segmentation, or videos (:class:`torchvision.tv_tensors.Video`), we could have
# passed them to the transforms in exactly the same way.
#
# By now you likely have a few questions: what are these datapoints, how do we
# By now you likely have a few questions: what are these tv_tensors, how do we
# use them, and what is the expected input/output of those transforms? We'll
# answer these in the next sections.

# %%
#
# .. _what_are_datapoints:
# .. _what_are_tv_tensors:
#
# What are Datapoints?
# What are TVTensors?
# --------------------
#
# Datapoints are :class:`torch.Tensor` subclasses. The available datapoints are
# :class:`~torchvision.datapoints.Image`,
# :class:`~torchvision.datapoints.BoundingBoxes`,
# :class:`~torchvision.datapoints.Mask`, and
# :class:`~torchvision.datapoints.Video`.
# TVTensors are :class:`torch.Tensor` subclasses. The available tv_tensors are
# :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.tv_tensors.BoundingBoxes`,
# :class:`~torchvision.tv_tensors.Mask`, and
# :class:`~torchvision.tv_tensors.Video`.
#
# Datapoints look and feel just like regular tensors - they **are** tensors.
# TVTensors look and feel just like regular tensors - they **are** tensors.
# Everything that is supported on a plain :class:`torch.Tensor` like ``.sum()``
# or any ``torch.*`` operator will also work on a datapoint:
# or any ``torch.*`` operator will also work on a tv_tensor:

img_dp = datapoints.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))
img_dp = tv_tensors.Image(torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8))

print(f"{isinstance(img_dp, torch.Tensor) = }")
print(f"{img_dp.dtype = }, {img_dp.shape = }, {img_dp.sum() = }")

# %%
# These Datapoint classes are at the core of the transforms: in order to
# These TVTensor classes are at the core of the transforms: in order to
# transform a given input, the transforms first look at the **class** of the
# object, and dispatch to the appropriate implementation accordingly.
#
# You don't need to know much more about datapoints at this point, but advanced
# You don't need to know much more about tv_tensors at this point, but advanced
# users who want to learn more can refer to
# :ref:`sphx_glr_auto_examples_transforms_plot_datapoints.py`.
# :ref:`sphx_glr_auto_examples_transforms_plot_tv_tensors.py`.
#
# What do I pass as input?
# ------------------------
Expand Down Expand Up @@ -196,17 +196,17 @@
# Pure :class:`torch.Tensor` objects are, in general, treated as images (or
# as videos for video-specific transforms). Indeed, you may have noticed
# that in the code above we haven't used the
# :class:`~torchvision.datapoints.Image` class at all, and yet our images
# :class:`~torchvision.tv_tensors.Image` class at all, and yet our images
# got transformed properly. Transforms follow the following logic to
# determine whether a pure Tensor should be treated as an image (or video),
# or just ignored:
#
# * If there is an :class:`~torchvision.datapoints.Image`,
# :class:`~torchvision.datapoints.Video`,
# * If there is an :class:`~torchvision.tv_tensors.Image`,
# :class:`~torchvision.tv_tensors.Video`,
# or :class:`PIL.Image.Image` instance in the input, all other pure
# tensors are passed-through.
# * If there is no :class:`~torchvision.datapoints.Image` or
# :class:`~torchvision.datapoints.Video` instance, only the first pure
# * If there is no :class:`~torchvision.tv_tensors.Image` or
# :class:`~torchvision.tv_tensors.Video` instance, only the first pure
# :class:`torch.Tensor` will be transformed as image or video, while all
# others will be passed-through. Here "first" means "first in a depth-wise
# traversal".
Expand Down Expand Up @@ -234,9 +234,9 @@
# Torchvision also supports datasets for object detection or segmentation like
# :class:`torchvision.datasets.CocoDetection`. Those datasets predate
# the existence of the :mod:`torchvision.transforms.v2` module and of the
# datapoints, so they don't return datapoints out of the box.
# tv_tensors, so they don't return tv_tensors out of the box.
#
# An easy way to force those datasets to return datapoints and to make them
# An easy way to force those datasets to return tv_tensors and to make them
# compatible with v2 transforms is to use the
# :func:`torchvision.datasets.wrap_dataset_for_transforms_v2` function:
#
Expand All @@ -246,14 +246,14 @@
#
# dataset = CocoDetection(..., transforms=my_transforms)
# dataset = wrap_dataset_for_transforms_v2(dataset)
# # Now the dataset returns datapoints!
# # Now the dataset returns tv_tensors!
#
# Using your own datasets
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# If you have a custom dataset, then you'll need to convert your objects into
# the appropriate Datapoint classes. Creating Datapoint instances is very easy,
# refer to :ref:`datapoint_creation` for more details.
# the appropriate TVTensor classes. Creating TVTensor instances is very easy,
# refer to :ref:`tv_tensor_creation` for more details.
#
# There are two main places where you can implement that conversion logic:
#
Expand Down
Loading

0 comments on commit d5f4cc3

Please sign in to comment.