Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX export for variable input sizes #1840

Merged
merged 61 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
e48f3f7
fixes and tests for variable input size
neginraoof Jan 30, 2020
8df5790
transform test fix
neginraoof Jan 30, 2020
03bc884
Fix comment
neginraoof Jan 30, 2020
9b6142f
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Jan 30, 2020
e3d6040
Merge branch 'neraoof/variableInput' of https://github.com/neginraoof…
neginraoof Jan 30, 2020
37888e2
Dynamic shape for keypoint_rcnn
neginraoof Jan 30, 2020
84786b0
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Feb 4, 2020
ebcd45b
Update test_onnx.py
neginraoof Feb 4, 2020
89404f3
Update rpn.py
neginraoof Feb 5, 2020
b6f7859
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Feb 18, 2020
60ba5e7
Fix for split on RPN
neginraoof Feb 18, 2020
10a9fcb
Merge branch 'neraoof/variableInput' of github.com:neginraoof/vision …
neginraoof Feb 18, 2020
ea5cf6e
Fixes for feedbacks
neginraoof Feb 20, 2020
dced83a
flake8
neginraoof Feb 20, 2020
da44102
topk fix
neginraoof Feb 20, 2020
cd79435
Fix build
neginraoof Feb 20, 2020
fbe4680
branch on tracing
neginraoof Feb 20, 2020
2b4ad07
fix for scalar tensor
neginraoof Feb 20, 2020
be0ae7e
Fixes for script type annotations
neginraoof Feb 25, 2020
e6e3109
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Feb 25, 2020
829b58b
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Feb 25, 2020
7999e55
Update rpn.py
neginraoof Mar 3, 2020
94b1ac6
clean up
neginraoof Mar 3, 2020
050e756
clean up
neginraoof Mar 3, 2020
a445d4a
Update rpn.py
neginraoof Mar 3, 2020
b0c79bb
Updated for feedback
neginraoof Mar 3, 2020
89c2a80
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 3, 2020
9d2b533
Merge branch 'master' into neraoof/variableInput
neginraoof Mar 13, 2020
931d5eb
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 16, 2020
c90b6e0
Merge branch 'neraoof/variableInput' of github.com:neginraoof/vision …
neginraoof Mar 16, 2020
ba40ea1
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 19, 2020
56b62d8
Merge branch 'master' of github.com:pytorch/vision into neraoof/varia…
fmassa Mar 20, 2020
228db38
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 24, 2020
04ff430
Fix for comments
neginraoof Mar 24, 2020
d01a8ff
Merge branch 'neraoof/variableInput' of github.com:neginraoof/vision …
neginraoof Mar 24, 2020
23bff59
revert to use tensor
neginraoof Mar 24, 2020
b9ff797
Added test for box clip
neginraoof Mar 24, 2020
2cecb1c
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 26, 2020
572e522
Fixes for feedback
neginraoof Mar 26, 2020
d0ff3f3
Fix for feedback
neginraoof Mar 26, 2020
a66381a
ORT version revert
neginraoof Mar 26, 2020
8d4b1ee
Update ort
neginraoof Mar 27, 2020
16cd2eb
Update .travis.yml
neginraoof Mar 27, 2020
c9fc98c
Update test_onnx.py
neginraoof Mar 27, 2020
f12e10c
Update test_onnx.py
neginraoof Mar 27, 2020
a3a9cd6
Tensor sizes
neginraoof Mar 27, 2020
088d514
Fix for dynamic split
neginraoof Mar 27, 2020
ce5d781
Try disable tests
neginraoof Mar 27, 2020
e826afe
pytest verbose
neginraoof Mar 27, 2020
8f9d8d5
revert one test
neginraoof Mar 27, 2020
faa16aa
enable tests
neginraoof Mar 27, 2020
d8ec5c5
Update .travis.yml
neginraoof Mar 27, 2020
85daf17
Update .travis.yml
neginraoof Mar 27, 2020
52aaa67
Update .travis.yml
neginraoof Mar 27, 2020
c30cce7
Update test_onnx.py
neginraoof Mar 29, 2020
3032e8e
Update .travis.yml
neginraoof Mar 29, 2020
04a3806
Passing device
neginraoof Mar 29, 2020
0c2b4b1
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 30, 2020
dbc2c9d
Fixes for test
neginraoof Mar 30, 2020
c82321b
Fix for boxes datatype
neginraoof Mar 30, 2020
35142fa
clean up
neginraoof Mar 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 77 additions & 43 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ class ONNXExporterTester(unittest.TestCase):
def setUpClass(cls):
torch.manual_seed(123)

