Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX export for variable input sizes #1840

Merged
merged 61 commits into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
e48f3f7
fixes and tests for variable input size
neginraoof Jan 30, 2020
8df5790
transform test fix
neginraoof Jan 30, 2020
03bc884
Fix comment
neginraoof Jan 30, 2020
9b6142f
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Jan 30, 2020
e3d6040
Merge branch 'neraoof/variableInput' of https://github.com/neginraoof…
neginraoof Jan 30, 2020
37888e2
Dynamic shape for keypoint_rcnn
neginraoof Jan 30, 2020
84786b0
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Feb 4, 2020
ebcd45b
Update test_onnx.py
neginraoof Feb 4, 2020
89404f3
Update rpn.py
neginraoof Feb 5, 2020
b6f7859
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Feb 18, 2020
60ba5e7
Fix for split on RPN
neginraoof Feb 18, 2020
10a9fcb
Merge branch 'neraoof/variableInput' of github.com:neginraoof/vision …
neginraoof Feb 18, 2020
ea5cf6e
Fixes for feedbacks
neginraoof Feb 20, 2020
dced83a
flake8
neginraoof Feb 20, 2020
da44102
topk fix
neginraoof Feb 20, 2020
cd79435
Fix build
neginraoof Feb 20, 2020
fbe4680
branch on tracing
neginraoof Feb 20, 2020
2b4ad07
fix for scalar tensor
neginraoof Feb 20, 2020
be0ae7e
Fixes for script type annotations
neginraoof Feb 25, 2020
e6e3109
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Feb 25, 2020
829b58b
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Feb 25, 2020
7999e55
Update rpn.py
neginraoof Mar 3, 2020
94b1ac6
clean up
neginraoof Mar 3, 2020
050e756
clean up
neginraoof Mar 3, 2020
a445d4a
Update rpn.py
neginraoof Mar 3, 2020
b0c79bb
Updated for feedback
neginraoof Mar 3, 2020
89c2a80
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 3, 2020
9d2b533
Merge branch 'master' into neraoof/variableInput
neginraoof Mar 13, 2020
931d5eb
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 16, 2020
c90b6e0
Merge branch 'neraoof/variableInput' of github.com:neginraoof/vision …
neginraoof Mar 16, 2020
ba40ea1
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 19, 2020
56b62d8
Merge branch 'master' of github.com:pytorch/vision into neraoof/varia…
fmassa Mar 20, 2020
228db38
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 24, 2020
04ff430
Fix for comments
neginraoof Mar 24, 2020
d01a8ff
Merge branch 'neraoof/variableInput' of github.com:neginraoof/vision …
neginraoof Mar 24, 2020
23bff59
revert to use tensor
neginraoof Mar 24, 2020
b9ff797
Added test for box clip
neginraoof Mar 24, 2020
2cecb1c
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 26, 2020
572e522
Fixes for feedback
neginraoof Mar 26, 2020
d0ff3f3
Fix for feedback
neginraoof Mar 26, 2020
a66381a
ORT version revert
neginraoof Mar 26, 2020
8d4b1ee
Update ort
neginraoof Mar 27, 2020
16cd2eb
Update .travis.yml
neginraoof Mar 27, 2020
c9fc98c
Update test_onnx.py
neginraoof Mar 27, 2020
f12e10c
Update test_onnx.py
neginraoof Mar 27, 2020
a3a9cd6
Tensor sizes
neginraoof Mar 27, 2020
088d514
Fix for dynamic split
neginraoof Mar 27, 2020
ce5d781
Try disable tests
neginraoof Mar 27, 2020
e826afe
pytest verbose
neginraoof Mar 27, 2020
8f9d8d5
revert one test
neginraoof Mar 27, 2020
faa16aa
enable tests
neginraoof Mar 27, 2020
d8ec5c5
Update .travis.yml
neginraoof Mar 27, 2020
85daf17
Update .travis.yml
neginraoof Mar 27, 2020
52aaa67
Update .travis.yml
neginraoof Mar 27, 2020
c30cce7
Update test_onnx.py
neginraoof Mar 29, 2020
3032e8e
Update .travis.yml
neginraoof Mar 29, 2020
04a3806
Passing device
neginraoof Mar 29, 2020
0c2b4b1
Merge branch 'master' of https://github.com/pytorch/vision into nerao…
neginraoof Mar 30, 2020
dbc2c9d
Fixes for test
neginraoof Mar 30, 2020
c82321b
Fix for boxes datatype
neginraoof Mar 30, 2020
35142fa
clean up
neginraoof Mar 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 67 additions & 40 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ class ONNXExporterTester(unittest.TestCase):
def setUpClass(cls):
torch.manual_seed(123)

def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True):
def run_model(self, model, inputs_list, tolerate_small_mismatch=False, do_constant_folding=True, dynamic_axes=None,
output_names=None, input_names=None):
model.eval()

