Skip to content

Commit

Permalink
Fix inception v3 input transform for trace & onnx (#621)
Browse files Browse the repository at this point in the history
* Fix inception v3 input transform for trace & onnx

* Input transform are in-place updates, which produce issues for tracing
and exporting to onnx.

* nit
  • Loading branch information
BowenBao authored and soumith committed Oct 25, 2018
1 parent 8f943d4 commit 85369e3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchvision/models/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):

def forward(self, x):
if self.transform_input:
x = x.clone()
x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
# 299 x 299 x 3
x = self.Conv2d_1a_3x3(x)
# 149 x 149 x 32
Expand Down

0 comments on commit 85369e3

Please sign in to comment.