diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 425c5b88bcb..1d72319d0f7 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -70,10 +70,10 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): def forward(self, x): if self.transform_input: - x = x.clone() - x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 - x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 - x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) # 299 x 299 x 3 x = self.Conv2d_1a_3x3(x) # 149 x 149 x 32