onnx_io = io.BytesIO()
# onnx_io = '/home/neraoof/test/results/transform.onnx'
# export to onnx with the first input
torch.onnx.export(model, inputs_list[0], onnx_io,
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version)

torch.onnx.export(model, inputs_list[0], onnx_io,
do_constant_folding=do_constant_folding, opset_version=_onnx_opset_version,
dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
with torch.no_grad():
Expand Down Expand Up @@ -65,6 +68,7 @@ def to_numpy(tensor):
# compute onnxruntime output prediction
ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
ort_outs = ort_session.run(None, ort_inputs)

for i in range(0, len(outputs)):
try:
torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
Expand Down Expand Up @@ -123,9 +127,9 @@ def __init__(self_module):
def forward(self_module, images):
return self_module.transform(images)[0].tensors

input = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
input_test = [torch.rand(3, 100, 200), torch.rand(3, 200, 200)]
self.run_model(TransformModule(), [input, input_test])
input = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
input_test = torch.rand(3, 100, 200), torch.rand(3, 200, 200)
self.run_model(TransformModule(), [(input,), (input_test,)])
neginraoof marked this conversation as resolved.
Show resolved Hide resolved

def _init_test_generalized_rcnn_transform(self):
min_size = 100
Expand Down Expand Up @@ -207,22 +211,28 @@ def get_features(self, images):

def test_rpn(self):
class RPNModule(torch.nn.Module):
def __init__(self_module, images):
def __init__(self_module):
super(RPNModule, self_module).__init__()
self_module.rpn = self._init_test_rpn()
self_module.images = ImageList(images, [i.shape[-2:] for i in images])

def forward(self_module, features):
return self_module.rpn(self_module.images, features)
def forward(self_module, images, features):
images = ImageList(images, [i.shape[-2:] for i in images])
return self_module.rpn(images, features)

images = torch.rand(2, 3, 600, 600)
features = self.get_features(images)
test_features = self.get_features(images)
images2 = torch.rand(2, 3, 1000, 1000)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is using too much memory and is making the tests segfault in CI?

test_features = self.get_features(images2)

model = RPNModule(images)
model = RPNModule()
model.eval()
model(features)
self.run_model(model, [(features,), (test_features,)], tolerate_small_mismatch=True)
model(images, features)

self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3],
"input3": [0, 1, 2, 3], "input4": [0, 1, 2, 3],
"input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})

def test_multi_scale_roi_align(self):

Expand Down Expand Up @@ -251,32 +261,37 @@ def forward(self, input, boxes):

def test_roi_heads(self):
class RoiHeadsModule(torch.nn.Module):
def __init__(self_module, images):
def __init__(self_module):
super(RoiHeadsModule, self_module).__init__()
self_module.transform = self._init_test_generalized_rcnn_transform()
self_module.rpn = self._init_test_rpn()
self_module.roi_heads = self._init_test_roi_heads_faster_rcnn()
self_module.original_image_sizes = [img.shape[-2:] for img in images]
self_module.images = ImageList(images, [i.shape[-2:] for i in images])

def forward(self_module, features):
proposals, _ = self_module.rpn(self_module.images, features)
detections, _ = self_module.roi_heads(features, proposals, self_module.images.image_sizes)
def forward(self_module, images, features):
original_image_sizes = [img.shape[-2:] for img in images]
images = ImageList(images, [i.shape[-2:] for i in images])
proposals, _ = self_module.rpn(images, features)
detections, _ = self_module.roi_heads(features, proposals, images.image_sizes)
detections = self_module.transform.postprocess(detections,
self_module.images.image_sizes,
self_module.original_image_sizes)
images.image_sizes,
original_image_sizes)
return detections

images = torch.rand(2, 3, 600, 600)
features = self.get_features(images)
test_features = self.get_features(images)
images2 = torch.rand(2, 3, 1000, 1000)
test_features = self.get_features(images2)

model = RoiHeadsModule(images)
model = RoiHeadsModule()
model.eval()
model(features)
self.run_model(model, [(features,), (test_features,)])
model(images, features)

self.run_model(model, [(images, features), (images2, test_features)], tolerate_small_mismatch=True,
input_names=["input1", "input2", "input3", "input4", "input5", "input6"],
dynamic_axes={"input1": [0, 1, 2, 3], "input2": [0, 1, 2, 3], "input3": [0, 1, 2, 3],
"input4": [0, 1, 2, 3], "input5": [0, 1, 2, 3], "input6": [0, 1, 2, 3]})

def get_image_from_url(self, url):
def get_image_from_url(self, url, size=None):
import requests
import numpy
from PIL import Image
Expand All @@ -285,29 +300,35 @@ def get_image_from_url(self, url):

