From fac02067c1a6f73c925582e3d2955bc25ba95c53 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 6 Sep 2021 09:15:16 +0100 Subject: [PATCH 1/2] Add types in transform. --- test/tracing/frcnn/trace_model.py | 2 +- torchvision/models/detection/transform.py | 47 +++++++++-------------- 2 files changed, 20 insertions(+), 29 deletions(-) diff --git a/test/tracing/frcnn/trace_model.py b/test/tracing/frcnn/trace_model.py index 34961e8684f..84d5e65f014 100644 --- a/test/tracing/frcnn/trace_model.py +++ b/test/tracing/frcnn/trace_model.py @@ -8,7 +8,7 @@ ASSETS = osp.dirname(osp.dirname(HERE)) model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False) -model.eval() +model.train() traced_model = torch.jit.script(model) traced_model.save("fasterrcnn_resnet50_fpn.pt") diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 0ca5273e047..3022c2f4149 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -10,15 +10,13 @@ @torch.jit.unused -def _get_shape_onnx(image): - # type: (Tensor) -> Tensor +def _get_shape_onnx(image: Tensor) -> Tensor: from torch.onnx import operators return operators.shape_as_tensor(image)[-2:] @torch.jit.unused -def _fake_cast_onnx(v): - # type: (Tensor) -> float +def _fake_cast_onnx(v: Tensor) -> float: # ONNX requires a tensor but here we fake its type for JIT. return v @@ -74,7 +72,8 @@ class GeneralizedRCNNTransform(nn.Module): It returns a ImageList for the inputs, and a List[Dict[Tensor]] for the targets """ - def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32, fixed_size=None): + def __init__(self, min_size: int, max_size: int, image_mean: List[float], image_std: List[float], + size_divisible: int = 32, fixed_size: Optional[Tuple[int, int]] = None): super(GeneralizedRCNNTransform, self).__init__() if not isinstance(min_size, (list, tuple)): min_size = (min_size,) @@ -86,10 +85,9 @@ def __init__(self, min_size, max_size, image_mean, image_std, size_divisible=32, self.fixed_size = fixed_size def forward(self, - images, # type: List[Tensor] - targets=None # type: Optional[List[Dict[str, Tensor]]] - ): - # type: (...) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]] + images: List[Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None + ) -> Tuple[ImageList, Optional[List[Dict[str, Tensor]]]]: images = [img for img in images] if targets is not None: # make a copy of targets to avoid modifying it in-place @@ -126,7 +124,7 @@ def forward(self, image_list = ImageList(images, image_sizes_list) return image_list, targets - def normalize(self, image): + def normalize(self, image: Tensor) -> Tensor: if not image.is_floating_point(): raise TypeError( f"Expected input images to be of floating type (in range [0, 1]), " @@ -137,8 +135,7 @@ def normalize(self, image): std = torch.as_tensor(self.image_std, dtype=dtype, device=device) return (image - mean[:, None, None]) / std[:, None, None] - def torch_choice(self, k): - # type: (List[int]) -> int + def torch_choice(self, k: List[int]) -> int: """ Implements `random.choice` via torch ops so it can be compiled with TorchScript. Remove if https://github.com/pytorch/pytorch/issues/25803 @@ -175,8 +172,7 @@ def resize(self, # _onnx_batch_images() is an implementation of # batch_images() that is supported by ONNX tracing. @torch.jit.unused - def _onnx_batch_images(self, images, size_divisible=32): - # type: (List[Tensor], int) -> Tensor + def _onnx_batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor: max_size = [] for i in range(images[0].dim()): max_size_i = torch.max(torch.stack([img.shape[i] for img in images]).to(torch.float32)).to(torch.int64) @@ -197,16 +193,14 @@ def _onnx_batch_images(self, images, size_divisible=32): return torch.stack(padded_imgs) - def max_by_axis(self, the_list): - # type: (List[List[int]]) -> List[int] + def max_by_axis(self, the_list: List[List[int]]) -> List[int]: maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes - def batch_images(self, images, size_divisible=32): - # type: (List[Tensor], int) -> Tensor + def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor: if torchvision._is_tracing(): # batch_images() does not export well to ONNX # call _onnx_batch_images() instead @@ -226,11 +220,10 @@ def batch_images(self, images, size_divisible=32): return batched_imgs def postprocess(self, - result, # type: List[Dict[str, Tensor]] - image_shapes, # type: List[Tuple[int, int]] - original_image_sizes # type: List[Tuple[int, int]] - ): - # type: (...) -> List[Dict[str, Tensor]] + result: List[Dict[str, Tensor]], + image_shapes: List[Tuple[int, int]], + original_image_sizes: List[Tuple[int, int]] + ) -> List[Dict[str, Tensor]]: if self.training: return result for i, (pred, im_s, o_im_s) in enumerate(zip(result, image_shapes, original_image_sizes)): @@ -247,7 +240,7 @@ def postprocess(self, result[i]["keypoints"] = keypoints return result - def __repr__(self): + def __repr__(self) -> str: format_string = self.__class__.__name__ + '(' _indent = '\n ' format_string += "{0}Normalize(mean={1}, std={2})".format(_indent, self.image_mean, self.image_std) @@ -257,8 +250,7 @@ def __repr__(self): return format_string -def resize_keypoints(keypoints, original_size, new_size): - # type: (Tensor, List[int], List[int]) -> Tensor +def resize_keypoints(keypoints: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: ratios = [ torch.tensor(s, dtype=torch.float32, device=keypoints.device) / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) @@ -276,8 +268,7 @@ def resize_keypoints(keypoints, original_size, new_size): return resized_data -def resize_boxes(boxes, original_size, new_size): - # type: (Tensor, List[int], List[int]) -> Tensor +def resize_boxes(boxes: Tensor, original_size: List[int], new_size: List[int]) -> Tensor: ratios = [ torch.tensor(s, dtype=torch.float32, device=boxes.device) / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) From aa41000b259ad0537ccf9714bbd03b8d6700a5c4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 6 Sep 2021 09:23:23 +0100 Subject: [PATCH 2/2] Trace on eval mode. --- test/tracing/frcnn/trace_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tracing/frcnn/trace_model.py b/test/tracing/frcnn/trace_model.py index 84d5e65f014..34961e8684f 100644 --- a/test/tracing/frcnn/trace_model.py +++ b/test/tracing/frcnn/trace_model.py @@ -8,7 +8,7 @@ ASSETS = osp.dirname(osp.dirname(HERE)) model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False) -model.train() +model.eval() traced_model = torch.jit.script(model) traced_model.save("fasterrcnn_resnet50_fpn.pt")