Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing in GeneralizedRCNNTransform #4369

Merged
merged 3 commits into from
Sep 6, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 19 additions & 28 deletions torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@


@torch.jit.unused
def _get_shape_onnx(image):
# type: (Tensor) -> Tensor
def _get_shape_onnx(image: Tensor) -> Tensor:
from torch.onnx import operators
return operators.shape_as_tensor(image)[-2:]


@torch.jit.unused
def _fake_cast_onnx(v):
# type: (Tensor) -> float
def _fake_cast_onnx(v: Tensor) -> float:
# ONNX requires a tensor but here we fake its type for JIT.
return v

Expand Down Expand Up @@ -74,7 +72,8 @@ class GeneralizedRCNNTransform(nn.Module):
It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets
"""

def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32, fixed_size=None):
def __init__(self, min_size: int, max_size: int, image_mean: List[float], image_std: List[float],
size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None):
super(GeneralizedRCNNTransform, self).__init__()
if not isinstance(min_size, (list, tuple)):
min_size = (min_size,)
Expand All @@ -86,10 +85,9 @@ def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32,
self.fixed_size = fixed_size

def forward(self,
images, # type: List[Tensor]
targets=None # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]
images: List[Tensor],
targets: Optional[List[Dict[str, Tensor]]] = None
) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]:
datumbox marked this conversation as resolved.
Show resolved Hide resolved
images = [img for img in images]
if targets is not None:
# make a copy of targets to avoid modifying it in-place
Expand Down Expand Up @@ -126,7 +124,7 @@ def forward(self,
image_list = ImageList(images, image_sizes_list)
return image_list, targets

def normalize(self, image):
def normalize(self, image: Tensor) -> Tensor:
if not image.is_floating_point():
raise TypeError(
f"Expected input images to be of floating type (in range [0, 1]), "
Expand All @@ -137,8 +135,7 @@ def normalize(self, image):
std = torch.as_tensor(self.image_std, dtype=dtype, device=device)
return (image - mean[:, None, None]) / std[:, None, None]

def torch_choice(self, k):
# type: (List[int]) -> int
def torch_choice(self, k: List[int]) -> int:
"""
Implements `random.choice` via torch ops so it can be compiled with
TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803
Expand Down Expand Up @@ -175,8 +172,7 @@ def resize(self,
# _onnx_batch_images() is an implementation of
# batch_images() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int) -> Tensor
def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
max_size = []
for i in range(images[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64)
Expand All @@ -197,16 +193,14 @@ def _onnx_batch_images(self, images, size_divisible=32):

return torch.stack(padded_imgs)

def max_by_axis(self, the_list):
# type: (List[List[int]]) -> List[int]
def max_by_axis(self, the_list: List[List[int]]) -> List[int]:
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes

def batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int) -> Tensor
def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor:
if torchvision._is_tracing():
# batch_images() does not export well to ONNX
# call _onnx_batch_images() instead
Expand All @@ -226,11 +220,10 @@ def batch_images(self, images, size_divisible=32):
return batched_imgs

def postprocess(self,
result, # type: List[Dict[str, Tensor]]
image_shapes, # type: List[Tuple[int, int]]
original_image_sizes # type: List[Tuple[int, int]]
):
# type: (...) -> List[Dict[str, Tensor]]
result: List[Dict[str, Tensor]],
image_shapes: List[Tuple[int, int]],
original_image_sizes: List[Tuple[int, int]]
) -> List[Dict[str, Tensor]]:
if self.training:
return result
for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)):
Expand All @@ -247,7 +240,7 @@ def postprocess(self,
result[i]["keypoints"] = keypoints
return result

def __repr__(self):
def __repr__(self) -> str:
format_string = self.__class__.__name__ + '('
_indent = '\n '
format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std)
Expand All @@ -257,8 +250,7 @@ def __repr__(self):
return format_string


def resize_keypoints(keypoints, original_size, new_size):
# type: (Tensor, List[int], List[int]) -> Tensor
def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
ratios = [
torch.tensor(s, dtype=torch.float32, device=keypoints.device) /
torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
Expand All @@ -276,8 +268,7 @@ def resize_keypoints(keypoints, original_size, new_size):
return resized_data


def resize_boxes(boxes, original_size, new_size):
# type: (Tensor, List[int], List[int]) -> Tensor
def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor:
ratios = [
torch.tensor(s, dtype=torch.float32, device=boxes.device) /
torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
Expand Down