def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True):
def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
output_names=None, input_names=None):
model.eval()

onnx_io = io.BytesIO()
# export to onnx with the first input
torch.onnx.export(model, inputs_list[0], onnx_io,
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version)

do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version,
dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
with torch.no_grad():
Expand Down Expand Up @@ -99,6 +100,17 @@ def forward(self, boxes, scores):

self.run_model(Module(), [(boxes, scores)])

def test_clip_boxes_to_image(self):
boxes = torch.randint(10, (5, 4))
boxes[:, 2:] += torch.randint(500, (5, 2))
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
size = torch.randn(200, 300)

class Module(torch.nn.Module):
def forward(self, boxes, size):
return ops.boxes.clip_boxes_to_image(boxes, size.shape)

self.run_model(Module(), [(boxes, size)])
neginraoof marked this conversation as resolved.
Show resolved Hide resolved

def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
Expand All @@ -123,9 +135,9 @@ def __init__(self_module):
def forward(self_module, images):
return self_module.transform(images)[0].tensors

input = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
input_test = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
self.run_model(TransformModule(), [input, input_test])
input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
self.run_model(TransformModule(), [(input,), (input_test,)])
neginraoof marked this conversation as resolved.
Show resolved Hide resolved

def _init_test_generalized_rcnn_transform(self):
min_size = 100
Expand Down Expand Up @@ -207,22 +219,28 @@ def get_features(self, images):

def test_rpn(self):
class RPNModule(torch.nn.Module):
def __init__(self_module, images):
def __init__(self_module):
super(RPNModule, self_module).__init__()
self_module.rpn = self._init_test_rpn()
self_module.images = ImageList(images, [i.shape[-2:] for i in images])

def forward(self_module, features):
return self_module.rpn(self_module.images, features)
def forward(self_module, images, features):
images = ImageList(images, [i.shape[-2:] for i in images])
return self_module.rpn(images, features)

images = torch.rand(2, 3, 600, 600)
images = torch.rand(2, 3, 150, 150)
features = self.get_features(images)
test_features = self.get_features(images)
images2 = torch.rand(2, 3, 80, 80)
test_features = self.get_features(images2)

model = RPNModule(images)
model = RPNModule()
model.eval()
model(features)
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
model(images, features)

self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3],
"input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3],
"input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})

def test_multi_scale_roi_align(self):

Expand Down Expand Up @@ -251,63 +269,73 @@ def forward(self, input, boxes):

def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module):
def __init__(self_module, images):
def __init__(self_module):
super(RoiHeadsModule, self_module).__init__()
self_module.transform = self._init_test_generalized_rcnn_transform()
self_module.rpn = self._init_test_rpn()
self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()
self_module.original_image_sizes = [img.shape[-2:] for img in images]
self_module.images = ImageList(images, [i.shape[-2:] for i in images])

def forward(self_module, features):
proposals, _ = self_module.rpn(self_module.images, features)
detections, _ = self_module.roi_heads(features, proposals, self_module.images.image_sizes)
def forward(self_module, images, features):
original_image_sizes = [img.shape[-2:] for img in images]
images = ImageList(images, [i.shape[-2:] for i in images])
proposals, _ = self_module.rpn(images, features)
detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
detections = self_module.transform.postprocess(detections,
self_module.images.image_sizes,
self_module.original_image_sizes)
images.image_sizes,
original_image_sizes)
return detections

images = torch.rand(2, 3, 600, 600)
images = torch.rand(2, 3, 100, 100)
features = self.get_features(images)
test_features = self.get_features(images)
images2 = torch.rand(2, 3, 150, 150)
test_features = self.get_features(images2)

model = RoiHeadsModule(images)
model = RoiHeadsModule()
model.eval()
model(features)
self.run_model(model, [(features,), (test_features,)])
model(images, features)

self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3],
"input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})

def get_image_from_url(self, url):
def get_image_from_url(self, url, size=None):
import requests
import numpy
from PIL import Image
from io import BytesIO
from torchvision import transforms

data = requests.get(url)
image = Image.open(BytesIO(data.content)).convert("RGB")
image = image.resize((300, 200), Image.BILINEAR)

