Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated annotation to be a Union of Tensor and List #4416

Merged
merged 3 commits into from
Sep 15, 2021

Conversation

prabhat00155
Copy link
Contributor

Resolves #3737

import torch
from torchvision.ops import roi_align, roi_pool


def _make_rois2(img_size, num_imgs, dtype, num_rois=1000): 
    rois = torch.randint(0, img_size // 2, size=(num_rois, 4)).to(dtype) 
    rois[:, 0] = torch.randint(0, num_imgs, size=(num_rois,)) 
    rois[:, 2:] += rois[:, 1:2]  
    return [rois]


pool_size = 5
img_size = 10
n_channels = 2
num_imgs = 1
dtype = torch.float

x = torch.randint(50, 100, size=(num_imgs, n_channels, img_size, img_size)).to(dtype)
rois2 = _make_rois2(img_size, num_imgs, dtype)

f1 = torch.jit.script(roi_align)
f2 = torch.jit.script(roi_pool)

print(torch.mean((f1(x, rois2, pool_size) == roi_align(x, rois2, pool_size)).to(float)))
print(torch.mean((f2(x, rois2, pool_size) == roi_pool(x, rois2, pool_size)).to(float)))
tensor(1., dtype=torch.float64)
tensor(1., dtype=torch.float64)

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @prabhat00155. The current op tests will fail if you break JIT scriptability so feel free to merge once the CI turns green.

@prabhat00155 prabhat00155 merged commit b096271 into pytorch:main Sep 15, 2021
@prabhat00155 prabhat00155 deleted the prabhat00155/use_union branch September 15, 2021 16:43
facebook-github-bot pushed a commit that referenced this pull request Sep 30, 2021
Summary:
* Updated annotation to be a Union of Tensor and List

* Updated check_roi_boxes_shape.

Reviewed By: datumbox

Differential Revision: D31268033

fbshipit-source-id: 38f7c4c356862c8cd7785460a6a7a73647e9f519
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

roi_align and roi_pool cannot be scripted due to type annotation issue
3 participants