Improved transforms, native image IO, new video API and more
This release brings new additions to torchvision that improves support for model deployment. Most notably, transforms in torchvision are now torchscript-compatible, and can thus be serialized together with your model for simpler deployment. Additionally, we provide native image IO with torchscript support, and a new video reading API (released as Beta) which is more flexible than torchvision.io.read_video
.
Highlights
Transforms now support Tensor, batch computation, GPU and TorchScript
torchvision transforms are now inherited from nn.Module and can be torchscripted and applied on torch Tensor inputs as well as on PIL images. They also support Tensors with batch dimension and work seamlessly on CPU/GPU devices:
import torch
import torchvision.transforms as T
# to fix random seed, use torch.manual_seed
# instead of random.seed
torch.manual_seed(12)
transforms = torch.nn.Sequential(
T.RandomCrop(224),
T.RandomHorizontalFlip(p=0.3),
T.ConvertImageDtype(torch.float),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)
scripted_transforms = torch.jit.script(transforms)
# Note: we can similarly use T.Compose to define transforms
# transforms = T.Compose([...]) and
# scripted_transforms = torch.jit.script(torch.nn.Sequential(*transforms.transforms))
tensor_image = torch.randint(0, 256, size=(3, 256, 256), dtype=torch.uint8)
# works directly on Tensors
out_image1 = transforms(tensor_image)
# on the GPU
out_image1_cuda = transforms(tensor_image.cuda())
# with batches
batched_image = torch.randint(0, 256, size=(4, 3, 256, 256), dtype=torch.uint8)
out_image_batched = transforms(batched_image)
# and has torchscript support
out_image2 = scripted_transforms(tensor_image)
These improvements enable the following new features:
- support for GPU acceleration
- batched transformations e.g. as needed for videos
- transform multi-band torch tensor images (with more than 3-4 channels)
- torchscript transforms together with your model for deployment
Note: Exceptions for TorchScript support includes Compose
, RandomChoice
, RandomOrder
, Lambda
and those applied on PIL images, such as ToPILImage
.
Native image IO for JPEG and PNG formats
torchvision 0.8.0 introduces native image reading and writing operations for JPEG and PNG formats. Those operators support TorchScript and return CxHxW
tensors in uint8
format, and can thus be now part of your model for deployment in C++ environments.
from torchvision.io import read_image
# tensor_image is a CxHxW uint8 Tensor
tensor_image = read_image('path_to_image.jpeg')
# or equivalently
from torchvision.io.image import read_file, decode_image
# raw_data is a 1d uint8 Tensor with the raw bytes
raw_data = read_file('path_to_image.jpeg')
tensor_image = decode_image(raw_data)
# all operators are torchscriptable and can be
# serialized together with your model torchscript code
scripted_read_image = torch.jit.script(read_image)
New detection model
This release adds a pretrained model for RetinaNet with a ResNet50 backbone from Focal Loss for Dense Object Detection, with the following accuracies on COCO val2017:
IoU metric: bbox
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.364
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.558
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.383
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.193
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.400
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.490
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.315
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.506
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.558
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.386
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.595
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.699
[BETA] New Video Reader API
This release introduces a new video reading abstraction, which gives more fine-grained control on how to iterate over the videos. It supports image and audio, and implements an iterator interface so that it can be combined with the rest of the python ecosystem, such as itertools
.
from torchvision.io import VideoReader
# stream indicates if reading from audio or video
reader = VideoReader('path_to_video.mp4', stream='video')
# can change the stream after construction
# via reader.set_current_stream
# to read all frames in a video starting at 2 seconds
for frame in reader.seek(2):
# frame is a dict with "data" and "pts" metadata
print(frame["data"], frame["pts"])
# because reader is an iterator you can combine it with
# itertools
from itertools import takewhile, islice
# read 10 frames starting from 2 seconds
for frame in islice(reader.seek(2), 10):
pass
# or to return all frames between 2 and 5 seconds
for frame in takewhile(lambda x: x["pts"] < 5, reader.seek(2)):
pass
Note: In order to use the Video Reader API, you need to compile torchvision from source and make sure that you have ffmpeg installed in your system.
Note: the VideoReader API is currently released as beta and its API can change following user feedback.
Backwards Incompatible Changes
- [Transforms] Random seed now should be set with
torch.manual_seed
instead ofrandom.seed
(#2292) - [Transforms]
RandomErasing.get_params
function’s argument was previouslyvalue=0
and is nowvalue=None
which is interpreted as Gaussian random noise (#2386) - [Transforms]
RandomPerspective
andF.perspective
changed the default value of interpolation to beBILINEAR
instead ofBICUBIC
(#2558, #2561) - [Transforms] Fixes incoherence in
affine
transformation when center is defined as half image size + 0.5 (#2468)
New Features
- [Ops] Added focal loss (#2784)
- [Ops] Added bounding boxes conversion function (#2710, #2737)
- [Ops] Added Generalized IOU (#2642)
- [Models] Added RetinaNet object detection model (#2784)
- [Datasets] Added Places365 dataset (#2610, #2625)
- [Transforms] Added GaussianBlur transform (#2658)
- [Transforms] Added torchscript, batch and GPU and tensor support for transforms (#2769, #2767, #2749, #2755, #2485, #2721, #2645, #2694, #2584, #2661, #2566, #2345, #2342, #2356, #2368, #2373, #2496, #2553, #2495, #2561, #2518, #2478, #2459, #2444, #2396, #2401, #2394, #2586, #2371, #2477, #2456, #2628, #2569, #2639, #2620, #2595, #2456, #2403, #2729)
- [Transforms] Added example notebook for tensor transforms (#2730)
- [IO] Added JPEG/PNG encoding / decoding ops
- [IO] Added file reading / writing ops (#2728, #2765, #2768)
- [IO] [BETA] Added new VideoReader API (#2683, #2781, #2778, #2802, #2596, #2612, #2734, #2770)
Improvements
Datasets
- Added error message if Google Drive download quota is exceeded (#2321)
- Optimized LSUN initialization time by only pulling keys from db (#2544)
- Use more precise return type for gzip.open() (#2792)
- Added UCF101 dataset tests (#2548)
- Added download tests on a schedule (#2665, #2675, #2699, #2706, #2747, #2731)
- Added typehints for datasets (#2487, #2521, #2522, #2523, #2524, #2526, #2528, #2529, #2525, #2527, #2530, #2533, #2534, #2535, #2536, #2532, #2538, #2537, #2539, #2531, #2540, #2667)
Models
- Removed hard coded value in DeepLabV3 (#2793)
- Changed the anchor generator default argument to an equivalent one (#2722)
- Moved model construction location in
resnet_fpn_backbone
into after docstring (#2482) - Partially enabled type hints for models (#2668)
Ops
- Moved RoIs shape check to C++ (#2794)
- Use autocast built-in cast-helper functions (#2646)
- Adde type annotations for
torchvision.ops
(#2331, #2462)
References
- [References] Removed redundant target send to device in detection evaluation (#2503)
- [References] Removed obsolete import in segmentation. (#2399)
Misc
- [Transforms] Added support for negative padding in
pad
(#2744) - [IO] Added type hints for
torchvision.io
(#2543) - [ONNX] Export
ROIAlign
withaligned=True
(#2613)
Internal
- [Binaries] Added CUDA 11 binary builds (#2671)
- [Binaries] Added DEBUG=1 option to build torchvision (#2603)
- [Binaries] Unpin ninja version (#2358)
- Warn if torchvision imported from repo root (#2759)
- Added compatibility checks for C++ extensions (#2467)
- Added probot (#2448)
- Added ipynb to git attributes file (#2772)
- CI improvements (#2328, #2346, #2374, #2437, #2465, #2579, #2577, #2633, #2640, #2727, #2754, #2674, #2678)
- CMakeList improvements (#2739, #2684, #2626, #2585, #2587)
- Documentation improvements (#2659, #2615, #2614, #2542, #2685, #2507, #2760, #2550, #2656, #2723, #2601, #2654, #2757, #2592, #2606)
Bug Fixes
- [Ops] Fixed crash in deformable convolutions (#2604)
- [Ops] Added empty batch support for
DeformConv2d
(#2782) - [Transforms] Enforced contiguous output in
to_tensor
(#2483) - [Transforms] Fixed fill parameter for PIL pad (#2515)
- [Models] Fixed deprecation warning in
nonzero
for R-CNN models (#2705) - [IO] Explicitly cast to
size_t
in video decoder (#2389) - [ONNX] Fixed dynamic resize in Mask R-CNN (#2488)
- [C++ API] Fixed function signatures for
torch::nn::Functional
(#2463)
Deprecations
- [Transforms] Deprecated dedicated implementations
functional_tensor
ofF_t.center_crop
,F_t.five_crop
,F_t.ten_crop
, as they can be implemented as a function ofcrop
(#2568) - [Transforms] Deprecated explicit usage of
F_pil
andF_t
functions, users should instead use the general functional API (#2664)