Skip to content

Commit

Permalink
ONNX export for variable input sizes (pytorch#1840)
Browse files Browse the repository at this point in the history
* fixes and tests for variable input size

* transform test fix

* Fix comment

* Dynamic shape for keypoint_rcnn

* Update test_onnx.py

* Update rpn.py

* Fix for split on RPN

* Fixes for feedbacks

* flake8

* topk fix

* Fix build

* branch on tracing

* fix for scalar tensor

* Fixes for script type annotations

* Update rpn.py

* clean up

* clean up

* Update rpn.py

* Updated for feedback

* Fix for comments

* revert to use tensor

* Added test for box clip

* Fixes for feedback

* Fix for feedback

* ORT version revert

* Update ort

* Update .travis.yml

* Update test_onnx.py

* Update test_onnx.py

* Tensor sizes

* Fix for dynamic split

* Try disable tests

* pytest verbose

* revert one test

* enable tests

* Update .travis.yml

* Update .travis.yml

* Update .travis.yml

* Update test_onnx.py

* Update .travis.yml

* Passing device

* Fixes for test

* Fix for boxes datatype

* clean up

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
neginraoof and fmassa committed Jun 9, 2020
1 parent 2c7d6e1 commit 1ff8f90
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 68 deletions.
124 changes: 81 additions & 43 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ class ONNXExporterTester(unittest.TestCase):
def setUpClass(cls):
torch.manual_seed(123)

def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True):
def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
output_names=None, input_names=None):
model.eval()

onnx_io = io.BytesIO()
# export to onnx with the first input
torch.onnx.export(model, inputs_list[0], onnx_io,
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version)

do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version,
dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
with torch.no_grad():
Expand Down Expand Up @@ -99,6 +100,21 @@ def forward(self, boxes, scores):

self.run_model(Module(), [(boxes, scores)])

def test_clip_boxes_to_image(self):
boxes = torch.randn(5, 4) * 500
boxes[:, 2:] += boxes[:, :2]
size = torch.randn(200, 300)

size_2 = torch.randn(300, 400)

class Module(torch.nn.Module):
def forward(self, boxes, size):
return ops.boxes.clip_boxes_to_image(boxes, size.shape)

self.run_model(Module(), [(boxes, size), (boxes, size_2)],
input_names=["boxes", "size"],
dynamic_axes={"size": [0, 1]})

def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
Expand All @@ -123,9 +139,9 @@ def __init__(self_module):
def forward(self_module, images):
return self_module.transform(images)[0].tensors

input = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
input_test = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
self.run_model(TransformModule(), [input, input_test])
input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
self.run_model(TransformModule(), [(input,), (input_test,)])

def _init_test_generalized_rcnn_transform(self):
min_size = 100
Expand Down Expand Up @@ -207,22 +223,28 @@ def get_features(self, images):

def test_rpn(self):
class RPNModule(torch.nn.Module):
def __init__(self_module, images):
def __init__(self_module):
super(RPNModule, self_module).__init__()
self_module.rpn = self._init_test_rpn()
self_module.images = ImageList(images, [i.shape[-2:] for i in images])

def forward(self_module, features):
return self_module.rpn(self_module.images, features)
def forward(self_module, images, features):
images = ImageList(images, [i.shape[-2:] for i in images])
return self_module.rpn(images, features)

images = torch.rand(2, 3, 600, 600)
images = torch.rand(2, 3, 150, 150)
features = self.get_features(images)
test_features = self.get_features(images)
images2 = torch.rand(2, 3, 80, 80)
test_features = self.get_features(images2)

model = RPNModule(images)
model = RPNModule()
model.eval()
model(features)
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
model(images, features)

self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3],
"input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3],
"input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})

def test_multi_scale_roi_align(self):

Expand Down Expand Up @@ -251,63 +273,73 @@ def forward(self, input, boxes):

def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module):
def __init__(self_module, images):
def __init__(self_module):
super(RoiHeadsModule, self_module).__init__()
self_module.transform = self._init_test_generalized_rcnn_transform()
self_module.rpn = self._init_test_rpn()
self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()
self_module.original_image_sizes = [img.shape[-2:] for img in images]
self_module.images = ImageList(images, [i.shape[-2:] for i in images])

def forward(self_module, features):
proposals, _ = self_module.rpn(self_module.images, features)
detections, _ = self_module.roi_heads(features, proposals, self_module.images.image_sizes)
def forward(self_module, images, features):
original_image_sizes = [img.shape[-2:] for img in images]
images = ImageList(images, [i.shape[-2:] for i in images])
proposals, _ = self_module.rpn(images, features)
detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
detections = self_module.transform.postprocess(detections,
self_module.images.image_sizes,
self_module.original_image_sizes)
images.image_sizes,
original_image_sizes)
return detections

images = torch.rand(2, 3, 600, 600)
images = torch.rand(2, 3, 100, 100)
features = self.get_features(images)
test_features = self.get_features(images)
images2 = torch.rand(2, 3, 150, 150)
test_features = self.get_features(images2)

model = RoiHeadsModule(images)
model = RoiHeadsModule()
model.eval()
model(features)
self.run_model(model, [(features,), (test_features,)])
model(images, features)

def get_image_from_url(self, url):
self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3],
"input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})

def get_image_from_url(self, url, size=None):
import requests
import numpy
from PIL import Image
from io import BytesIO
from torchvision import transforms

data = requests.get(url)
image = Image.open(BytesIO(data.content)).convert("RGB")
image = image.resize((300, 200), Image.BILINEAR)

if size is None:
size = (300, 200)
image = image.resize(size, Image.BILINEAR)

