From 9dc160ebb4eeb7b178b82fe3f0140e5fd17f82c8 Mon Sep 17 00:00:00 2001 From: mbrookhart Date: Thu, 21 Jan 2021 13:28:48 -0700 Subject: [PATCH] add a shape function and dynamic test for round --- python/tvm/relay/op/_tensor.py | 1 + tests/python/relay/test_any.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 6fc423371325..7728d6e3efa4 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -235,6 +235,7 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("cast", False, elemwise_shape_func) register_shape_func("cast_like", False, elemwise_shape_func) +register_shape_func("round", False, elemwise_shape_func) register_shape_func("zeros", False, no_data_full_shape_func) register_shape_func("zeros_like", False, elemwise_shape_func) register_shape_func("ones", False, no_data_full_shape_func) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index a537782355d2..0b575d120e8f 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -120,6 +120,7 @@ def test_any_elemwise(): verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt) verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative) verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp) + verify_any_elemwise((relay.Any(),), (3,), relay.round, np.round) @tvm.testing.uses_gpu