diff --git a/torchvision/utils.py b/torchvision/utils.py index 5e548496b05..5aea1c98eb7 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -3,12 +3,12 @@ import torch import math import numpy as np +import warnings from PIL import Image, ImageDraw from PIL import ImageFont -__all__ = ["make_grid", "save_image", "draw_bounding_boxes"] -irange = range +__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "get_range"] def make_grid( @@ -16,9 +16,12 @@ def make_grid( nrow: int = 8, padding: int = 2, normalize: bool = False, - range: Optional[Tuple[int, int]] = None, + # range: Optional[Tuple[int, int]] = None, + vmin: int = None, + vmax: int = None, scale_each: bool = False, pad_value: int = 0, + **kwargs ) -> torch.Tensor: """Make a grid of images. @@ -30,9 +33,11 @@ def make_grid( padding (int, optional): amount of padding. Default: ``2``. normalize (bool, optional): If True, shift the image to the range (0, 1), by the min and max values specified by :attr:`range`. Default: ``False``. - range (tuple, optional): tuple (min, max) where min and max are numbers, - then these numbers are used to normalize the image. By default, min and max - are computed from the tensor. + range (tuple, optional): Deprecated (see :attr:`vmin, vmax`). tuple (min, max) + where min and max are numbers,then these numbers are used to normalize the image. + By default, min and max are computed from the tensor. + vmin (int, optional): vmin and vmax define the data range that used to normalize the image. + vmax (int, optional): vmin and vmax define the data range that used to normalize the image. scale_each (bool, optional): If ``True``, scale each image in the batch of images separately rather than the (min, max) over all images. Default: ``False``. pad_value (float, optional): Value for the padded pixels. Default: ``0``. @@ -61,25 +66,27 @@ def make_grid( if normalize is True: tensor = tensor.clone() # avoid modifying tensor in-place - if range is not None: - assert isinstance(range, tuple), \ - "range has to be a tuple (min, max) if specified. min and max are numbers" - + if "range" in kwargs.keys(): + if kwargs["range"] is not None: + assert isinstance(kwargs["range"], tuple), \ + "range has to be a tuple (min, max) if specified. min and max are numbers" + vmin, vmax = get_range(kwargs["range"]) def norm_ip(img, low, high): img.clamp_(min=low, max=high) img.sub_(low).div_(max(high - low, 1e-5)) - def norm_range(t, range): - if range is not None: - norm_ip(t, range[0], range[1]) + def norm_range(t, vmin, vmax): + if vmin is not None and vmax is not None: + norm_ip(t, vmin, vmax) else: norm_ip(t, float(t.min()), float(t.max())) if scale_each is True: - for t in tensor: # loop over mini-batch dimension - norm_range(t, range) + for i in range(tensor.size(0)): # loop over mini-batch dimension + t = tensor[i] + norm_range(t, vmin, vmax) else: - norm_range(tensor, range) + norm_range(tensor, vmin, vmax) if tensor.size(0) == 1: return tensor.squeeze(0) @@ -92,8 +99,8 @@ def norm_range(t, range): num_channels = tensor.size(1) grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) k = 0 - for y in irange(ymaps): - for x in irange(xmaps): + for y in range(ymaps): + for x in range(xmaps): if k >= nmaps: break # Tensor.copy_() is a valid method but seems to be missing from the stubs @@ -111,10 +118,13 @@ def save_image( nrow: int = 8, padding: int = 2, normalize: bool = False, - range: Optional[Tuple[int, int]] = None, + # range: Optional[Tuple[int, int]] = None, + vmin: int = None, + vmax: int = None, scale_each: bool = False, pad_value: int = 0, format: Optional[str] = None, + **kwargs ) -> None: """Save a given Tensor into an image file. @@ -126,8 +136,14 @@ def save_image( If a file object was used instead of a filename, this parameter should always be used. **kwargs: Other arguments are documented in ``make_grid``. """ + if "range" in kwargs.keys(): + if kwargs["range"] is not None: + assert isinstance(kwargs["range"], tuple), \ + "range has to be a tuple (min, max) if specified. min and max are numbers" + vmin, vmax = get_range(kwargs["range"]) + grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, - normalize=normalize, range=range, scale_each=scale_each) + normalize=normalize, vmin=vmin, vmax=vmax, scale_each=scale_each) # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) @@ -187,3 +203,15 @@ def draw_bounding_boxes( draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font) return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1) + + +def get_range(range_tmp, emit_warning=True) -> (int, int): + """ + In order to support previous versions, accept range argument and convert this into (vmin, vmax). + """ + + warning = "range will be deprecated, please use vmin and vmax args instead." + vmin_, vmax_ = range_tmp + if emit_warning: + warnings.warn(warning) + return (vmin_, vmax_ )