Skip to content

Commit

Permalink
Fix pytorch#3025 : deprecated range argument, vmin and vmax args instead
Browse files Browse the repository at this point in the history
  • Loading branch information
TingsongYu committed Dec 10, 2020
1 parent 4d3a309 commit 497f8a0
Showing 1 changed file with 48 additions and 20 deletions.
68 changes: 48 additions & 20 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@
import torch
import math
import numpy as np
import warnings
from PIL import Image, ImageDraw
from PIL import ImageFont

__all__ = ["make_grid", "save_image", "draw_bounding_boxes"]

irange = range
__all__ = ["make_grid", "save_image", "draw_bounding_boxes", "get_range"]


def make_grid(
tensor: Union[torch.Tensor, List[torch.Tensor]],
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
range: Optional[Tuple[int, int]] = None,
# range: Optional[Tuple[int, int]] = None,
vmin: int = None,
vmax: int = None,
scale_each: bool = False,
pad_value: int = 0,
**kwargs
) -> torch.Tensor:
"""Make a grid of images.
Expand All @@ -30,9 +33,11 @@ def make_grid(
padding (int, optional): amount of padding. Default: ``2``.
normalize (bool, optional): If True, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``.
range (tuple, optional): tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
range (tuple, optional): Deprecated (see :attr:`vmin, vmax`). tuple (min, max)
where min and max are numbers,then these numbers are used to normalize the image.
By default, min and max are computed from the tensor.
vmin (int, optional): vmin and vmax define the data range that used to normalize the image.
vmax (int, optional): vmin and vmax define the data range that used to normalize the image.
scale_each (bool, optional): If ``True``, scale each image in the batch of
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
Expand Down Expand Up @@ -61,25 +66,27 @@ def make_grid(

if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if range is not None:
assert isinstance(range, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"

if "range" in kwargs.keys():
if kwargs["range"] is not None:
assert isinstance(kwargs["range"], tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"
vmin, vmax = get_range(kwargs["range"])
def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5))

def norm_range(t, range):
if range is not None:
norm_ip(t, range[0], range[1])
def norm_range(t, vmin, vmax):
if vmin is not None and vmax is not None:
norm_ip(t, vmin, vmax)
else:
norm_ip(t, float(t.min()), float(t.max()))

if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, range)
for i in range(tensor.size(0)): # loop over mini-batch dimension
t = tensor[i]
norm_range(t, vmin, vmax)
else:
norm_range(tensor, range)
norm_range(tensor, vmin, vmax)

if tensor.size(0) == 1:
return tensor.squeeze(0)
Expand All @@ -92,8 +99,8 @@ def norm_range(t, range):
num_channels = tensor.size(1)
grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value)
k = 0
for y in irange(ymaps):
for x in irange(xmaps):
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
# Tensor.copy_() is a valid method but seems to be missing from the stubs
Expand All @@ -111,10 +118,13 @@ def save_image(
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
range: Optional[Tuple[int, int]] = None,
# range: Optional[Tuple[int, int]] = None,
vmin: int = None,
vmax: int = None,
scale_each: bool = False,
pad_value: int = 0,
format: Optional[str] = None,
**kwargs
) -> None:
"""Save a given Tensor into an image file.
Expand All @@ -126,8 +136,14 @@ def save_image(
If a file object was used instead of a filename, this parameter should always be used.
**kwargs: Other arguments are documented in ``make_grid``.
"""
if "range" in kwargs.keys():
if kwargs["range"] is not None:
assert isinstance(kwargs["range"], tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"
vmin, vmax = get_range(kwargs["range"])

grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
normalize=normalize, vmin=vmin, vmax=vmax, scale_each=scale_each)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
Expand Down Expand Up @@ -187,3 +203,15 @@ def draw_bounding_boxes(
draw.text((bbox[0], bbox[1]), labels[i], fill=color, font=txt_font)

return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1)


def get_range(range_tmp, emit_warning=True) -> (int, int):
"""
In order to support previous versions, accept range argument and convert this into (vmin, vmax).
"""

warning = "range will be deprecated, please use vmin and vmax args instead."
vmin_, vmax_ = range_tmp
if emit_warning:
warnings.warn(warning)
return (vmin_, vmax_ )

0 comments on commit 497f8a0

Please sign in to comment.