Skip to content

Commit

Permalink
resolve issue 1818 by modifying mean and standard deviation in the tr…
Browse files Browse the repository at this point in the history
…ansforms.Normalize (#2405)

* Fixes #2083 - explain model.eval, torch.no_grad
* set norm to mean & std of CIFAR10(#1818)
---------

Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
  • Loading branch information
zabboud and Svetlana Karslioglu authored Jun 2, 2023
1 parent d41e23b commit dd6a55d
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions beginner_source/introyt/introyt1_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def num_flat_features(self, x):

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])


##########################################################################
Expand All @@ -297,9 +297,28 @@ def num_flat_features(self, x):
# - ``transforms.ToTensor()`` converts images loaded by Pillow into
# PyTorch tensors.
# - ``transforms.Normalize()`` adjusts the values of the tensor so
# that their average is zero and their standard deviation is 0.5. Most
# that their average is zero and their standard deviation is 1.0. Most
# activation functions have their strongest gradients around x = 0, so
# centering our data there can speed learning.
# The values passed to the transform are the means (first tuple) and the
# standard deviations (second tuple) of the rgb values of the images in
# the dataset. You can calculate these values yourself by running these
# few lines of code:
# ```
# from torch.utils.data import ConcatDataset
# transform = transforms.Compose([transforms.ToTensor()])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
# download=True, transform=transform)
#
# #stack all train images together into a tensor of shape
# #(50000, 3, 32, 32)
# x = torch.stack([sample[0] for sample in ConcatDataset([trainset])])
#
# #get the mean of each channel
# mean = torch.mean(x, dim=(0,2,3)) #tensor([0.4914, 0.4822, 0.4465])
# std = torch.std(x, dim=(0,2,3)) #tensor([0.2470, 0.2435, 0.2616])
#
# ```
#
# There are many more transforms available, including cropping, centering,
# rotation, and reflection.
Expand Down

0 comments on commit dd6a55d

Please sign in to comment.