Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX export for variable input sizes #1840

Merged
merged 61 commits into from
Mar 31, 2020

Conversation

neginraoof
Copy link
Contributor

Fixes and tests for exporting onnx model that supports variable input sizes.

@@ -466,7 +474,9 @@ def forward(self, images, features, targets=None):
anchors = self.anchor_generator(images, features)

num_images = len(anchors)
num_anchors_per_level = [o[0].numel() for o in objectness]
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line will be replaced by:
# num_anchors_per_level = [torch.prod(shape_as_tensor(o[0])) for o in objectness]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't numel() dynamic on the input shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the output is an integer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this something that we might want to change on ONNX side? Because shape in PyTorch currently returns a Size object, which is a tuple[int] under the hood, and using shape works inn ONNX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it could be fixed in exporter. Numel is different than shape since it is traced as a constant in the IR graph. I tested this with a small repro, but I haven't looked deeper.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a first pass, thanks for the PR!

Tests are failing for torchscript, so this can't be merged as is.

Let me know if you need help to address the torchscript issues.

Comment on lines 159 to 160
strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64),
torch.tensor(image_size[1] / g[1], dtype=torch.int64)] for g in grid_sizes]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, I wonder if lines 124-126 are still necessary?

            if torchvision._is_tracing():
                # required in ONNX export for mult operation with float32
                stride_width = torch.tensor(stride_width, dtype=torch.float32)
                stride_height = torch.tensor(stride_height, dtype=torch.float32)

Also, do we need to pass the device in the tensor constructor, so that it works for CUDA as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for pointing this out. The lines here are not needed anymore. I'll remove this part.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also pass the device to the tensor constructor?

@@ -354,7 +354,17 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
# type: (Tensor, List[int])
r = []
offset = 0
for ob in objectness.split(num_anchors_per_level, 1):
if torchvision._is_tracing():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It now looks like this function is completely different (or almost) for ONNX and the other code-path.

Also, would this become unnecessary when pytorch/pytorch#32493 gets merged?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is not addressing this issue. Currently split with dynamic sizes is not supported in tracing.
I agree that the behavior is much different here. I can implement _get_top_n_idx_for_onnx.
What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually going to add support for split with dynamic input in the exporter. We won't need these extra code-paths anymore.

@@ -466,7 +474,9 @@ def forward(self, images, features, targets=None):
anchors = self.anchor_generator(images, features)

num_images = len(anchors)
num_anchors_per_level = [o[0].numel() for o in objectness]
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't numel() dynamic on the input shape?

