Skip to content

Commit

Permalink
Run ONNX Node Tests on available targets (apache#8189)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Jun 17, 2021
1 parent b17dd06 commit 2279300
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4306,8 +4306,25 @@ def verify_eyelike(indata):
]


targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()]

target_skips = {
"cuda": [
"test_mod_mixed_sign_float16/",
"test_qlinearconv/",
"test_resize_upsample_sizes_nearest/",
]
}


@pytest.mark.parametrize("target", targets)
@pytest.mark.parametrize("test", onnx_test_folders)
def test_onnx_nodes(test):
def test_onnx_nodes(test, target):
if target in target_skips:
for failure in target_skips[target]:
if failure in test:
pytest.skip()
break
for failure in unsupported_onnx_tests:
if failure in test:
pytest.skip()
Expand All @@ -4333,12 +4350,14 @@ def test_onnx_nodes(test):
outputs.append(numpy_helper.to_array(new_tensor))
else:
raise ImportError(str(tensor) + " not labeled as an import or an output")
tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0))
if len(outputs) == 1:
tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol)
else:
for output, val in zip(outputs, tvm_val):
tvm.testing.assert_allclose(output, val, rtol=rtol, atol=atol)

dev = tvm.device(target, 0)
tvm_val = get_tvm_output_with_vm(onnx_model, inputs, target, dev)
if len(outputs) == 1:
tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol)
else:
for output, val in zip(outputs, tvm_val):
tvm.testing.assert_allclose(output, val, rtol=rtol, atol=atol)


def test_wrong_input():
Expand Down

0 comments on commit 2279300

Please sign in to comment.