to_tensor = transforms.ToTensor()
return to_tensor(image)

def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url)
image = self.get_image_from_url(url=image_url, size=(200, 300))

image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2)
image2 = self.get_image_from_url(url=image_url2, size=(250, 200))

images = [image]
test_images = [image2]
return images, test_images

def test_faster_rcnn(self):
images, test_images = self.get_test_images()

model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True,
min_size=200,
max_size=300)
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])
self.run_model(model, [(images,), (test_images,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
tolerate_small_mismatch=True)

# Verify that paste_mask_in_image beahves the same in tracing.
# This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
Expand Down Expand Up @@ -350,7 +382,11 @@ def test_mask_rcnn(self):
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])
self.run_model(model, [(images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
tolerate_small_mismatch=True)

# Verify that heatmaps_to_keypoints behaves the same in tracing.
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
Expand Down Expand Up @@ -384,9 +420,7 @@ def test_keypoint_rcnn(self):
class KeyPointRCNN(torch.nn.Module):
def __init__(self):
super(KeyPointRCNN, self).__init__()
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True,
min_size=200,
max_size=300)
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)

def forward(self, images):
output = self.model(images)
Expand All @@ -398,8 +432,12 @@ def forward(self, images):
images, test_images = self.get_test_images()
model = KeyPointRCNN()
model.eval()
model(test_images)
self.run_model(model, [(images,), (test_images,)])
model(images)
self.run_model(model, [(images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2, 3]},
tolerate_small_mismatch=True)


if __name__ == '__main__':
Expand Down
13 changes: 3 additions & 10 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,20 +679,13 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_
device = class_logits.device
num_classes = class_logits.shape[-1]

boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
pred_boxes = self.box_coder.decode(box_regression, proposals)

pred_scores = F.softmax(class_logits, -1)

# split boxes and scores per image
if len(boxes_per_image) == 1:
# TODO : remove this when ONNX support dynamic split sizes
# and just assign to pred_boxes instead of pred_boxes_list
pred_boxes_list = [pred_boxes]
pred_scores_list = [pred_scores]
else:
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
pred_scores_list = pred_scores.split(boxes_per_image, 0)
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
pred_scores_list = pred_scores.split(boxes_per_image, 0)

all_boxes = []
all_scores = []
Expand Down
17 changes: 7 additions & 10 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def num_anchors_per_location(self):
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[int]])
# type: (List[List[int]], List[List[Tensor]])
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
Expand All @@ -128,10 +128,6 @@ def grid_anchors(self, grid_sizes, strides):
):
grid_height, grid_width = size
stride_height, stride_width = stride
if torchvision._is_tracing():
# required in ONNX export for mult operation with float32
stride_width = torch.tensor(stride_width, dtype=torch.float32)
stride_height = torch.tensor(stride_height, dtype=torch.float32)
device = base_anchors.device

# For output anchor, compute [x_center, y_center, x_center, y_center]
Expand All @@ -155,8 +151,8 @@ def grid_anchors(self, grid_sizes, strides):
return anchors

def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[int]])
key = str(grid_sizes + strides)
# type: (List[List[int]], List[List[Tensor]])
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
Expand All @@ -167,9 +163,9 @@ def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor])
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]

dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] / g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
Expand Down Expand Up @@ -484,7 +480,8 @@ def forward(self, images, features, targets=None):
anchors = self.anchor_generator(images, features)

num_images = len(anchors)
num_anchors_per_level = [o[0].numel() for o in objectness]
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
objectness, pred_bbox_deltas = \
concat_box_prediction_layers(objectness, pred_bbox_deltas)
# apply pred_bbox_deltas to anchors to obtain the decoded proposals
Expand Down
9 changes: 6 additions & 3 deletions torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def resize(self, image, target):
if max_size * scale_factor > self.max_size:
scale_factor = self.max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0]
image[None], scale_factor=scale_factor, mode='bilinear',
align_corners=False)[0]

if target is None:
return image, target
Expand Down Expand Up @@ -184,7 +185,8 @@ def postprocess(self, result, image_shapes, original_image_sizes):

def resize_keypoints(keypoints, original_size, new_size):
# type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratios = [torch.tensor(s, dtype=torch.float32, device=keypoints.device) / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
for s, s_orig in zip(new_size, original_size)]
ratio_h, ratio_w = ratios
resized_data = keypoints.clone()
if torch._C._get_tracing_state():
Expand All @@ -199,7 +201,8 @@ def resize_keypoints(keypoints, original_size, new_size):

def resize_boxes(boxes, original_size, new_size):
# type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratios = [torch.tensor(s, dtype=torch.float32, device=boxes.device) / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
for s, s_orig in zip(new_size, original_size)]
ratio_height, ratio_width = ratios
xmin, ymin, xmax, ymax = boxes.unbind(1)

Expand Down
13 changes: 11 additions & 2 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.jit.annotations import Tuple
from torch import Tensor
import torchvision


def nms(boxes, scores, iou_threshold):
Expand Down Expand Up @@ -112,8 +113,16 @@ def clip_boxes_to_image(boxes, size):
boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2]
height, width = size
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height)

if torchvision._is_tracing():
boxes_x = torch.max(boxes_x, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=boxes.dtype, device=boxes.device))
boxes_y = torch.max(boxes_y, torch.tensor(0, dtype=boxes.dtype, device=boxes.device))
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=boxes.dtype, device=boxes.device))
else:
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height)

clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
return clipped_boxes.reshape(boxes.shape)

Expand Down

0 comments on commit 1ff8f90

Please sign in to comment.