Skip to content

Commit

Permalink
updated mask rcnn test to verify outputs and also run cuda target
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 24, 2020
1 parent 25a8d00 commit 2c8efae
Showing 1 changed file with 27 additions and 36 deletions.
63 changes: 27 additions & 36 deletions tests/python/frontend/pytorch/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import tvm

import tvm.testing
from tvm import relay
from tvm.runtime.vm import VirtualMachine
from tvm.contrib.download import download
Expand Down Expand Up @@ -94,46 +95,36 @@ def test_detection_models():
download(img_url, img)

input_shape = (1, 3, in_size, in_size)
target = "llvm"

input_name = "input0"
shape_list = [(input_name, input_shape)]
score_threshold = 0.9

scripted_model = generate_jit_model(1)
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

with tvm.transform.PassContext(opt_level=3, disabled_pass=["FoldScaleAxis"]):
vm_exec = relay.vm.compile(mod, target=target, params=params)

ctx = tvm.cpu()
vm = VirtualMachine(vm_exec, ctx)
data = process_image(img)
pt_res = scripted_model(data)
data = data.detach().numpy()
vm.set_input("main", **{input_name: data})
tvm_res = vm.run()

# Note: due to accumulated numerical error, we can't directly compare results
# with pytorch output. Some boxes might have a quite tiny difference in score
# and the order can become different. We just measure how many valid boxes
# there are for input image.
pt_scores = pt_res[1].detach().numpy().tolist()
tvm_scores = tvm_res[1].asnumpy().tolist()
num_pt_valid_scores = num_tvm_valid_scores = 0

for score in pt_scores:
if score >= score_threshold:
num_pt_valid_scores += 1
else:
break

for score in tvm_scores:
if score >= score_threshold:
num_tvm_valid_scores += 1
else:
break

assert num_pt_valid_scores == num_tvm_valid_scores, (
"Output mismatch: Under score threshold {}, Pytorch has {} valid "
"boxes while TVM has {}.".format(score_threshold, num_pt_valid_scores, num_tvm_valid_scores)
)
data_np = data.detach().numpy()

with torch.no_grad():
pt_res = scripted_model(data)

for target in ["llvm", "cuda"]:
with tvm.transform.PassContext(opt_level=3):
vm_exec = relay.vm.compile(mod, target=target, params=params)

ctx = tvm.context(target, 0)
vm = VirtualMachine(vm_exec, ctx)

vm.set_input("main", **{input_name: data_np})
tvm_res = vm.run()

# Bounding boxes
tvm.testing.assert_allclose(
pt_res[0].cpu().numpy(), tvm_res[0].asnumpy(), rtol=1e-5, atol=1e-5
)
# Scores
tvm.testing.assert_allclose(
pt_res[1].cpu().numpy(), tvm_res[1].asnumpy(), rtol=1e-5, atol=1e-5
)
# Class ids
np.testing.assert_equal(pt_res[2].cpu().numpy(), tvm_res[2].asnumpy())

0 comments on commit 2c8efae

Please sign in to comment.