Skip to content

Commit

Permalink
add dynamic gather_nd test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 29, 2021
1 parent 3a9fe5d commit aaa6211
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.relay.testing import run_infer_type as infer_type

from utils.assert_diagnostic import DiagnosticTesting
from utils import ref_funcs


def int32(val):
Expand Down Expand Up @@ -1703,5 +1704,29 @@ def verify_all_class_non_max_suppression(
)


@tvm.testing.uses_gpu
def test_gather_nd():
def verify_gather_nd(data_shape, indices_shape, data_shape_np, indices_shape_np, batch_dims=0):
x = relay.var("x", relay.TensorType(data_shape, "float32"))
y = relay.var("y", relay.TensorType(indices_shape, "int32"))
z = relay.gather_nd(x, y, batch_dims, indices_shape[0])

mod = tvm.IRModule()
mod["main"] = relay.Function([x, y], z)

data_np = np.random.uniform(size=data_shape_np).astype("float32")
indices_np = np.random.randint(low=0, high=2, size=indices_shape_np, dtype="int32")

ref_res = ref_funcs.gather_nd(data_np, indices_np, batch_dims)
check_result([data_np, indices_np], mod, [ref_res])

verify_gather_nd((2, 2), (2, relay.Any()), (2, 2), (2, 3))
verify_gather_nd((relay.Any(), 2), (2, relay.Any()), (2, 2), (2, 3))
verify_gather_nd((relay.Any(), 2), (1, relay.Any()), (10, 2), (1, 10), 1)
verify_gather_nd(
(relay.Any(), 2, 2, 3, 4), (3, relay.Any(), relay.Any()), (3, 2, 2, 3, 4), (3, 3, 2), 2
)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit aaa6211

Please sign in to comment.