diff --git a/torchvision/utils.py b/torchvision/utils.py index e277d0c7253..22ee699df7a 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -74,7 +74,7 @@ def norm_range(t, range): xmaps = min(nrow, nmaps) ymaps = int(math.ceil(float(nmaps) / xmaps)) height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) - grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding).fill_(pad_value) + grid = tensor.new_full((3, height * ymaps + padding, width * xmaps + padding), pad_value) k = 0 for y in irange(ymaps): for x in irange(xmaps): @@ -99,6 +99,7 @@ def save_image(tensor, filename, nrow=8, padding=2, from PIL import Image grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, normalize=normalize, range=range, scale_each=scale_each) - ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() + # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer + ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) im.save(filename)