diff --git a/torchvision/utils.py b/torchvision/utils.py index 5e548496b05..b6880c83f2d 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -2,23 +2,24 @@ import pathlib import torch import math +import warnings import numpy as np from PIL import Image, ImageDraw from PIL import ImageFont __all__ = ["make_grid", "save_image", "draw_bounding_boxes"] -irange = range - +@torch.no_grad() def make_grid( tensor: Union[torch.Tensor, List[torch.Tensor]], nrow: int = 8, padding: int = 2, normalize: bool = False, - range: Optional[Tuple[int, int]] = None, + value_range: Optional[Tuple[int, int]] = None, scale_each: bool = False, pad_value: int = 0, + **kwargs ) -> torch.Tensor: """Make a grid of images. @@ -30,7 +31,7 @@ def make_grid( padding (int, optional): amount of padding. Default: ``2``. normalize (bool, optional): If True, shift the image to the range (0, 1), by the min and max values specified by :attr:`range`. Default: ``False``. - range (tuple, optional): tuple (min, max) where min and max are numbers, + value_range (tuple, optional): tuple (min, max) where min and max are numbers, then these numbers are used to normalize the image. By default, min and max are computed from the tensor. scale_each (bool, optional): If ``True``, scale each image in the batch of @@ -43,7 +44,12 @@ def make_grid( """ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): - raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor))) + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if "range" in kwargs.keys(): + warning = "range will be deprecated, please use value_range instead." + warnings.warn(warning) + value_range = kwargs["range"] # if list of tensors, convert to a 4D mini-batch Tensor if isinstance(tensor, list): @@ -61,25 +67,25 @@ def make_grid( if normalize is True: tensor = tensor.clone() # avoid modifying tensor in-place - if range is not None: - assert isinstance(range, tuple), \ - "range has to be a tuple (min, max) if specified. min and max are numbers" + if value_range is not None: + assert isinstance(value_range, tuple), \ + "value_range has to be a tuple (min, max) if specified. min and max are numbers" def norm_ip(img, low, high): img.clamp_(min=low, max=high) img.sub_(low).div_(max(high - low, 1e-5)) - def norm_range(t, range): - if range is not None: - norm_ip(t, range[0], range[1]) + def norm_range(t, value_range): + if value_range is not None: + norm_ip(t, value_range[0], value_range[1]) else: norm_ip(t, float(t.min()), float(t.max())) if scale_each is True: for t in tensor: # loop over mini-batch dimension - norm_range(t, range) + norm_range(t, value_range) else: - norm_range(tensor, range) + norm_range(tensor, value_range) if tensor.size(0) == 1: return tensor.squeeze(0) @@ -92,8 +98,8 @@ def norm_range(t, range): num_channels = tensor.size(1) grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) k = 0 - for y in irange(ymaps): - for x in irange(xmaps): + for y in range(ymaps): + for x in range(xmaps): if k >= nmaps: break # Tensor.copy_() is a valid method but seems to be missing from the stubs @@ -105,16 +111,13 @@ def norm_range(t, range): return grid +@torch.no_grad() def save_image( tensor: Union[torch.Tensor, List[torch.Tensor]], fp: Union[Text, pathlib.Path, BinaryIO], - nrow: int = 8, - padding: int = 2, normalize: bool = False, - range: Optional[Tuple[int, int]] = None, - scale_each: bool = False, - pad_value: int = 0, format: Optional[str] = None, + **kwargs ) -> None: """Save a given Tensor into an image file. @@ -126,8 +129,8 @@ def save_image( If a file object was used instead of a filename, this parameter should always be used. **kwargs: Other arguments are documented in ``make_grid``. """ - grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, - normalize=normalize, range=range, scale_each=scale_each) + + grid = make_grid(tensor, **kwargs) # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr)