From d86c93e58f68b5f04a7f993ae65c41539a8b8d79 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Mon, 8 Oct 2018 15:49:43 -0700 Subject: [PATCH 1/2] Fix inception v3 input transform for trace & onnx * Input transform are in-place updates, which produce issues for tracing and exporting to onnx. --- torchvision/models/inception.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index 425c5b88bcb..c17bf3d321a 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -70,10 +70,11 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): def forward(self, x): if self.transform_input: - x = x.clone() - x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 - x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 - x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x = torch.cat(( + torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5, + torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5, + torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + ), 1) # 299 x 299 x 3 x = self.Conv2d_1a_3x3(x) # 149 x 149 x 32 From fe8fe7099086316f391d5427845e8e90da41f502 Mon Sep 17 00:00:00 2001 From: Bowen Bao Date: Mon, 8 Oct 2018 16:25:48 -0700 Subject: [PATCH 2/2] nit --- torchvision/models/inception.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index c17bf3d321a..1d72319d0f7 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -70,11 +70,10 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False): def forward(self, x): if self.transform_input: - x = torch.cat(( - torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5, - torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5, - torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 - ), 1) + x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 + x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 + x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 + x = torch.cat((x_ch0, x_ch1, x_ch2), 1) # 299 x 299 x 3 x = self.Conv2d_1a_3x3(x) # 149 x 149 x 32