Skip to content

Commit

Permalink
Replace reshape with flatten (#3462)
Browse files Browse the repository at this point in the history
Summary:
Current implementation is generating bad graph after onnx conversion. So replacing with flatten like in mobilenetv3 code.

Reviewed By: fmassa

Differential Revision: D26756271

fbshipit-source-id: 68751201436147c179532b4d35e1140cb0f56967

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Mar 4, 2021
1 parent ebcd2f3 commit 828395a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torchvision/models/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch import nn
from torch import Tensor
from .utils import load_state_dict_from_url
Expand Down Expand Up @@ -189,8 +190,9 @@ def _forward_impl(self, x: Tensor) -> Tensor:
# This exists since TorchScript doesn't support inheritance, so the superclass method
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
x = self.features(x)
# Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
# Cannot use "squeeze" as batch-size can be 1
x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
x = torch.flatten(x, 1)
x = self.classifier(x)
return x

Expand Down

0 comments on commit 828395a

Please sign in to comment.