data = requests.get(url)
image = Image.open(BytesIO(data.content)).convert("RGB")
image = image.resize((300, 200), Image.BILINEAR)

if size is None:
size = (300, 200)
image = image.resize(size, Image.BILINEAR)

to_tensor = transforms.ToTensor()
return to_tensor(image)

def get_test_images(self):
image_url = "http://farm3.staticflickr.com/2469/3915380994_2e611b1779_z.jpg"
image = self.get_image_from_url(url=image_url)
image = self.get_image_from_url(url=image_url, size=(800, 1201))

image_url2 = "https://pytorch.org/tutorials/_static/img/tv_tutorial/tv_image05.png"
image2 = self.get_image_from_url(url=image_url2)
image2 = self.get_image_from_url(url=image_url2, size=(873, 800))

images = [image]
test_images = [image2]
return images, test_images

def test_faster_rcnn(self):
images, test_images = self.get_test_images()

model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True,
min_size=200,
max_size=300)
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])
self.run_model(model, [(images,), (test_images,)], input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
tolerate_small_mismatch=True)

# Verify that paste_mask_in_image beahves the same in tracing.
# This test also compares both paste_masks_in_image and _onnx_paste_masks_in_image
Expand Down Expand Up @@ -346,10 +367,14 @@ def test_paste_mask_in_image(self):
def test_mask_rcnn(self):
images, test_images = self.get_test_images()

model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
model(images)
self.run_model(model, [(images,), (test_images,)])
self.run_model(model, [(images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs"],
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
tolerate_small_mismatch=True)

# Verify that heatmaps_to_keypoints behaves the same in tracing.
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
Expand Down Expand Up @@ -383,9 +408,7 @@ def test_keypoint_rcnn(self):
class KeyPointRCNN(torch.nn.Module):
def __init__(self):
super(KeyPointRCNN, self).__init__()
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True,
min_size=200,
max_size=300)
self.model = models.detection.keypoint_rcnn.keypointrcnn_resnet50_fpn(pretrained=True)

def forward(self, images):
output = self.model(images)
Expand All @@ -398,7 +421,11 @@ def forward(self, images):
model = KeyPointRCNN()
model.eval()
model(test_images)
self.run_model(model, [(images,), (test_images,)])
self.run_model(model, [(images,), (test_images,)],
input_names=["images_tensors"],
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
dynamic_axes={"images_tensors": [0, 1, 2, 3]},
tolerate_small_mismatch=True)


if __name__ == '__main__':
Expand Down
24 changes: 17 additions & 7 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
# type: (Tensor, int) -> Tuple[int, int]
from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
num_anchors = ob.shape[1].unsqueeze(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So now we don't need to use operators.shape_as_tensor? Does ONNX now trace the .shape values from tensors and keep them as dynamic variables?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, now we are able to trace the shapes dynamically.

# TODO : remove cast to IntTensor/num_anchors.dtype when
# ONNX Runtime version is updated with ReduceMin int64 support
pre_nms_top_n = torch.min(torch.cat(
Expand Down Expand Up @@ -157,7 +156,8 @@ def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor])
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
strides = [[int(image_size[0] / g[0]), int(image_size[1] / g[1])] for g in grid_sizes]
strides = [[torch.tensor(image_size[0] / g[0], dtype=torch.int64),
torch.tensor(image_size[1] / g[1], dtype=torch.int64)] for g in grid_sizes]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, I wonder if lines 124-126 are still necessary?

            if torchvision._is_tracing():
                # required in ONNX export for mult operation with float32
                stride_width = torch.tensor(stride_width, dtype=torch.float32)
                stride_height = torch.tensor(stride_height, dtype=torch.float32)

Also, do we need to pass the device in the tensor constructor, so that it works for CUDA as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for pointing this out. The lines here are not needed anymore. I'll remove this part.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also pass the device to the tensor constructor?

dtype, device = feature_maps[0].dtype, feature_maps[0].device
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
Expand Down Expand Up @@ -354,7 +354,17 @@ def _get_top_n_idx(self, objectness, num_anchors_per_level):
# type: (Tensor, List[int])
r = []
offset = 0
for ob in objectness.split(num_anchors_per_level, 1):
if torchvision._is_tracing():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It now looks like this function is completely different (or almost) for ONNX and the other code-path.

Also, would this become unnecessary when pytorch/pytorch#32493 gets merged?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is not addressing this issue. Currently split with dynamic sizes is not supported in tracing.
I agree that the behavior is much different here. I can implement _get_top_n_idx_for_onnx.
What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually going to add support for split with dynamic input in the exporter. We won't need these extra code-paths anymore.

# Split's split_size is traced as constant in onnx exporting, use Gather
start_list = [torch.tensor(0)]
end_list = [num_anchors_per_level[0].clone()]
for cnt in num_anchors_per_level[1:]:
start_list.append(end_list[-1].clone())
end_list.append(end_list[-1] + cnt)
objectness_per_level = [objectness[:, s:e] for s, e in zip(start_list, end_list)]
else:
objectness_per_level = objectness.split(num_anchors_per_level, 1)
for ob in objectness_per_level:
if torchvision._is_tracing():
num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(ob, self.pre_nms_top_n())
else:
Expand All @@ -372,14 +382,12 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
# do not backprop throught objectness
objectness = objectness.detach()
objectness = objectness.reshape(num_images, -1)

neginraoof marked this conversation as resolved.
Show resolved Hide resolved
levels = [
torch.full((n,), idx, dtype=torch.int64, device=device)
for idx, n in enumerate(num_anchors_per_level)
]
levels = torch.cat(levels, 0)
levels = levels.reshape(1, -1).expand_as(objectness)

# select top_n boxes independently per level before applying nms
top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)

Expand Down Expand Up @@ -466,7 +474,9 @@ def forward(self, images, features, targets=None):
anchors = self.anchor_generator(images, features)

num_images = len(anchors)
num_anchors_per_level = [o[0].numel() for o in objectness]
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line will be replaced by:
# num_anchors_per_level = [torch.prod(shape_as_tensor(o[0])) for o in objectness]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't numel() dynamic on the input shape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the output is an integer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this something that we might want to change on ONNX side? Because shape in PyTorch currently returns a Size object, which is a tuple[int] under the hood, and using shape works inn ONNX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it could be fixed in exporter. Numel is different than shape since it is traced as a constant in the IR graph. I tested this with a small repro, but I haven't looked deeper.


objectness, pred_bbox_deltas = \
concat_box_prediction_layers(objectness, pred_bbox_deltas)
# apply pred_bbox_deltas to anchors to obtain the decoded proposals
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def resize(self, image, target):
if max_size * scale_factor > self.max_size:
scale_factor = self.max_size / max_size
image = torch.nn.functional.interpolate(
image[None], scale_factor=scale_factor, mode='bilinear', align_corners=False)[0]
image[None], scale_factor=scale_factor, mode='bilinear',
align_corners=False, recompute_scale_factor=False)[0]
neginraoof marked this conversation as resolved.
Show resolved Hide resolved

