Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resolve issue 1818 by modifying mean and standard deviation in the transforms.Normalize #2405

Merged
merged 5 commits into from
Jun 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions beginner_source/introyt/introyt1_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def num_flat_features(self, x):

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])


##########################################################################
Expand All @@ -297,9 +297,28 @@ def num_flat_features(self, x):
# - ``transforms.ToTensor()`` converts images loaded by Pillow into
# PyTorch tensors.
# - ``transforms.Normalize()`` adjusts the values of the tensor so
# that their average is zero and their standard deviation is 0.5. Most
# that their average is zero and their standard deviation is 1.0. Most
# activation functions have their strongest gradients around x = 0, so
# centering our data there can speed learning.
# The values passed to the transform are the means (first tuple) and the
# standard deviations (second tuple) of the rgb values of the images in
# the dataset. You can calculate these values yourself by running these
# few lines of code:
# ```
# from torch.utils.data import ConcatDataset
# transform = transforms.Compose([transforms.ToTensor()])
# trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
# download=True, transform=transform)
#
# #stack all train images together into a tensor of shape
# #(50000, 3, 32, 32)
# x = torch.stack([sample[0] for sample in ConcatDataset([trainset])])
#
# #get the mean of each channel
# mean = torch.mean(x, dim=(0,2,3)) #tensor([0.4914, 0.4822, 0.4465])
# std = torch.std(x, dim=(0,2,3)) #tensor([0.2470, 0.2435, 0.2616])
#
# ```
#
# There are many more transforms available, including cropping, centering,
# rotation, and reflection.
Expand Down