torchvision/models/detection/transform.py Outdated Show resolved Hide resolved
Comment on lines 118 to 119
height = height.to(torch.float32)
width = width.to(torch.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this will break torchscript compatibility, because it would imagine that height and width are int, which do not have a .to method.

Would it be possible to make clamp work with dynamic values? In PyTorch, .clamp already support tensor values for min / max.

Copy link
Contributor Author

@neginraoof neginraoof Feb 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe clamp is traces with constants for min/max:
Tensor clamp(const Tensor & self, c10::optional min, c10::optional max)

Would this cause a failure with branch:
if torchvision._is_tracing():

@@ -17,8 +17,7 @@
@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
# type: (Tensor, int) -> Tuple[int, int]
from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
num_anchors = ob.shape[1].unsqueeze(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now we don't need to use operators.shape_as_tensor? Does ONNX now trace the .shape values from tensors and keep them as dynamic variables?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, now we are able to trace the shapes dynamically.

@neginraoof
Copy link
Contributor Author

@dxxz Thanks for the comment.
So far, you should not see any mismatches when running the model with same size input. Please let us know if this is the case.
However our work for running the model with dynamic input size is still in progress and is pending a couple PRs:
pytorch/pytorch#32587
pytorch/pytorch#33161
I'll update this PR after these changes go in.

@neginraoof
Copy link
Contributor Author

@dxxz
Thanks. I just updated this part of the code. Please try out and see. I was able to successfully run this with new input size.

@neginraoof
Copy link
Contributor Author

@dxxz Yes, I'm aware of this issue as well. We cannot fix this from onnx export side.
Another idea here would be to provide min_size and max_size as inputs to the model, so that we can calculate scale_factor dynamically from this part of the code:

@neginraoof
Copy link
Contributor Author

@lara-hdr
This PR is ready for review. Thanks!
cc @fmassa
I've prepared this PR for review. How ever, this PR does not fix the issue with export for dynamic input sizes.
The problem we are currently facing in transform (resize) module is related to:
pytorch/pytorch#33443

In this part of the code:

image = torch.nn.functional.interpolate(

The scale_factor cannot be traced dynamically since it is calculated from the model state parameters (min_size and max_size). Setting these parameters set's the scale_factor as constant in the traced model. However, for variable input sizes, the scale_factor in interpolate op may vary.

@codecov-io
Copy link

codecov-io commented Feb 25, 2020

Codecov Report

Merging #1840 into master will decrease coverage by <.01%.
The diff coverage is 0%.

Impacted file tree graph

@@            Coverage Diff            @@
##           master   #1840      +/-   ##
=========================================
- Coverage    0.48%   0.48%   -0.01%     
=========================================
  Files          92      92              
  Lines        7449    7450       +1     
  Branches     1135    1135              
=========================================
  Hits           36      36              
- Misses       7400    7401       +1     
  Partials       13      13
Impacted Files Coverage Δ
torchvision/ops/boxes.py 0% <0%> (ø) ⬆️
torchvision/models/detection/roi_heads.py 0% <0%> (ø) ⬆️
torchvision/models/detection/transform.py 0% <0%> (ø) ⬆️
torchvision/models/detection/rpn.py 0% <0%> (ø) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3c254fb...04a3806. Read the comment docs.

Copy link
Contributor

@lara-hdr lara-hdr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR Negin!


if torchvision._is_tracing():
boxes_x = torch.max(boxes_x, torch.tensor(0.))
boxes_x = torch.min(boxes_x, torch.tensor(width, dtype=torch.float32))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be better to use scalar_tensor() here as well?

# type: (List[List[int]], List[List[int]])
key = str(grid_sizes + strides)
# type: (List[List[int]], List[List[Tensor]])
key = str(str(grid_sizes) + str(strides))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious of what is this change for

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the strides is changed to a list of tensor type, and grid_sizes is a list of ints. So we cannot concat those and then cast to string.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @neginraoof !

I have a few more questions. Also, can you explain which cases this PR can fix in its current state?

@@ -157,7 +153,8 @@ def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor])
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]
strides = [[torch.scalar_tensor(image_size[0] / g[0], dtype=torch.int64),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is scalar_tensor different than tensor with a scalar value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only difference is that it only holds scalars. We want to be consistent in using scalar tensors for all dynamically traced scalars.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that scalar_tensor is not part of the public API of PyTorch (and it's not documented), would it give different results / be a problem if we use tensor instead?

torchvision/models/detection/rpn.py Outdated Show resolved Hide resolved
@@ -466,7 +474,9 @@ def forward(self, images, features, targets=None):
anchors = self.anchor_generator(images, features)

num_images = len(anchors)
num_anchors_per_level = [o[0].numel() for o in objectness]
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this something that we might want to change on ONNX side? Because shape in PyTorch currently returns a Size object, which is a tuple[int] under the hood, and using shape works inn ONNX

@@ -193,7 +195,7 @@ def __repr__(self):

def resize_keypoints(keypoints, original_size, new_size):
# type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratios = [torch.scalar_tensor(s) / torch.scalar_tensor(s_orig) for s, s_orig in zip(new_size, original_size)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clarify the difference between scalar_tensor and tensor with a scalar value? IIRC scalar_tensor is an internal API and non documented

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any major difference between the two. I don't see any issues on export side when using tensor. I'm trying to keep these changes consistent within the code and looks like scalar_tensor was mainly used for dynamic export of scalars.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there are no major differences, can we instead use tensor, given that scalar_tensor is an internal API?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as here, maybe passing float32 dtype would be better

boxes_y = torch.max(boxes_y, torch.tensor(0.))
boxes_y = torch.min(boxes_y, torch.tensor(height, dtype=torch.float32))
else:
boxes_x = boxes_x.clamp(min=0, max=width)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From your previous comment,

I believe clamp is traces with constants for min/max:
Tensor clamp(const Tensor & self, c10::optional min, c10::optional max)

Would this cause a failure with branch:
if torchvision._is_tracing():

Can we make it trace with tensors as well if tensors are passed?

Copy link
Contributor Author

@neginraoof neginraoof Mar 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we have the same issue. The traced graph for clamp shows min/max as tensors even when a tensor is passed.
%7 : int = prim::Constant [ value=2 ] ( )
%8 : None = prim::Constant()
%9 : Float(3, 4) = aten::clamp(%0, %7, %8)
I haven't looked deeper.
The PR to address this: pytorch/pytorch#32587
Maybe we can have this branch until the PR is merged.
I'm not sure if changing it to tensors would help much. Let me know what you think. Thanks.

@@ -193,7 +194,7 @@ def __repr__(self):

def resize_keypoints(keypoints, original_size, new_size):
# type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratios = [torch.tensor(s) / torch.tensor(s_orig) for s, s_orig in zip(new_size, original_size)]
Copy link
Member

@fmassa fmassa Mar 26, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we pass the dtype=torch.float32 here and the device, just to make sure that we don't perform integer division?

@@ -208,7 +209,7 @@ def resize_keypoints(keypoints, original_size, new_size):

def resize_boxes(boxes, original_size, new_size):
# type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratios = [torch.tensor(s, dtype=torch.float32) / torch.tensor(s_orig, dtype=torch.float32) for s, s_orig in zip(new_size, original_size)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have the device been passed here as well?

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for all your work @neginraoof !

Only thing I would ask is an extra test case for test_clip_boxes_to_image, to validate that it works for any input size after it has been exported.

Apart from that, I have a comment on casting to float32 the boxes, as I think it could be improved, but otherwise this looks good to merge, thanks a lot!

test/test_onnx.py Outdated Show resolved Hide resolved
test/test_onnx.py Outdated Show resolved Hide resolved
test/test_onnx.py Show resolved Hide resolved
torchvision/models/detection/roi_heads.py Show resolved Hide resolved
boxes_y = boxes_y.clamp(min=0, max=height)

if torchvision._is_tracing():
boxes_x = torch.max(boxes_x.to(torch.float32), torch.tensor(0., device=boxes.device))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering, do we need to cast the boxes to float32 here, or can we just use the boxes_x dtype while constructing the scalar tensor?

Copy link
Contributor Author

@neginraoof neginraoof Mar 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using boxes.dtype, I'm actually seeing an error from ONNX Runtime:
[ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(int64)' of input parameter (17) of operator (Max) in node (Max_21) is invalid.

Will need to look into the issue.

@neginraoof
Copy link
Contributor Author

Test failure is in: torchvision/csrc/cpu/DeformConv_cpu.cpp
This is unrelated.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @neginraoof

@fmassa fmassa merged commit 986d242 into pytorch:master Mar 31, 2020
fmassa added a commit to fmassa/vision-1 that referenced this pull request Jun 9, 2020
* fixes and tests for variable input size

* transform test fix

* Fix comment

* Dynamic shape for keypoint_rcnn

* Update test_onnx.py

* Update rpn.py

* Fix for split on RPN

* Fixes for feedbacks

* flake8

* topk fix

* Fix build

* branch on tracing

* fix for scalar tensor

* Fixes for script type annotations

* Update rpn.py

* clean up

* clean up

* Update rpn.py

* Updated for feedback

* Fix for comments

* revert to use tensor

* Added test for box clip

* Fixes for feedback

* Fix for feedback

* ORT version revert

* Update ort

* Update .travis.yml

* Update test_onnx.py

* Update test_onnx.py

* Tensor sizes

* Fix for dynamic split

* Try disable tests

* pytest verbose

* revert one test

* enable tests

* Update .travis.yml

* Update .travis.yml

* Update .travis.yml

* Update test_onnx.py

* Update .travis.yml

* Passing device

* Fixes for test

* Fix for boxes datatype

* clean up

Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants