Skip to content

Commit

Permalink
Fixes no grad and range bugs in utils. (#3269)
Browse files Browse the repository at this point in the history
* Fixes utils

* don't use any

* slightly simplify logic

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
  • Loading branch information
oke-aditya and datumbox authored Jan 20, 2021
1 parent 631ff91 commit 767b23e
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@
import pathlib
import torch
import math
import warnings
import numpy as np
from PIL import Image, ImageDraw
from PIL import ImageFont

__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]

irange = range


@torch.no_grad()
def make_grid(
tensor: Union[torch.Tensor, List[torch.Tensor]],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
range: Optional[Tuple[int, int]] = None,
value_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
**kwargs
) -> torch.Tensor:
"""Make a grid of images.
Expand All @@ -30,7 +31,7 @@ def make_grid(
padding (int, optional): amount of padding. Default: ``2``.
normalize (bool, optional): If True, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``.
range (tuple, optional): tuple (min, max) where min and max are numbers,
value_range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each (bool, optional): If ``True``, scale each image in the batch of
Expand All @@ -43,7 +44,12 @@ def make_grid(
"""
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')

if "range" in kwargs.keys():
warning = "range will be deprecated, please use value_range instead."
warnings.warn(warning)
value_range = kwargs["range"]

# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
Expand All @@ -61,25 +67,25 @@ def make_grid(

if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if range is not None:
assert isinstance(range, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"
if value_range is not None:
assert isinstance(value_range, tuple), \
"value_range has to be a tuple (min, max) if specified. min and max are numbers"

def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5))

def norm_range(t, range):
if range is not None:
norm_ip(t, range[0], range[1])
def norm_range(t, value_range):
if value_range is not None:
norm_ip(t, value_range[0], value_range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))

if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, range)
norm_range(t, value_range)
else:
norm_range(tensor, range)
norm_range(tensor, value_range)

if tensor.size(0) == 1:
return tensor.squeeze(0)
Expand All @@ -92,8 +98,8 @@ def norm_range(t, range):
num_channels = tensor.size(1)
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
k = 0
for y in irange(ymaps):
for x in irange(xmaps):
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
# Tensor.copy_() is a valid method but seems to be missing from the stubs
Expand All @@ -105,16 +111,13 @@ def norm_range(t, range):
return grid


@torch.no_grad()
def save_image(
tensor: Union[torch.Tensor, List[torch.Tensor]],
fp: Union[Text, pathlib.Path, BinaryIO],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
format: Optional[str] = None,
**kwargs
) -> None:
"""Save a given Tensor into an image file.
Expand All @@ -126,8 +129,8 @@ def save_image(
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)

grid = make_grid(tensor, **kwargs)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
Expand Down

0 comments on commit 767b23e

Please sign in to comment.