From e48f3f738e4ed2102e8bdb089da9a578d3e5bf7f Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 30 Jan 2020 09:10:02 -0800 Subject: [PATCH 01/44] fixes and tests for variable input size --- test/test_onnx.py | 108 ++++++++++++++-------- torchvision/models/detection/rpn.py | 25 +++-- torchvision/models/detection/transform.py | 3 +- torchvision/ops/boxes.py | 17 +++- 4 files changed, 105 insertions(+), 48 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index e3cd55b6e14..18176dacb35 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -28,14 +28,17 @@ class ONNXExporterTester(unittest.TestCase): def setUpClass(cls): torch.manual_seed(123) - def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True): + def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None, + output_names=None, input_names=None): model.eval() onnx_io = io.BytesIO() + # onnx_io = '/home/neraoof/test/results/transform.onnx' # export to onnx with the first input - torch.onnx.export(model, inputs_list[0], onnx_io, - do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version) + torch.onnx.export(model, inputs_list[0], onnx_io, + do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version, + dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names) # validate the exported model with onnx runtime for test_inputs in inputs_list: with torch.no_grad(): @@ -65,6 +68,7 @@ def to_numpy(tensor): # compute onnxruntime output prediction ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs)) ort_outs = ort_session.run(None, ort_inputs) + for i in range(0, len(outputs)): try: torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05) @@ -123,9 +127,11 @@ def __init__(self_module): def forward(self_module, images): return self_module.transform(images)[0].tensors - input = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)] - input_test = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)] - self.run_model(TransformModule(), [input, input_test]) + input = torch.rand(3, 100, 200), torch.rand(3, 200, 200) + input_test = torch.rand(3, 130, 230), torch.rand(3, 230, 230) + self.run_model(TransformModule(), [(input,), (input_test,)], + input_names=["input1", "input2"], + dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3]}) def _init_test_generalized_rcnn_transform(self): min_size = 100 @@ -207,22 +213,29 @@ def get_features(self, images): def test_rpn(self): class RPNModule(torch.nn.Module): - def __init__(self_module, images): + def __init__(self_module): super(RPNModule, self_module).__init__() self_module.rpn = self._init_test_rpn() - self_module.images = ImageList(images, [i.shape[-2:] for i in images]) - def forward(self_module, features): - return self_module.rpn(self_module.images, features) + def forward(self_module, images, features): + images = ImageList(images, [i.shape[-2:] for i in images]) + return self_module.rpn(images, features) + images = torch.rand(2, 3, 600, 600) features = self.get_features(images) - test_features = self.get_features(images) + images2 = torch.rand(2, 3, 1000, 1000) + test_features = self.get_features(images2) - model = RPNModule(images) + model = RPNModule() model.eval() - model(features) - self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True) + model(images, features) + + self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True, + input_names=["input1", "input2", "input3", "input4", "input5", "input6"], + dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], + "input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3], + "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) def test_multi_scale_roi_align(self): @@ -251,32 +264,38 @@ def forward(self, input, boxes): def test_roi_heads(self): class RoiHeadsModule(torch.nn.Module): - def __init__(self_module, images): + def __init__(self_module): super(RoiHeadsModule, self_module).__init__() self_module.transform = self._init_test_generalized_rcnn_transform() self_module.rpn = self._init_test_rpn() self_module.roi_heads = self._init_test_roi_heads_faster_rcnn() - self_module.original_image_sizes = [img.shape[-2:] for img in images] - self_module.images = ImageList(images, [i.shape[-2:] for i in images]) - def forward(self_module, features): - proposals, _ = self_module.rpn(self_module.images, features) - detections, _ = self_module.roi_heads(features, proposals, self_module.images.image_sizes) + def forward(self_module, images, features): + original_image_sizes = [img.shape[-2:] for img in images] + images = ImageList(images, [i.shape[-2:] for i in images]) + proposals, _ = self_module.rpn(images, features) + detections, _ = self_module.roi_heads(features, proposals, images.image_sizes) detections = self_module.transform.postprocess(detections, - self_module.images.image_sizes, - self_module.original_image_sizes) + images.image_sizes, + original_image_sizes) return detections images = torch.rand(2, 3, 600, 600) features = self.get_features(images) - test_features = self.get_features(images) + images2 = torch.rand(2, 3, 1000, 1000) + test_features = self.get_features(images2) - model = RoiHeadsModule(images) + model = RoiHeadsModule() model.eval() - model(features) - self.run_model(model, [(features,), (test_features,)]) + model(images, features) - def get_image_from_url(self, url): + self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True, + input_names=["input1", "input2", "input3", "input4", "input5", "input6"], + dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3], + "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) + + + def get_image_from_url(self, url, size=None): import requests import numpy from PIL import Image @@ -285,16 +304,22 @@ def get_image_from_url(self, url): data = requests.get(url) image = Image.open(BytesIO(data.content)).convert("RGB") - image = image.resize((300, 200), Image.BILINEAR) + + if size is None: + size = (300, 200) + image = image.resize(size, Image.BILINEAR) to_tensor = transforms.ToTensor() return to_tensor(image) + def get_test_images(self): image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg" - image = self.get_image_from_url(url=image_url) + image = self.get_image_from_url(url=image_url, size=(800, 1201)) + image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png" - image2 = self.get_image_from_url(url=image_url2) + image2 = self.get_image_from_url(url=image_url2, size=(873, 800)) + images = [image] test_images = [image2] return images, test_images @@ -302,12 +327,13 @@ def get_test_images(self): def test_faster_rcnn(self): images, test_images = self.get_test_images() - model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, - min_size=200, - max_size=300) + model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True) model.eval() model(images) - self.run_model(model, [(images,), (test_images,)]) + self.run_model(model, [(images,), (test_images,)], input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]}, + tolerate_small_mismatch=True) # Verify that paste_mask_in_image beahves the same in tracing. # This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image @@ -346,10 +372,14 @@ def test_paste_mask_in_image(self): def test_mask_rcnn(self): images, test_images = self.get_test_images() - model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) + model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True) model.eval() model(images) - self.run_model(model, [(images,), (test_images,)]) + self.run_model(model, [(images,), (test_images,)], + input_names=["images_tensors"], + output_names=["outputs"], + dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]}, + tolerate_small_mismatch=True) # Verify that heatmaps_to_keypoints behaves the same in tracing. # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints @@ -398,7 +428,11 @@ def forward(self, images): model = KeyPointRCNN() model.eval() model(test_images) - self.run_model(model, [(images,), (test_images,)]) + self.run_model(model, [(images,), (test_images,)], + input_names=["images_tensors"], + output_names=["outputs1", "outputs2", "outputs3", "outputs4"], + dynamic_axes={"images_tensors": [0, 1, 2, 3]}, + tolerate_small_mismatch=True) if __name__ == '__main__': diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index f1c720bf748..8058a5ba0bb 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -17,8 +17,7 @@ @torch.jit.unused def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): # type: (Tensor, int) -> Tuple[int, int] - from torch.onnx import operators - num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) + num_anchors = ob.shape[1].unsqueeze(0) # TODO : remove cast to IntTensor/num_anchors.dtype when # ONNX Runtime version is updated with ReduceMin int64 support pre_nms_top_n = torch.min(torch.cat( @@ -157,7 +156,8 @@ def forward(self, image_list, feature_maps): # type: (ImageList, List[Tensor]) grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) image_size = image_list.tensors.shape[-2:] - strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes] + strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64), + torch.tensor(image_size[1] / g[1], dtype=torch.int64)] for g in grid_sizes] dtype, device = feature_maps[0].dtype, feature_maps[0].device self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) @@ -354,7 +354,17 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level): # type: (Tensor, List[int]) r = [] offset = 0 - for ob in objectness.split(num_anchors_per_level, 1): + if torchvision._is_tracing(): + # Split's split_size is traced as constant in onnx exporting, use Gather + start_list = [torch.tensor(0)] + end_list = [num_anchors_per_level[0].clone()] + for cnt in num_anchors_per_level[1:]: + start_list.append(end_list[-1].clone()) + end_list.append(end_list[-1] + cnt) + objectness_per_level = [objectness[:, s:e] for s, e in zip(start_list, end_list)] + else: + objectness_per_level = objectness.split(num_anchors_per_level, 1) + for ob in objectness_per_level: if torchvision._is_tracing(): num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n()) else: @@ -372,14 +382,12 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ # do not backprop throught objectness objectness = objectness.detach() objectness = objectness.reshape(num_images, -1) - levels = [ torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level) ] levels = torch.cat(levels, 0) levels = levels.reshape(1, -1).expand_as(objectness) - # select top_n boxes independently per level before applying nms top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level) @@ -466,7 +474,10 @@ def forward(self, images, features, targets=None): anchors = self.anchor_generator(images, features) num_images = len(anchors) - num_anchors_per_level = [o[0].numel() for o in objectness] + # num_anchors_per_level = [torch.prod(shape_as_tensor(o[0])) for o in objectness] + num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] + num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] + objectness, pred_bbox_deltas = \ concat_box_prediction_layers(objectness, pred_bbox_deltas) # apply pred_bbox_deltas to anchors to obtain the decoded proposals diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index f1cf8a41bfd..ffcce555492 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -90,7 +90,8 @@ def resize(self, image, target): if max_size * scale_factor > self.max_size: scale_factor = self.max_size / max_size image = torch.nn.functional.interpolate( - image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0] + image[None], scale_factor=scale_factor, mode='bilinear', + align_corners=False, recompute_scale_factor=False)[0] if target is None: return image, target diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 44dee79497f..56be6a2220e 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -3,7 +3,7 @@ import torch from torch.jit.annotations import Tuple from torch import Tensor - +import torchvision def nms(boxes, scores, iou_threshold): # type: (Tensor, Tensor, float) @@ -112,8 +112,19 @@ def clip_boxes_to_image(boxes, size): boxes_x = boxes[..., 0::2] boxes_y = boxes[..., 1::2] height, width = size - boxes_x = boxes_x.clamp(min=0, max=width) - boxes_y = boxes_y.clamp(min=0, max=height) + + if torchvision._is_tracing(): + height = height.to(torch.float32) + width = width.to(torch.float32) + + boxes_x = torch.max(boxes_x, torch.tensor(0.)) + boxes_x = torch.min(boxes_x, width) + boxes_y = torch.max(boxes_y, torch.tensor(0.)) + boxes_y = torch.min(boxes_y, height) + else: + boxes_x = boxes_x.clamp(min=0, max=width) + boxes_y = boxes_y.clamp(min=0, max=height) + clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim) return clipped_boxes.reshape(boxes.shape) From 8df5790a210cbc464940268bacb1d38b47b4475f Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 30 Jan 2020 09:40:40 -0800 Subject: [PATCH 02/44] transform test fix --- test/test_onnx.py | 9 ++------- torchvision/ops/boxes.py | 1 + 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index 18176dacb35..ef272eef4f8 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -128,10 +128,8 @@ def forward(self_module, images): return self_module.transform(images)[0].tensors input = torch.rand(3, 100, 200), torch.rand(3, 200, 200) - input_test = torch.rand(3, 130, 230), torch.rand(3, 230, 230) - self.run_model(TransformModule(), [(input,), (input_test,)], - input_names=["input1", "input2"], - dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3]}) + input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200) + self.run_model(TransformModule(), [(input,), (input_test,)]) def _init_test_generalized_rcnn_transform(self): min_size = 100 @@ -221,7 +219,6 @@ def forward(self_module, images, features): images = ImageList(images, [i.shape[-2:] for i in images]) return self_module.rpn(images, features) - images = torch.rand(2, 3, 600, 600) features = self.get_features(images) images2 = torch.rand(2, 3, 1000, 1000) @@ -294,7 +291,6 @@ def forward(self_module, images, features): dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]}) - def get_image_from_url(self, url, size=None): import requests import numpy @@ -312,7 +308,6 @@ def get_image_from_url(self, url, size=None): to_tensor = transforms.ToTensor() return to_tensor(image) - def get_test_images(self): image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg" image = self.get_image_from_url(url=image_url, size=(800, 1201)) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 56be6a2220e..e4f274c419c 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -5,6 +5,7 @@ from torch import Tensor import torchvision + def nms(boxes, scores, iou_threshold): # type: (Tensor, Tensor, float) """ From 03bc884dadfa8978d65b31d711e7de29492091c3 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Thu, 30 Jan 2020 09:50:39 -0800 Subject: [PATCH 03/44] Fix comment --- torchvision/models/detection/rpn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 8058a5ba0bb..79b1a549707 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -474,7 +474,6 @@ def forward(self, images, features, targets=None): anchors = self.anchor_generator(images, features) num_images = len(anchors) - # num_anchors_per_level = [torch.prod(shape_as_tensor(o[0])) for o in objectness] num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] From 37888e20da504dcb33c8cb7563813c02ce00b56a Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Thu, 30 Jan 2020 11:24:17 -0800 Subject: [PATCH 04/44] Dynamic shape for keypoint_rcnn --- test/test_onnx.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index ef272eef4f8..ce48a0bf5c6 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -408,9 +408,7 @@ def test_keypoint_rcnn(self): class KeyPointRCNN(torch.nn.Module): def __init__(self): super(KeyPointRCNN, self).__init__() - self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, - min_size=200, - max_size=300) + self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True) def forward(self, images): output = self.model(images) From ebcd45b6b3137c2fa08d887efe3db1889f902f9e Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Tue, 4 Feb 2020 11:59:49 -0800 Subject: [PATCH 05/44] Update test_onnx.py --- test/test_onnx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index ce48a0bf5c6..834a7758358 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -33,7 +33,6 @@ def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_consta model.eval() onnx_io = io.BytesIO() - # onnx_io = '/home/neraoof/test/results/transform.onnx' # export to onnx with the first input torch.onnx.export(model, inputs_list[0], onnx_io, From 89404f313b5ef1906cdc68e0a5f2ef3b53f063c5 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Tue, 4 Feb 2020 16:16:47 -0800 Subject: [PATCH 06/44] Update rpn.py --- torchvision/models/detection/rpn.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 79b1a549707..bfe0c2f8c21 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -121,10 +121,6 @@ def grid_anchors(self, grid_sizes, strides): ): grid_height, grid_width = size stride_height, stride_width = stride - if torchvision._is_tracing(): - # required in ONNX export for mult operation with float32 - stride_width = torch.tensor(stride_width, dtype=torch.float32) - stride_height = torch.tensor(stride_height, dtype=torch.float32) device = base_anchors.device shifts_x = torch.arange( 0, grid_width, dtype=torch.float32, device=device From 60ba5e7a3d0e257b9fa2e7839c3164950446d7b6 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Tue, 18 Feb 2020 11:57:32 -0800 Subject: [PATCH 07/44] Fix for split on RPN --- torchvision/models/detection/rpn.py | 31 +++-------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 79b1a549707..c74c317237a 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -14,19 +14,6 @@ from torch.jit.annotations import List, Optional, Dict, Tuple -@torch.jit.unused -def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): - # type: (Tensor, int) -> Tuple[int, int] - num_anchors = ob.shape[1].unsqueeze(0) - # TODO : remove cast to IntTensor/num_anchors.dtype when - # ONNX Runtime version is updated with ReduceMin int64 support - pre_nms_top_n = torch.min(torch.cat( - (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), - num_anchors), 0).to(torch.int32)).to(num_anchors.dtype) - - return num_anchors, pre_nms_top_n - - class AnchorGenerator(nn.Module): __annotations__ = { "cell_anchors": Optional[List[torch.Tensor]], @@ -354,22 +341,10 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level): # type: (Tensor, List[int]) r = [] offset = 0 - if torchvision._is_tracing(): - # Split's split_size is traced as constant in onnx exporting, use Gather - start_list = [torch.tensor(0)] - end_list = [num_anchors_per_level[0].clone()] - for cnt in num_anchors_per_level[1:]: - start_list.append(end_list[-1].clone()) - end_list.append(end_list[-1] + cnt) - objectness_per_level = [objectness[:, s:e] for s, e in zip(start_list, end_list)] - else: - objectness_per_level = objectness.split(num_anchors_per_level, 1) + objectness_per_level = objectness.split(num_anchors_per_level, 1) for ob in objectness_per_level: - if torchvision._is_tracing(): - num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n()) - else: - num_anchors = ob.shape[1] - pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors) + num_anchors = ob.shape[1] + pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors) _, top_n_idx = ob.topk(pre_nms_top_n, dim=1) r.append(top_n_idx + offset) offset += num_anchors From ea5cf6e93d770a59481740da131631f3334ba100 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Wed, 19 Feb 2020 17:24:46 -0800 Subject: [PATCH 08/44] Fixes for feedbacks --- test/test_onnx.py | 4 ++-- torchvision/models/detection/transform.py | 8 ++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index 834a7758358..c358a53da1b 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -34,7 +34,6 @@ def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_consta onnx_io = io.BytesIO() # export to onnx with the first input - torch.onnx.export(model, inputs_list[0], onnx_io, do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version, dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names) @@ -375,6 +374,7 @@ def test_mask_rcnn(self): dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]}, tolerate_small_mismatch=True) + # Verify that heatmaps_to_keypoints behaves the same in tracing. # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints # (since jit_trace witll call _heatmaps_to_keypoints). @@ -419,7 +419,7 @@ def forward(self, images): images, test_images = self.get_test_images() model = KeyPointRCNN() model.eval() - model(test_images) + model(images) self.run_model(model, [(images,), (test_images,)], input_names=["images_tensors"], output_names=["outputs1", "outputs2", "outputs3", "outputs4"], diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index be2b8cabe31..a4e13df57bc 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -91,7 +91,7 @@ def resize(self, image, target): scale_factor = self.max_size / max_size image = torch.nn.functional.interpolate( image[None], scale_factor=scale_factor, mode='bilinear', - align_corners=False, recompute_scale_factor=False)[0] + align_corners=False)[0] if target is None: return image, target @@ -209,7 +209,11 @@ def resize_keypoints(keypoints, original_size, new_size): def resize_boxes(boxes, original_size, new_size): # type: (Tensor, List[int], List[int]) - ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)] + if torchvision._is_tracing(): + ratios = [s.to(dtype=torch.float32) / s_orig.to(dtype=torch.float32) for s, s_orig in + zip(new_size, original_size)] + else: + ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)] ratio_height, ratio_width = ratios xmin, ymin, xmax, ymax = boxes.unbind(1) From dced83af62006fb4cc1346c4c63409b35f891c23 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Wed, 19 Feb 2020 17:27:01 -0800 Subject: [PATCH 09/44] flake8 --- test/test_onnx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index c358a53da1b..fd478e79fd9 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -374,7 +374,6 @@ def test_mask_rcnn(self): dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]}, tolerate_small_mismatch=True) - # Verify that heatmaps_to_keypoints behaves the same in tracing. # This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints # (since jit_trace witll call _heatmaps_to_keypoints). From da44102cfe270d5031958c68e71cae112b0b3b23 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Wed, 19 Feb 2020 17:44:04 -0800 Subject: [PATCH 10/44] topk fix --- test/test_onnx.py | 1 - torchvision/models/detection/rpn.py | 24 ++++++++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index fd478e79fd9..da67f61a7ac 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -66,7 +66,6 @@ def to_numpy(tensor): # compute onnxruntime output prediction ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs)) ort_outs = ort_session.run(None, ort_inputs) - for i in range(0, len(outputs)): try: torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 9cfbc2e69f8..9d9986e6741 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -14,6 +14,20 @@ from torch.jit.annotations import List, Optional, Dict, Tuple +@torch.jit.unused +def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n): + # type: (Tensor, int) -> Tuple[int, int] + from torch.onnx import operators + num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0) + # TODO : remove cast to IntTensor/num_anchors.dtype when + # ONNX Runtime version is updated with ReduceMin int64 support + pre_nms_top_n = torch.min(torch.cat( + (torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), + num_anchors), 0).to(torch.int32)).to(num_anchors.dtype) + + return num_anchors, pre_nms_top_n + + class AnchorGenerator(nn.Module): __annotations__ = { "cell_anchors": Optional[List[torch.Tensor]], @@ -337,10 +351,12 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level): # type: (Tensor, List[int]) r = [] offset = 0 - objectness_per_level = objectness.split(num_anchors_per_level, 1) - for ob in objectness_per_level: - num_anchors = ob.shape[1] - pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors) + for ob in objectness.split(num_anchors_per_level, 1): + if torchvision._is_tracing(): + num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n()) + else: + num_anchors = ob.shape[1] + pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors) _, top_n_idx = ob.topk(pre_nms_top_n, dim=1) r.append(top_n_idx + offset) offset += num_anchors From cd7943538561d9a4bd50504db82817659563c830 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 20 Feb 2020 09:31:06 -0800 Subject: [PATCH 11/44] Fix build --- torchvision/models/detection/transform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index a4e13df57bc..69a1f5d4843 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -89,6 +89,7 @@ def resize(self, image, target): scale_factor = size / min_size if max_size * scale_factor > self.max_size: scale_factor = self.max_size / max_size + image = torch.nn.functional.interpolate( image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0] @@ -210,7 +211,7 @@ def resize_keypoints(keypoints, original_size, new_size): def resize_boxes(boxes, original_size, new_size): # type: (Tensor, List[int], List[int]) if torchvision._is_tracing(): - ratios = [s.to(dtype=torch.float32) / s_orig.to(dtype=torch.float32) for s, s_orig in + ratios = [torch.tensor(s, dtype=torch.float32) / torch.tensor(s_orig, dtype=torch.float32) for s, s_orig in zip(new_size, original_size)] else: ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)] From fbe46802d57d3330c9165aa65c9b9fcaff75c7fd Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 20 Feb 2020 10:28:05 -0800 Subject: [PATCH 12/44] branch on tracing --- torchvision/models/detection/transform.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 69a1f5d4843..9ec202bb535 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -211,10 +211,9 @@ def resize_keypoints(keypoints, original_size, new_size): def resize_boxes(boxes, original_size, new_size): # type: (Tensor, List[int], List[int]) if torchvision._is_tracing(): - ratios = [torch.tensor(s, dtype=torch.float32) / torch.tensor(s_orig, dtype=torch.float32) for s, s_orig in - zip(new_size, original_size)] + ratios = [s.to(dtype=torch.float32) / s_orig.to(dtype=torch.float32) for s, s_orig in zip(new_size, original_size)] else: - ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)] + ratios = [float(s) / (s_orig) for s, s_orig in zip(new_size, original_size)] ratio_height, ratio_width = ratios xmin, ymin, xmax, ymax = boxes.unbind(1) From 2b4ad0738b1f6987276059ffef3a127c231db83f Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 20 Feb 2020 13:17:12 -0800 Subject: [PATCH 13/44] fix for scalar tensor --- torchvision/models/detection/rpn.py | 4 ++-- torchvision/models/detection/transform.py | 7 ++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 9d9986e6741..cb698feb450 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -153,8 +153,8 @@ def forward(self, image_list, feature_maps): # type: (ImageList, List[Tensor]) grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) image_size = image_list.tensors.shape[-2:] - strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64), - torch.tensor(image_size[1] / g[1], dtype=torch.int64)] for g in grid_sizes] + strides = [[torch.scalar_tensor(image_size[0] / g[0], dtype=torch.int64), + torch.scalar_tensor(image_size[1] / g[1], dtype=torch.int64)] for g in grid_sizes] dtype, device = feature_maps[0].dtype, feature_maps[0].device self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 9ec202bb535..c566fa60b6c 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -195,7 +195,7 @@ def __repr__(self): def resize_keypoints(keypoints, original_size, new_size): # type: (Tensor, List[int], List[int]) - ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)] + ratios = [torch.scalar_tensor(s) / torch.scalar_tensor(s_orig) for s, s_orig in zip(new_size, original_size)] ratio_h, ratio_w = ratios resized_data = keypoints.clone() if torch._C._get_tracing_state(): @@ -210,10 +210,7 @@ def resize_keypoints(keypoints, original_size, new_size): def resize_boxes(boxes, original_size, new_size): # type: (Tensor, List[int], List[int]) - if torchvision._is_tracing(): - ratios = [s.to(dtype=torch.float32) / s_orig.to(dtype=torch.float32) for s, s_orig in zip(new_size, original_size)] - else: - ratios = [float(s) / (s_orig) for s, s_orig in zip(new_size, original_size)] + ratios = [torch.scalar_tensor(s) / torch.scalar_tensor(s_orig) for s, s_orig in zip(new_size, original_size)] ratio_height, ratio_width = ratios xmin, ymin, xmax, ymax = boxes.unbind(1) From be0ae7e213bf9970884fa075bf15f479d0a6ed72 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Tue, 25 Feb 2020 10:00:40 -0800 Subject: [PATCH 14/44] Fixes for script type annotations --- torchvision/models/detection/rpn.py | 6 +++--- torchvision/ops/boxes.py | 7 ++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index cb698feb450..52911c149bc 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -112,7 +112,7 @@ def num_anchors_per_location(self): return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] def grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[int]]) + # type: (List[List[int]], List[List[Tensor]]) anchors = [] cell_anchors = self.cell_anchors assert cell_anchors is not None @@ -141,8 +141,8 @@ def grid_anchors(self, grid_sizes, strides): return anchors def cached_grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[int]]) - key = str(grid_sizes + strides) + # type: (List[List[int]], List[List[Tensor]]) + key = str(str(grid_sizes) + str(strides)) if key in self._cache: return self._cache[key] anchors = self.grid_anchors(grid_sizes, strides) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index e4f274c419c..821e2681a0c 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -115,13 +115,10 @@ def clip_boxes_to_image(boxes, size): height, width = size if torchvision._is_tracing(): - height = height.to(torch.float32) - width = width.to(torch.float32) - boxes_x = torch.max(boxes_x, torch.tensor(0.)) - boxes_x = torch.min(boxes_x, width) + boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=torch.float32)) boxes_y = torch.max(boxes_y, torch.tensor(0.)) - boxes_y = torch.min(boxes_y, height) + boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=torch.float32)) else: boxes_x = boxes_x.clamp(min=0, max=width) boxes_y = boxes_y.clamp(min=0, max=height) From 7999e5597b4e2ea1c12cd3200e3aaf386d33428f Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Tue, 3 Mar 2020 09:53:59 -0800 Subject: [PATCH 15/44] Update rpn.py --- torchvision/models/detection/rpn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 52911c149bc..5bef031984d 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -142,7 +142,7 @@ def grid_anchors(self, grid_sizes, strides): def cached_grid_anchors(self, grid_sizes, strides): # type: (List[List[int]], List[List[Tensor]]) - key = str(str(grid_sizes) + str(strides)) + key = str(grid_sizes) + str(strides) if key in self._cache: return self._cache[key] anchors = self.grid_anchors(grid_sizes, strides) From 94b1ac62768b1f53127f042ab2afcb38cf1b5eee Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Tue, 3 Mar 2020 10:03:59 -0800 Subject: [PATCH 16/44] clean up --- torchvision/models/detection/transform.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index c566fa60b6c..fcaf574fe97 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -89,7 +89,6 @@ def resize(self, image, target): scale_factor = size / min_size if max_size * scale_factor > self.max_size: scale_factor = self.max_size / max_size - image = torch.nn.functional.interpolate( image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0] From 050e756a9aaceb86e2a3ed880965082ec2892373 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Tue, 3 Mar 2020 10:07:09 -0800 Subject: [PATCH 17/44] clean up --- torchvision/models/detection/rpn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 5bef031984d..63cd58f5fb5 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -369,10 +369,12 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ # do not backprop throught objectness objectness = objectness.detach() objectness = objectness.reshape(num_images, -1) + levels = [ torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level) ] + levels = torch.cat(levels, 0) levels = levels.reshape(1, -1).expand_as(objectness) # select top_n boxes independently per level before applying nms @@ -463,7 +465,6 @@ def forward(self, images, features, targets=None): num_images = len(anchors) num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] - objectness, pred_bbox_deltas = \ concat_box_prediction_layers(objectness, pred_bbox_deltas) # apply pred_bbox_deltas to anchors to obtain the decoded proposals From a445d4a5ab18b6ecedba6bfe997b00647a586b0a Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Tue, 3 Mar 2020 10:08:07 -0800 Subject: [PATCH 18/44] Update rpn.py --- torchvision/models/detection/rpn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 63cd58f5fb5..56d099b06e1 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -374,9 +374,9 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ torch.full((n,), idx, dtype=torch.int64, device=device) for idx, n in enumerate(num_anchors_per_level) ] - levels = torch.cat(levels, 0) levels = levels.reshape(1, -1).expand_as(objectness) + # select top_n boxes independently per level before applying nms top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level) From b0c79bbfa0fd95ea12b2471781d279bc27407cca Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Tue, 3 Mar 2020 10:09:53 -0800 Subject: [PATCH 19/44] Updated for feedback --- torchvision/ops/boxes.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 821e2681a0c..ea2fd2405b6 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -115,10 +115,10 @@ def clip_boxes_to_image(boxes, size): height, width = size if torchvision._is_tracing(): - boxes_x = torch.max(boxes_x, torch.tensor(0.)) - boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=torch.float32)) - boxes_y = torch.max(boxes_y, torch.tensor(0.)) - boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=torch.float32)) + boxes_x = torch.max(boxes_x, torch.scalar_tensor(0.)) + boxes_x = torch.min(boxes_x, torch.scalar_tensor(width, dtype=torch.float32)) + boxes_y = torch.max(boxes_y, torch.scalar_tensor(0.)) + boxes_y = torch.min(boxes_y, torch.scalar_tensor(height, dtype=torch.float32)) else: boxes_x = boxes_x.clamp(min=0, max=width) boxes_y = boxes_y.clamp(min=0, max=height) From 04ff4302d8e3eede702637a3aa49d62f7aa725ae Mon Sep 17 00:00:00 2001 From: neginraoof Date: Tue, 24 Mar 2020 00:40:00 -0700 Subject: [PATCH 20/44] Fix for comments --- torchvision/models/detection/rpn.py | 10 +++++----- torchvision/models/detection/transform.py | 4 ++-- torchvision/ops/boxes.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 2d01f7980df..d525367b0db 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -116,7 +116,7 @@ def num_anchors_per_location(self): # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. def grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[Tensor]]) + # type: (List[List[int]], List[List[int]]) anchors = [] cell_anchors = self.cell_anchors assert cell_anchors is not None @@ -149,8 +149,8 @@ def grid_anchors(self, grid_sizes, strides): return anchors def cached_grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[Tensor]]) - key = str(grid_sizes) + str(strides) + # type: (List[List[int]], List[List[int]]) + key = str(grid_sizes + strides) if key in self._cache: return self._cache[key] anchors = self.grid_anchors(grid_sizes, strides) @@ -161,8 +161,8 @@ def forward(self, image_list, feature_maps): # type: (ImageList, List[Tensor]) grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) image_size = image_list.tensors.shape[-2:] - strides = [[torch.scalar_tensor(image_size[0] / g[0], dtype=torch.int64), - torch.scalar_tensor(image_size[1] / g[1], dtype=torch.int64)] for g in grid_sizes] + strides = [[image_size[0] / g[0], + image_size[1] / g[1]] for g in grid_sizes] dtype, device = feature_maps[0].dtype, feature_maps[0].device self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index fcaf574fe97..78042a0241c 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -194,7 +194,7 @@ def __repr__(self): def resize_keypoints(keypoints, original_size, new_size): # type: (Tensor, List[int], List[int]) - ratios = [torch.scalar_tensor(s) / torch.scalar_tensor(s_orig) for s, s_orig in zip(new_size, original_size)] + ratios = [torch.tensor(s) / torch.tensor(s_orig) for s, s_orig in zip(new_size, original_size)] ratio_h, ratio_w = ratios resized_data = keypoints.clone() if torch._C._get_tracing_state(): @@ -209,7 +209,7 @@ def resize_keypoints(keypoints, original_size, new_size): def resize_boxes(boxes, original_size, new_size): # type: (Tensor, List[int], List[int]) - ratios = [torch.scalar_tensor(s) / torch.scalar_tensor(s_orig) for s, s_orig in zip(new_size, original_size)] + ratios = [torch.tensor(s, dtype=torch.float32) / torch.tensor(s_orig, dtype=torch.float32) for s, s_orig in zip(new_size, original_size)] ratio_height, ratio_width = ratios xmin, ymin, xmax, ymax = boxes.unbind(1) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index ea2fd2405b6..155c7fc3fb7 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -115,10 +115,10 @@ def clip_boxes_to_image(boxes, size): height, width = size if torchvision._is_tracing(): - boxes_x = torch.max(boxes_x, torch.scalar_tensor(0.)) - boxes_x = torch.min(boxes_x, torch.scalar_tensor(width, dtype=torch.float32)) - boxes_y = torch.max(boxes_y, torch.scalar_tensor(0.)) - boxes_y = torch.min(boxes_y, torch.scalar_tensor(height, dtype=torch.float32)) + boxes_x = torch.max(boxes_x, torch.tensor(0., device=boxes.device)) + boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=torch.float32, device=boxes.device)) + boxes_y = torch.max(boxes_y, torch.tensor(0., device=boxes.device)) + boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=torch.float32, device=boxes.device)) else: boxes_x = boxes_x.clamp(min=0, max=width) boxes_y = boxes_y.clamp(min=0, max=height) From 23bff5964fa274f1b5e2b3d6c7036b4fe082ce77 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Tue, 24 Mar 2020 09:22:20 -0700 Subject: [PATCH 21/44] revert to use tensor --- torchvision/models/detection/rpn.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index d525367b0db..02ebd6b8380 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -116,7 +116,7 @@ def num_anchors_per_location(self): # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. def grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[int]]) + # type: (List[List[int]], List[List[Tensor]]) anchors = [] cell_anchors = self.cell_anchors assert cell_anchors is not None @@ -149,8 +149,8 @@ def grid_anchors(self, grid_sizes, strides): return anchors def cached_grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[int]]) - key = str(grid_sizes + strides) + # type: (List[List[int]], List[List[Tensor]]) + key = str(grid_sizes) + str(strides) if key in self._cache: return self._cache[key] anchors = self.grid_anchors(grid_sizes, strides) @@ -161,8 +161,8 @@ def forward(self, image_list, feature_maps): # type: (ImageList, List[Tensor]) grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) image_size = image_list.tensors.shape[-2:] - strides = [[image_size[0] / g[0], - image_size[1] / g[1]] for g in grid_sizes] + strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64), + torch.tensor(image_size[1] / g[1], dtype=torch.int64)] for g in grid_sizes] dtype, device = feature_maps[0].dtype, feature_maps[0].device self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) From b9ff797fbc15012979750061a11f880368525c97 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Tue, 24 Mar 2020 10:42:13 -0700 Subject: [PATCH 22/44] Added test for box clip --- test/test_onnx.py | 11 +++++++++++ torchvision/ops/boxes.py | 8 ++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index ec71fc2cc51..43d991c215b 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -100,6 +100,17 @@ def forward(self, boxes, scores): self.run_model(Module(), [(boxes, scores)]) + def test_clip_boxes_to_image(self): + boxes = torch.randint(10, (5, 4)) + boxes[:, 2:] += torch.randint(500, (5, 2)) + size = torch.randn(200, 300) + + class Module(torch.nn.Module): + def forward(self, boxes, size): + return ops.boxes.clip_boxes_to_image(boxes, size.shape) + + self.run_model(Module(), [(boxes, size)]) + def test_roi_align(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 155c7fc3fb7..dfd1e9d815c 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -115,10 +115,10 @@ def clip_boxes_to_image(boxes, size): height, width = size if torchvision._is_tracing(): - boxes_x = torch.max(boxes_x, torch.tensor(0., device=boxes.device)) - boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=torch.float32, device=boxes.device)) - boxes_y = torch.max(boxes_y, torch.tensor(0., device=boxes.device)) - boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=torch.float32, device=boxes.device)) + boxes_x = torch.max(boxes_x.to(torch.float32), torch.tensor(0., device=boxes.device)) + boxes_x = torch.min(boxes_x.to(torch.float32), torch.tensor(width, dtype=torch.float32, device=boxes.device)) + boxes_y = torch.max(boxes_y.to(torch.float32), torch.tensor(0., device=boxes.device)) + boxes_y = torch.min(boxes_y.to(torch.float32), torch.tensor(height, dtype=torch.float32, device=boxes.device)) else: boxes_x = boxes_x.clamp(min=0, max=width) boxes_y = boxes_y.clamp(min=0, max=height) From 572e5227dda6434a35660420410f0d1f9b367764 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 26 Mar 2020 15:42:33 -0700 Subject: [PATCH 23/44] Fixes for feedback --- .travis.yml | 2 +- test/test_onnx.py | 5 ++--- torchvision/models/detection/transform.py | 6 ++++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.travis.yml b/.travis.yml index 1b6ecb7a65b..354cf6a02fe 100644 --- a/.travis.yml +++ b/.travis.yml @@ -46,7 +46,7 @@ before_install: - pip install typing - | if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then - pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.0.0.dev1123 + pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202003231 fi - conda install av -c conda-forge diff --git a/test/test_onnx.py b/test/test_onnx.py index 43d991c215b..f1a9b5dd5a8 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -285,9 +285,9 @@ def forward(self_module, images, features): original_image_sizes) return detections - images = torch.rand(2, 3, 600, 600) + images = torch.rand(2, 3, 400, 400) features = self.get_features(images) - images2 = torch.rand(2, 3, 1000, 1000) + images2 = torch.rand(2, 3, 600, 600) test_features = self.get_features(images2) model = RoiHeadsModule() @@ -301,7 +301,6 @@ def forward(self_module, images, features): def get_image_from_url(self, url, size=None): import requests - import numpy from PIL import Image from io import BytesIO from torchvision import transforms diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 78042a0241c..b04b6f13d76 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -194,7 +194,8 @@ def __repr__(self): def resize_keypoints(keypoints, original_size, new_size): # type: (Tensor, List[int], List[int]) - ratios = [torch.tensor(s) / torch.tensor(s_orig) for s, s_orig in zip(new_size, original_size)] + ratios = [torch.tensor(s, dtype=torch.float32, device=keypoints.device) / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device) + for s, s_orig in zip(new_size, original_size)] ratio_h, ratio_w = ratios resized_data = keypoints.clone() if torch._C._get_tracing_state(): @@ -209,7 +210,8 @@ def resize_keypoints(keypoints, original_size, new_size): def resize_boxes(boxes, original_size, new_size): # type: (Tensor, List[int], List[int]) - ratios = [torch.tensor(s, dtype=torch.float32) / torch.tensor(s_orig, dtype=torch.float32) for s, s_orig in zip(new_size, original_size)] + ratios = [torch.tensor(s, dtype=torch.float32, device=boxes.device) / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device) + for s, s_orig in zip(new_size, original_size)] ratio_height, ratio_width = ratios xmin, ymin, xmax, ymax = boxes.unbind(1) From d0ff3f300bd066c7903262e8c40e41ad9c74dd3c Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 26 Mar 2020 15:47:05 -0700 Subject: [PATCH 24/44] Fix for feedback --- test/test_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index f1a9b5dd5a8..464bcabe828 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -227,9 +227,9 @@ def forward(self_module, images, features): images = ImageList(images, [i.shape[-2:] for i in images]) return self_module.rpn(images, features) - images = torch.rand(2, 3, 600, 600) + images = torch.rand(2, 3, 400, 400) features = self.get_features(images) - images2 = torch.rand(2, 3, 1000, 1000) + images2 = torch.rand(2, 3, 600, 600) test_features = self.get_features(images2) model = RPNModule() From a66381a2938b7d6236fad12b9283d6d714609354 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 26 Mar 2020 16:06:47 -0700 Subject: [PATCH 25/44] ORT version revert --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 354cf6a02fe..1b6ecb7a65b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -46,7 +46,7 @@ before_install: - pip install typing - | if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then - pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202003231 + pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.0.0.dev1123 fi - conda install av -c conda-forge From 8d4b1ee898c63927677673e318bf7740e89f31cd Mon Sep 17 00:00:00 2001 From: neginraoof Date: Thu, 26 Mar 2020 17:19:15 -0700 Subject: [PATCH 26/44] Update ort --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 1b6ecb7a65b..42a9c5a5b36 100644 --- a/.travis.yml +++ b/.travis.yml @@ -46,7 +46,7 @@ before_install: - pip install typing - | if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then - pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.0.0.dev1123 + pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202003241 fi - conda install av -c conda-forge From 16cd2eb42803bf7d82bfce5c1a4e05b0b2b8b916 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Thu, 26 Mar 2020 17:27:05 -0700 Subject: [PATCH 27/44] Update .travis.yml --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 42a9c5a5b36..10101f603e6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -46,7 +46,7 @@ before_install: - pip install typing - | if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then - pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202003241 + pip install -q --user onnxruntime fi - conda install av -c conda-forge From c9fc98cfbd2c6dbba5f73632c3e37c521266ccd4 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Thu, 26 Mar 2020 18:02:27 -0700 Subject: [PATCH 28/44] Update test_onnx.py --- test/test_onnx.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index 464bcabe828..d042de8021b 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -227,9 +227,9 @@ def forward(self_module, images, features): images = ImageList(images, [i.shape[-2:] for i in images]) return self_module.rpn(images, features) - images = torch.rand(2, 3, 400, 400) + images = torch.rand(1, 3, 50, 50) features = self.get_features(images) - images2 = torch.rand(2, 3, 600, 600) + images2 = torch.rand(1, 3, 150, 150) test_features = self.get_features(images2) model = RPNModule() @@ -285,9 +285,9 @@ def forward(self_module, images, features): original_image_sizes) return detections - images = torch.rand(2, 3, 400, 400) + images = torch.rand(1, 3, 50, 50) features = self.get_features(images) - images2 = torch.rand(2, 3, 600, 600) + images2 = torch.rand(1, 3, 150, 150) test_features = self.get_features(images2) model = RoiHeadsModule() From f12e10c1dc3a6608c101c667a94cd42aeae48bee Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Fri, 27 Mar 2020 00:52:42 -0700 Subject: [PATCH 29/44] Update test_onnx.py --- test/test_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index d042de8021b..f79c160d98a 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -317,10 +317,10 @@ def get_image_from_url(self, url, size=None): def get_test_images(self): image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg" - image = self.get_image_from_url(url=image_url, size=(800, 1201)) + image = self.get_image_from_url(url=image_url, size=(200, 300)) image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png" - image2 = self.get_image_from_url(url=image_url2, size=(873, 800)) + image2 = self.get_image_from_url(url=image_url2, size=(250, 200)) images = [image] test_images = [image2] From a3a9cd6114d22c20f55e02e3bfdbeed22dd944ff Mon Sep 17 00:00:00 2001 From: neginraoof Date: Fri, 27 Mar 2020 01:06:27 -0700 Subject: [PATCH 30/44] Tensor sizes --- test/test_onnx.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index f79c160d98a..42f45d0ac08 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -227,9 +227,9 @@ def forward(self_module, images, features): images = ImageList(images, [i.shape[-2:] for i in images]) return self_module.rpn(images, features) - images = torch.rand(1, 3, 50, 50) + images = torch.rand(2, 3, 150, 150) features = self.get_features(images) - images2 = torch.rand(1, 3, 150, 150) + images2 = torch.rand(2, 3, 80, 80) test_features = self.get_features(images2) model = RPNModule() @@ -285,9 +285,9 @@ def forward(self_module, images, features): original_image_sizes) return detections - images = torch.rand(1, 3, 50, 50) + images = torch.rand(2, 3, 200, 200) features = self.get_features(images) - images2 = torch.rand(1, 3, 150, 150) + images2 = torch.rand(2, 3, 300, 300) test_features = self.get_features(images2) model = RoiHeadsModule() From 088d5146a7421a73fd1a6ded6d0eb42ed4c18110 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Fri, 27 Mar 2020 01:51:45 -0700 Subject: [PATCH 31/44] Fix for dynamic split --- test/test_onnx.py | 4 ++-- torchvision/models/detection/roi_heads.py | 13 +++---------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index 42f45d0ac08..77ee315fe93 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -285,9 +285,9 @@ def forward(self_module, images, features): original_image_sizes) return detections - images = torch.rand(2, 3, 200, 200) + images = torch.rand(2, 3, 100, 100) features = self.get_features(images) - images2 = torch.rand(2, 3, 300, 300) + images2 = torch.rand(2, 3, 200, 200) test_features = self.get_features(images2) model = RoiHeadsModule() diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 0dbb0119144..344305ea8f5 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -679,20 +679,13 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_ device = class_logits.device num_classes = class_logits.shape[-1] - boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals] + boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals] pred_boxes = self.box_coder.decode(box_regression, proposals) pred_scores = F.softmax(class_logits, -1) - # split boxes and scores per image - if len(boxes_per_image) == 1: - # TODO : remove this when ONNX support dynamic split sizes - # and just assign to pred_boxes instead of pred_boxes_list - pred_boxes_list = [pred_boxes] - pred_scores_list = [pred_scores] - else: - pred_boxes_list = pred_boxes.split(boxes_per_image, 0) - pred_scores_list = pred_scores.split(boxes_per_image, 0) + pred_boxes_list = pred_boxes.split(boxes_per_image, 0) + pred_scores_list = pred_scores.split(boxes_per_image, 0) all_boxes = [] all_scores = [] From ce5d781a2137547a5683430484555e2ace869298 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Fri, 27 Mar 2020 02:57:31 -0700 Subject: [PATCH 32/44] Try disable tests --- test/test_onnx.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_onnx.py b/test/test_onnx.py index 77ee315fe93..928bbd1a3f0 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -100,6 +100,7 @@ def forward(self, boxes, scores): self.run_model(Module(), [(boxes, scores)]) + @unittest.skip def test_clip_boxes_to_image(self): boxes = torch.randint(10, (5, 4)) boxes[:, 2:] += torch.randint(500, (5, 2)) @@ -267,6 +268,7 @@ def forward(self, input, boxes): self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)]) + @unittest.skip def test_roi_heads(self): class RoiHeadsModule(torch.nn.Module): def __init__(self_module): From e826afef313b1b707d6dc6fa24ef03cb725cd39b Mon Sep 17 00:00:00 2001 From: neginraoof Date: Fri, 27 Mar 2020 03:14:20 -0700 Subject: [PATCH 33/44] pytest verbose --- .travis.yml | 2 +- test/test_onnx.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 10101f603e6..b9dde83362b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -67,7 +67,7 @@ install: cd - script: - - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -k 'not TestVideoReader and not TestVideoTransforms' test + - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -v -k 'not TestVideoReader and not TestVideoTransforms' test - pytest test/test_hub.py after_success: diff --git a/test/test_onnx.py b/test/test_onnx.py index 928bbd1a3f0..77ee315fe93 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -100,7 +100,6 @@ def forward(self, boxes, scores): self.run_model(Module(), [(boxes, scores)]) - @unittest.skip def test_clip_boxes_to_image(self): boxes = torch.randint(10, (5, 4)) boxes[:, 2:] += torch.randint(500, (5, 2)) @@ -268,7 +267,6 @@ def forward(self, input, boxes): self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)]) - @unittest.skip def test_roi_heads(self): class RoiHeadsModule(torch.nn.Module): def __init__(self_module): From 8f9d8d511df92e076babeee6381abc678a34a08b Mon Sep 17 00:00:00 2001 From: neginraoof Date: Fri, 27 Mar 2020 03:15:20 -0700 Subject: [PATCH 34/44] revert one test --- test/test_onnx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_onnx.py b/test/test_onnx.py index 77ee315fe93..1f4f3b25ae9 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -267,6 +267,7 @@ def forward(self, input, boxes): self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)]) + @unittest.skip def test_roi_heads(self): class RoiHeadsModule(torch.nn.Module): def __init__(self_module): From faa16aa616fc7fa002ff591ed2d3145a84256a1c Mon Sep 17 00:00:00 2001 From: neginraoof Date: Fri, 27 Mar 2020 03:40:31 -0700 Subject: [PATCH 35/44] enable tests --- test/test_onnx.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index 1f4f3b25ae9..b7e78ccb855 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -267,7 +267,6 @@ def forward(self, input, boxes): self.run_model(TransformModule(), [(i, [boxes],), (i1, [boxes1],)]) - @unittest.skip def test_roi_heads(self): class RoiHeadsModule(torch.nn.Module): def __init__(self_module): @@ -288,7 +287,7 @@ def forward(self_module, images, features): images = torch.rand(2, 3, 100, 100) features = self.get_features(images) - images2 = torch.rand(2, 3, 200, 200) + images2 = torch.rand(2, 3, 150, 150) test_features = self.get_features(images2) model = RoiHeadsModule() From d8ec5c5befa936ff4ecbade949139b15e09f738e Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Fri, 27 Mar 2020 09:48:21 -0700 Subject: [PATCH 36/44] Update .travis.yml --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index b9dde83362b..c5d90698860 100644 --- a/.travis.yml +++ b/.travis.yml @@ -67,7 +67,7 @@ install: cd - script: - - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -v -k 'not TestVideoReader and not TestVideoTransforms' test + - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -v -ra -k 'not TestVideoReader and not TestVideoTransforms' test - pytest test/test_hub.py after_success: From 85daf174548df683e11b68a8520acbd26ff0d6dc Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Fri, 27 Mar 2020 10:26:34 -0700 Subject: [PATCH 37/44] Update .travis.yml --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index c5d90698860..4bf66c2e52a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -67,7 +67,7 @@ install: cd - script: - - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -v -ra -k 'not TestVideoReader and not TestVideoTransforms' test + - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -v -rA -k 'not TestVideoReader and not TestVideoTransforms' test - pytest test/test_hub.py after_success: From 52aaa67140c6d3d0f04f3216b76ba8f2457ebf47 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Fri, 27 Mar 2020 15:41:40 -0700 Subject: [PATCH 38/44] Update .travis.yml --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 4bf66c2e52a..1d8beae5161 100644 --- a/.travis.yml +++ b/.travis.yml @@ -67,7 +67,7 @@ install: cd - script: - - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -v -rA -k 'not TestVideoReader and not TestVideoTransforms' test + - pytest -rA --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -v -k 'not TestVideoReader and not TestVideoTransforms' test - pytest test/test_hub.py after_success: From c30cce70aceabfd65a4555511077f00ec08cdb0e Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Sat, 28 Mar 2020 22:59:45 -0700 Subject: [PATCH 39/44] Update test_onnx.py --- test/test_onnx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index b7e78ccb855..dd5af85f37d 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -329,7 +329,7 @@ def get_test_images(self): def test_faster_rcnn(self): images, test_images = self.get_test_images() - model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True) + model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model.eval() model(images) self.run_model(model, [(images,), (test_images,)], input_names=["images_tensors"], @@ -375,7 +375,7 @@ def test_paste_mask_in_image(self): def test_mask_rcnn(self): images, test_images = self.get_test_images() - model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True) + model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) model.eval() model(images) self.run_model(model, [(images,), (test_images,)], @@ -417,7 +417,7 @@ def test_keypoint_rcnn(self): class KeyPointRCNN(torch.nn.Module): def __init__(self): super(KeyPointRCNN, self).__init__() - self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True) + self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300) def forward(self, images): output = self.model(images) From 3032e8e0b1dc38005ca2d0d85529a203ce3a97b9 Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Sat, 28 Mar 2020 23:38:41 -0700 Subject: [PATCH 40/44] Update .travis.yml --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index 1d8beae5161..1b6ecb7a65b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -46,7 +46,7 @@ before_install: - pip install typing - | if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then - pip install -q --user onnxruntime + pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.0.0.dev1123 fi - conda install av -c conda-forge @@ -67,7 +67,7 @@ install: cd - script: - - pytest -rA --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -v -k 'not TestVideoReader and not TestVideoTransforms' test + - pytest --cov-config .coveragerc --cov torchvision --cov $TV_INSTALL_PATH -k 'not TestVideoReader and not TestVideoTransforms' test - pytest test/test_hub.py after_success: From 04a380629d89722333e3f2ff36344dc6b0800e8c Mon Sep 17 00:00:00 2001 From: Negin Raoof Date: Sun, 29 Mar 2020 01:23:39 -0700 Subject: [PATCH 41/44] Passing device --- torchvision/models/detection/rpn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 02ebd6b8380..ddebcdc5461 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -161,9 +161,9 @@ def forward(self, image_list, feature_maps): # type: (ImageList, List[Tensor]) grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) image_size = image_list.tensors.shape[-2:] - strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64), - torch.tensor(image_size[1] / g[1], dtype=torch.int64)] for g in grid_sizes] dtype, device = feature_maps[0].dtype, feature_maps[0].device + strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64, device=device), + torch.tensor(image_size[1] / g[1], dtype=torch.int64, device=device)] for g in grid_sizes] self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) anchors = torch.jit.annotate(List[List[torch.Tensor]], []) From dbc2c9d0273e0dff36c7e2aef1d3aa33ec58df4f Mon Sep 17 00:00:00 2001 From: neginraoof Date: Mon, 30 Mar 2020 13:47:26 -0700 Subject: [PATCH 42/44] Fixes for test --- test/test_onnx.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index dd5af85f37d..002d6df7836 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -102,14 +102,18 @@ def forward(self, boxes, scores): def test_clip_boxes_to_image(self): boxes = torch.randint(10, (5, 4)) - boxes[:, 2:] += torch.randint(500, (5, 2)) + boxes[:, 2:] += boxes[:, :2] + torch.randint(500, (5, 2)) size = torch.randn(200, 300) + size_2 = torch.randn(300, 400) + class Module(torch.nn.Module): def forward(self, boxes, size): return ops.boxes.clip_boxes_to_image(boxes, size.shape) - self.run_model(Module(), [(boxes, size)]) + self.run_model(Module(), [(boxes, size), (boxes, size_2)], + input_names=["boxes", "size"], + dynamic_axes={"size": [0, 1]}) def test_roi_align(self): x = torch.rand(1, 1, 10, 10, dtype=torch.float32) From c82321b55b06bde5e1ba31c96b90078a8ffcb165 Mon Sep 17 00:00:00 2001 From: neginraoof Date: Mon, 30 Mar 2020 15:58:29 -0700 Subject: [PATCH 43/44] Fix for boxes datatype --- test/test_onnx.py | 5 +++-- torchvision/ops/boxes.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index 002d6df7836..4103ee2f8f8 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -101,8 +101,9 @@ def forward(self, boxes, scores): self.run_model(Module(), [(boxes, scores)]) def test_clip_boxes_to_image(self): - boxes = torch.randint(10, (5, 4)) - boxes[:, 2:] += boxes[:, :2] + torch.randint(500, (5, 2)) + boxes = torch.randn(5, 4) * 500 + boxes[:, 2:] += boxes[:, :2] + print(boxes) size = torch.randn(200, 300) size_2 = torch.randn(300, 400) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index dfd1e9d815c..93039a4be77 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -115,10 +115,10 @@ def clip_boxes_to_image(boxes, size): height, width = size if torchvision._is_tracing(): - boxes_x = torch.max(boxes_x.to(torch.float32), torch.tensor(0., device=boxes.device)) - boxes_x = torch.min(boxes_x.to(torch.float32), torch.tensor(width, dtype=torch.float32, device=boxes.device)) - boxes_y = torch.max(boxes_y.to(torch.float32), torch.tensor(0., device=boxes.device)) - boxes_y = torch.min(boxes_y.to(torch.float32), torch.tensor(height, dtype=torch.float32, device=boxes.device)) + boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device)) + boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device)) + boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device)) + boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device)) else: boxes_x = boxes_x.clamp(min=0, max=width) boxes_y = boxes_y.clamp(min=0, max=height) From 35142fa352fe494c5fa61a6fc564182a5131794f Mon Sep 17 00:00:00 2001 From: neginraoof Date: Mon, 30 Mar 2020 15:58:55 -0700 Subject: [PATCH 44/44] clean up --- test/test_onnx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index 4103ee2f8f8..bf3fed371d6 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -103,7 +103,6 @@ def forward(self, boxes, scores): def test_clip_boxes_to_image(self): boxes = torch.randn(5, 4) * 500 boxes[:, 2:] += boxes[:, :2] - print(boxes) size = torch.randn(200, 300) size_2 = torch.randn(300, 400)