Skip to content

Commit

Permalink
Remove interpolate in favor of PyTorch's implementation (#2252)
Browse files Browse the repository at this point in the history
* Remove interpolate in favor of PyTorch's implementation

* Bugfix

* Bugfix
  • Loading branch information
fmassa authored Jun 1, 2020
1 parent 98aa805 commit b40f49f
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 64 deletions.
4 changes: 1 addition & 3 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
from torch import nn

from torchvision.ops import misc as misc_nn_ops

from torchvision.ops import MultiScaleRoIAlign

from ..utils import load_state_dict_from_url
Expand Down Expand Up @@ -253,7 +251,7 @@ def __init__(self, in_channels, num_keypoints):

def forward(self, x):
x = self.kps_score_lowres(x)
x = misc_nn_ops.interpolate(
x = torch.nn.functional.interpolate(
x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False
)
return x
Expand Down
13 changes: 6 additions & 7 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch import nn, Tensor

from torchvision.ops import boxes as box_ops
from torchvision.ops import misc as misc_nn_ops

from torchvision.ops import roi_align

Expand Down Expand Up @@ -175,8 +174,8 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
width_correction = widths_i / roi_map_width
height_correction = heights_i / roi_map_height

roi_map = torch.nn.functional.interpolate(
maps_i[None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[0]
roi_map = F.interpolate(
maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0]

w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
Expand Down Expand Up @@ -256,8 +255,8 @@ def heatmaps_to_keypoints(maps, rois):
roi_map_height = int(heights_ceil[i].item())
width_correction = widths[i] / roi_map_width
height_correction = heights[i] / roi_map_height
roi_map = torch.nn.functional.interpolate(
maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0]
roi_map = F.interpolate(
maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0]
# roi_map_probs = scores_to_probs(roi_map.copy())
w = roi_map.shape[2]
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
Expand Down Expand Up @@ -392,7 +391,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, -1, -1))

# Resize mask
mask = misc_nn_ops.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0]

im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
Expand Down Expand Up @@ -420,7 +419,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, mask.size(0), mask.size(1)))

# Resize mask
mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
mask = mask[0][0]

x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import math
import torch
from torch import nn, Tensor
from torch.nn import functional as F
import torchvision
from torch.jit.annotations import List, Tuple, Dict, Optional

from torchvision.ops import misc as misc_nn_ops
from .image_list import ImageList
from .roi_heads import paste_masks_in_image

Expand All @@ -28,7 +28,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):

if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
target["masks"] = mask
return image, target

Expand All @@ -50,7 +50,7 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target):

if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
target["masks"] = mask
return image, target

Expand Down
52 changes: 1 addition & 51 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from collections import OrderedDict
from torch.jit.annotations import Optional, List
from torch import Tensor

"""
helper class that supports empty tensors on some nn functions.
Expand All @@ -12,10 +8,8 @@
is implemented
"""

import math
import warnings
import torch
from torchvision.ops import _new_empty_tensor


class Conv2d(torch.nn.Conv2d):
Expand All @@ -42,51 +36,7 @@ def __init__(self, *args, **kwargs):
"removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning)


def _check_size_scale_factor(dim, size, scale_factor):
# type: (int, Optional[List[int]], Optional[float]) -> None
if size is None and scale_factor is None:
raise ValueError("either size or scale_factor should be defined")
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
if scale_factor is not None:
if isinstance(scale_factor, (list, tuple)):
if len(scale_factor) != dim:
raise ValueError(
"scale_factor shape must match input shape. "
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
)


def _output_size(dim, input, size, scale_factor):
# type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
assert dim == 2
_check_size_scale_factor(dim, size, scale_factor)
if size is not None:
return size
# if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
assert scale_factor is not None and isinstance(scale_factor, (int, float))
scale_factors = [scale_factor, scale_factor]
# math.floor might return float in py2.7
return [
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
]


def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)

output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
interpolate = torch.nn.functional.interpolate


# This is not in nn
Expand Down

0 comments on commit b40f49f

Please sign in to comment.