if size is None:
size = (300, 200)
image = image.resize(size, Image.BILINEAR)

to_tensor = transforms.ToTensor()
return to_tensor(image)

def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url)
image = self.get_image_from_url(url=image_url, size=(200, 300))

image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2)
image2 = self.get_image_from_url(url=image_url2, size=(250, 200))

images = [image]
test_images = [image2]
return images, test_images

def test_faster_rcnn(self):
images, test_images = self.get_test_images()

model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True,
min_size=200,
max_size=300)
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])
self.run_model(model, [(images,), (test_images,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
tolerate_small_mismatch=True)

# Verify that paste_mask_in_image beahves the same in tracing.
# This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
Expand Down Expand Up @@ -350,7 +378,11 @@ def test_mask_rcnn(self):
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])
self.run_model(model, [(images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
tolerate_small_mismatch=True)

# Verify that heatmaps_to_keypoints behaves the same in tracing.
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
Expand Down Expand Up @@ -385,9 +417,7 @@ def test_keypoint_rcnn(self):
class KeyPointRCNN(torch.nn.Module):
def __init__(self):
super(KeyPointRCNN, self).__init__()
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True,
min_size=200,
max_size=300)
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)

def forward(self, images):
output = self.model(images)
Expand All @@ -399,8 +429,12 @@ def forward(self, images):
images, test_images = self.get_test_images()
model = KeyPointRCNN()
model.eval()
model(test_images)
self.run_model(model, [(images,), (test_images,)])
model(images)
self.run_model(model, [(images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2, 3]},
tolerate_small_mismatch=True)


if __name__ == '__main__':
Expand Down
13 changes: 3 additions & 10 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,20 +679,13 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_
device = class_logits.device
num_classes = class_logits.shape[-1]

boxes_per_image = [len(boxes_in_image) for boxes_in_image in proposals]
boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
pred_boxes = self.box_coder.decode(box_regression, proposals)

pred_scores = F.softmax(class_logits, -1)

# split boxes and scores per image
if len(boxes_per_image) == 1:
# TODO : remove this when ONNX support dynamic split sizes
# and just assign to pred_boxes instead of pred_boxes_list
pred_boxes_list = [pred_boxes]
pred_scores_list = [pred_scores]
else:
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
pred_scores_list = pred_scores.split(boxes_per_image, 0)
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
pred_scores_list = pred_scores.split(boxes_per_image, 0)

all_boxes = []
all_scores = []
Expand Down
17 changes: 7 additions & 10 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def num_anchors_per_location(self):
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[int]])
# type: (List[List[int]], List[List[Tensor]])
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
Expand All @@ -126,10 +126,6 @@ def grid_anchors(self, grid_sizes, strides):
):
grid_height, grid_width = size
stride_height, stride_width = stride
if torchvision._is_tracing():
# required in ONNX export for mult operation with float32
stride_width = torch.tensor(stride_width, dtype=torch.float32)
stride_height = torch.tensor(stride_height, dtype=torch.float32)
device = base_anchors.device

# For output anchor, compute [x_center, y_center, x_center, y_center]
Expand All @@ -153,8 +149,8 @@ def grid_anchors(self, grid_sizes, strides):
return anchors

def cached_grid_anchors(self, grid_sizes, strides):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only strides should be a tensor, and not grid_sizes? Is it because grid_sizes is returned by shape and in ONNX this becomes tensor?

Copy link
Contributor Author

@neginraoof neginraoof Mar 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because of line 164 where I'm changing int() to torch.tensor(... dtype=int64) for strides.
This was not required for grid_sizes
There's a follow up item which is to find why int() and float() result in constants in the graph.

# type: (List[List[int]], List[List[int]])
key = str(grid_sizes + strides)
# type: (List[List[int]], List[List[Tensor]])
key = str(grid_sizes) + str(strides)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lara-hdr Maybe the comment was deleted after update.
So the grid_sizes is a list of ints, and strides is a list of tensors. So we cannot concat and then cast them to string. But casting to string and concatenating the strings has the same result.

if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
Expand All @@ -165,9 +161,9 @@ def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor])
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]

dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] / g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
Expand Down Expand Up @@ -482,7 +478,8 @@ def forward(self, images, features, targets=None):
anchors = self.anchor_generator(images, features)