if target is None:
return image, target
Expand Down
16 changes: 14 additions & 2 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.jit.annotations import Tuple
from torch import Tensor
import torchvision


def nms(boxes, scores, iou_threshold):
Expand Down Expand Up @@ -112,8 +113,19 @@ def clip_boxes_to_image(boxes, size):
boxes_x = boxes[..., 0::2]
boxes_y = boxes[..., 1::2]
height, width = size
boxes_x = boxes_x.clamp(min=0, max=width)
boxes_y = boxes_y.clamp(min=0, max=height)

if torchvision._is_tracing():
neginraoof marked this conversation as resolved.
Show resolved Hide resolved
height = height.to(torch.float32)
width = width.to(torch.float32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this will break torchscript compatibility, because it would imagine that height and width are int, which do not have a .to method.

Would it be possible to make clamp work with dynamic values? In PyTorch, .clamp already support tensor values for min / max.

Copy link
Contributor Author

@neginraoof neginraoof Feb 4, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe clamp is traces with constants for min/max:
Tensor clamp(const Tensor & self, c10::optional min, c10::optional max)

Would this cause a failure with branch:
if torchvision._is_tracing():


boxes_x = torch.max(boxes_x, torch.tensor(0.))
boxes_x = torch.min(boxes_x, width)
boxes_y = torch.max(boxes_y, torch.tensor(0.))
boxes_y = torch.min(boxes_y, height)
else:
boxes_x = boxes_x.clamp(min=0, max=width)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From your previous comment,

I believe clamp is traces with constants for min/max:
Tensor clamp(const Tensor & self, c10::optional min, c10::optional max)

Would this cause a failure with branch:
if torchvision._is_tracing():

Can we make it trace with tensors as well if tensors are passed?

Copy link
Contributor Author

@neginraoof neginraoof Mar 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we have the same issue. The traced graph for clamp shows min/max as tensors even when a tensor is passed.
%7 : int = prim::Constant [ value=2 ] ( )
%8 : None = prim::Constant()
%9 : Float(3, 4) = aten::clamp(%0, %7, %8)
I haven't looked deeper.
The PR to address this: pytorch/pytorch#32587
Maybe we can have this branch until the PR is merged.
I'm not sure if changing it to tensors would help much. Let me know what you think. Thanks.

boxes_y = boxes_y.clamp(min=0, max=height)

clipped_boxes = torch.stack((boxes_x, boxes_y), dim=dim)
return clipped_boxes.reshape(boxes.shape)

Expand Down