From 05ba3f53d1a40f497d4af4768428b7fb1a8dd550 Mon Sep 17 00:00:00 2001 From: Surgan Jandial Date: Fri, 28 Jun 2019 14:35:31 +0530 Subject: [PATCH] Doc, Test Fixes in Normalize (#1063) * updates on normalize * test fixes * Update test_transforms.py --- test/test_transforms.py | 5 +++++ torchvision/transforms/functional.py | 1 + torchvision/transforms/transforms.py | 2 ++ 3 files changed, 8 insertions(+) diff --git a/test/test_transforms.py b/test/test_transforms.py index 794e7a07c07..79e315a7729 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -820,6 +820,11 @@ def samples_from_standard_normal(tensor): # Checking if Normalize can be printed as string transforms.Normalize(mean, std).__repr__() + # Checking the optional in-place behaviour + tensor = torch.rand((1, 16, 16)) + tensor_inplace = transforms.Normalize((0.5,), (0.5,), inplace=True)(tensor) + assert torch.equal(tensor, tensor_inplace) + def test_normalize_different_dtype(self): for dtype1 in [torch.float32, torch.float64]: img = torch.rand(3, 10, 10, dtype=dtype1) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index b543b14f15e..a75493ad516 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -200,6 +200,7 @@ def normalize(tensor, mean, std, inplace=False): tensor (Tensor): Tensor image of size (C, H, W) to be normalized. mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation inplace. Returns: Tensor: Normalized Tensor image. diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 0b50144c07f..4cf099faeac 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -146,6 +146,8 @@ class Normalize(object): Args: mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channel. + inplace(bool,optional): Bool to make this operation in-place. + """ def __init__(self, mean, std, inplace=False):