num_images = len(anchors)
num_anchors_per_level = [o[0].numel() for o in objectness]
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line will be replaced by:
# num_anchors_per_level = [torch.prod(shape_as_tensor(o[0])) for o in objectness]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't numel() dynamic on the input shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the output is an integer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this something that we might want to change on ONNX side? Because shape in PyTorch currently returns a Size object, which is a tuple[int] under the hood, and using shape works inn ONNX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it could be fixed in exporter. Numel is different than shape since it is traced as a constant in the IR graph. I tested this with a small repro, but I haven't looked deeper.

objectness, pred_bbox_deltas = \
concat_box_prediction_layers(objectness, pred_bbox_deltas)
# apply pred_bbox_deltas to anchors to obtain the decoded proposals
Expand Down
9 changes: 6 additions & 3 deletions torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def resize(self, image, target):
if max_size * scale_factor > self.max_size:
scale_factor = self.max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0]
image[None], scale_factor=scale_factor, mode='bilinear',
align_corners=False)[0]

if target is None:
return image, target
Expand Down Expand Up @@ -193,7 +194,8 @@ def __repr__(self):

def resize_keypoints(keypoints, original_size, new_size):
# type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratios = [torch.tensor(s, dtype=torch.float32, device=keypoints.device) / torch.tensor(s_orig, dtype=torch.float32, device=keypoints.device)
for s, s_orig in zip(new_size, original_size)]
ratio_h, ratio_w = ratios
resized_data = keypoints.clone()
if torch._C._get_tracing_state():
Expand All @@ -208,7 +210,8 @@ def resize_keypoints(keypoints, original_size, new_size):

def resize_boxes(boxes, original_size, new_size):
# type: (Tensor, List[int], List[int])
ratios = [float(s) / float(s_orig) for s, s_orig in zip(new_size, original_size)]
ratios = [torch.tensor(s, dtype=torch.float32, device=boxes.device) / torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
for s, s_orig in zip(new_size, original_size)]
ratio_height, ratio_width = ratios
xmin, ymin, xmax, ymax = boxes.unbind(1)

Expand Down
13 changes: 11 additions & 2 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.jit.annotations import Tuple
from torch import Tensor
import torchvision


def nms(boxes, scores, iou_threshold):
Expand Down Expand Up @@ -112,8 +113,16 @@ def clip_boxes_to_image(boxes, size):
boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2]
height, width = size
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height)

if torchvision._is_tracing():
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
boxes_x = torch.max(boxes_x.to(torch.float32), torch.tensor(0., device=boxes.device))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering, do we need to cast the boxes to float32 here, or can we just use the boxes_x dtype while constructing the scalar tensor?

Copy link
Contributor Author

@neginraoof neginraoof Mar 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When using boxes.dtype, I'm actually seeing an error from ONNX Runtime:
[ONNXRuntimeError] : 10 : INVALID_GRAPH : This is an invalid model. Type Error: Type 'tensor(int64)' of input parameter (17) of operator (Max) in node (Max_21) is invalid.

Will need to look into the issue.

boxes_x = torch.min(boxes_x.to(torch.float32), torch.tensor(width, dtype=torch.float32, device=boxes.device))
boxes_y = torch.max(boxes_y.to(torch.float32), torch.tensor(0., device=boxes.device))
boxes_y = torch.min(boxes_y.to(torch.float32), torch.tensor(height, dtype=torch.float32, device=boxes.device))
else:
boxes_x = boxes_x.clamp(min=0, max=width)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From your previous comment,

I believe clamp is traces with constants for min/max:
Tensor clamp(const Tensor & self, c10::optional min, c10::optional max)

Would this cause a failure with branch:
if torchvision._is_tracing():

Can we make it trace with tensors as well if tensors are passed?

Copy link
Contributor Author

@neginraoof neginraoof Mar 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we have the same issue. The traced graph for clamp shows min/max as tensors even when a tensor is passed.
%7 : int = prim::Constant [ value=2 ] ( )
%8 : None = prim::Constant()
%9 : Float(3, 4) = aten::clamp(%0, %7, %8)
I haven't looked deeper.
The PR to address this: pytorch/pytorch#32587
Maybe we can have this branch until the PR is merged.
I'm not sure if changing it to tensors would help much. Let me know what you think. Thanks.

boxes_y = boxes_y.clamp(min=0, max=height)

clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
return clipped_boxes.reshape(boxes.shape)

Expand Down