diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index c4e83fa364f..386513b4071 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -1,3 +1,4 @@ +import torch from torch import nn from torch import Tensor from .utils import load_state_dict_from_url @@ -189,8 +190,9 @@ def _forward_impl(self, x: Tensor) -> Tensor: # This exists since TorchScript doesn't support inheritance, so the superclass method # (this one) needs to have a name other than `forward` that can be accessed in a subclass x = self.features(x) - # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] - x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1) + # Cannot use "squeeze" as batch-size can be 1 + x = nn.functional.adaptive_avg_pool2d(x, (1, 1)) + x = torch.flatten(x, 1) x = self.classifier(x) return x