Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Fix export of images for KeypointRCNN #2272

Merged
merged 3 commits into from
Jun 4, 2020

Conversation

KsenijaS
Copy link
Contributor

Add support in KeypointRCNN for:
exporting for image w/ detection -> running on image w/ no detection
exporting for image w/ no detection -> running on image w/ detection

@KsenijaS
Copy link
Contributor Author

@neginraoof please review

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few questions, but overall looks good to me, thanks!

torchvision/models/detection/keypoint_rcnn.py Show resolved Hide resolved
x = misc_nn_ops.interpolate(
x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False
return torch.nn.functional.interpolate(
x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False, recompute_scale_factor=False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't up_scale a integer (equals to 2)? As such, is there a difference between recompute_scale_factor=True and False? Or is it because the output_size in this case gets registered as a constant?

Copy link
Contributor

@neginraoof neginraoof Jun 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With recompute_scale_factor, when input has a 0 dim, the inferred scale_factor will be 0. Although this is expected, ONNX spec does not allow for Resize scale to be <= 0.

Just a note that I think the default value for recompute_scale_factor is changing to False soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With recompute_scale_factor, when input has a 0 dim, the inferred scale_factor will be 0.

This to me looks like a bug while inferring the scale_factor, as the batch dimension shouldn't affect the other sizes of the tensor (which are still valid).

Note that we use tensors of shape [0, C, H ,W], so it is fully-possible to infer the correct scale factor I believe

Copy link
Contributor

@neginraoof neginraoof Jun 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked more into this. Looks like this is an error in inferring scale_factor in ONNX Runtime.
Upsample_bicubic2d is exported to ONNX and running in ONNX Runtime if you have an input like: torch.randn(1, 2, 4, 6, requires_grad=True)
ONNX graph also looks ok.
But the same mode fails for an input like torch.randn(0, 2, 4, 6, requires_grad=True)
The reason is that ONNX Runtime is inferring scale for across all dimensions:

    for (size_t i = 0, end = input_dims.size(); i < end; ++i) {
      scales[i] = static_cast<float>(output_dims[i]) / static_cast<float>(input_dims[i]);
    }

For a 4D input, output_dims is inferred output shape (e.g. [n, c, hscale, wscale]) and input_dims is actual input shape (e.g. [n, c, h, w]).
This is a bug in such edge cases where n = 0.
We will create an issue for ORT to address this.

But we still need Ksenija's update to switch to nn.functional.interpolate.
And here is PR I sent out to switch the default recompute_scale_factor to False:
pytorch/pytorch#39453

@KsenijaS
Copy link
Contributor Author

KsenijaS commented Jun 3, 2020

@fmassa can this PR be merged? Thanks!

@fmassa
Copy link
Member

fmassa commented Jun 3, 2020

@KsenijaS I've made a comment on the top, not sure if you agree with it?

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, thanks for the explanation and the fix!

@fmassa fmassa merged commit de52437 into pytorch:master Jun 4, 2020
@neginraoof
Copy link
Contributor

@fmassa Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants