Skip to content

Commit

Permalink
Add typing in GeneralizedRCNNTransform (#4369)
Browse files Browse the repository at this point in the history
* Add types in transform.

* Trace on eval mode.

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
datumbox and fmassa authored Sep 6, 2021
1 parent 5fb36a1 commit 981ccfd
Showing 1 changed file with 19 additions and 28 deletions.
47 changes: 19 additions & 28 deletions torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@


@torch.jit.unused
def _get_shape_onnx(image):
# type: (Tensor) -> Tensor
def _get_shape_onnx(image: Tensor) -> Tensor:
from torch.onnx import operators
return operators.shape_as_tensor(image)[-2:]


@torch.jit.unused
def _fake_cast_onnx(v):
# type: (Tensor) -> float
def _fake_cast_onnx(v: Tensor) -> float:
# ONNX requires a tensor but here we fake its type for JIT.
return v

Expand Down Expand Up @@ -74,7 +72,8 @@ class GeneralizedRCNNTransform(nn.Module):
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
"""

def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32, fixed_size=None):
def __init__(self, min_size: int, max_size: int, image_mean: List[float], image_std: List[float],
size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None):
super(GeneralizedRCNNTransform, self).__init__()
if not isinstance(min_size, (list, tuple)):
min_size = (min_size,)
Expand All @@ -86,10 +85,9 @@ def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32,
self.fixed_size = fixed_size

def forward(self,
images, # type: List[Tensor]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
images: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None
) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
images = [img for img in images]
if targets is not None:
# make a copy of targets to avoid modifying it in-place
Expand Down Expand Up @@ -126,7 +124,7 @@ def forward(self,
image_list = ImageList(images, image_sizes_list)
return image_list, targets

def normalize(self, image):
def normalize(self, image: Tensor) -> Tensor:
if not image.is_floating_point():
raise TypeError(
f"Expected input images to be of floating type (in range [0, 1]), "
Expand All @@ -137,8 +135,7 @@ def normalize(self, image):
std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
return (image - mean[:, None, None]) / std[:, None, None]

def torch_choice(self, k):
# type: (List[int]) -> int
def torch_choice(self, k: List[int]) -> int:
"""
Implements `random.choice` via torch ops so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
Expand Down Expand Up @@ -175,8 +172,7 @@ def resize(self,
# _onnx_batch_images() is an implementation of
# batch_images() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int) -> Tensor
def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
max_size = []
for i in range(images[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
Expand All @@ -197,16 +193,14 @@ def _onnx_batch_images(self, images, size_divisible=32):

return torch.stack(padded_imgs)

def max_by_axis(self, the_list):
# type: (List[List[int]]) -> List[int]
def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes

def batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int) -> Tensor
def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
if torchvision._is_tracing():
# batch_images() does not export well to ONNX
# call _onnx_batch_images() instead
Expand All @@ -226,11 +220,10 @@ def batch_images(self, images, size_divisible=32):
return batched_imgs

def postprocess(self,
result, # type: List[Dict[str, Tensor]]
image_shapes, # type: List[Tuple[int, int]]
original_image_sizes # type: List[Tuple[int, int]]
):
# type: (...) -> List[Dict[str, Tensor]]
result: List[Dict[str, Tensor]],
image_shapes: List[Tuple[int, int]],
original_image_sizes: List[Tuple[int, int]]
) -> List[Dict[str, Tensor]]:
if self.training:
return result
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
Expand All @@ -247,7 +240,7 @@ def postprocess(self,
result[i]["keypoints"] = keypoints
return result

def __repr__(self):
def __repr__(self) -> str:
format_string = self.__class__.__name__ + '('
_indent = '\n '
format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
Expand All @@ -257,8 +250,7 @@ def __repr__(self):
return format_string


def resize_keypoints(keypoints, original_size, new_size):
# type: (Tensor, List[int], List[int]) -> Tensor
def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
ratios = [
torch.tensor(s, dtype=torch.float32, device=keypoints.device) /
torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
Expand All @@ -276,8 +268,7 @@ def resize_keypoints(keypoints, original_size, new_size):
return resized_data


def resize_boxes(boxes, original_size, new_size):
# type: (Tensor, List[int], List[int]) -> Tensor
def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
ratios = [
torch.tensor(s, dtype=torch.float32, device=boxes.device) /
torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
Expand Down

0 comments on commit 981ccfd

Please sign in to comment.