Skip to content

Commit

Permalink
add dynamic get valid count test, including empty size tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
masa committed Jan 20, 2021
1 parent 6c70ed2 commit a6c7403
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,51 @@ def test_any_topk():
verify_any_topk(any_dims(1), 0, (0,), "float32", ret_type="both")


def verify_any_get_valid_counts(num_anchor_real, dtype, targets=None):
mod = tvm.IRModule()
batch_size = 1
num_anchor = relay.Any()
data = relay.var("data", shape=(batch_size, num_anchor, 5), dtype=dtype)
np_data = np.random.uniform(size=(batch_size, num_anchor_real, 5)).astype(dtype)

np_out1 = np.zeros(shape=(batch_size,))
np_out2 = np.zeros(shape=np_data.shape).astype(dtype)
np_out3 = np.zeros(shape=(batch_size, num_anchor_real))
score_threshold = 0.95

for i in range(batch_size):
np_out1[i] = 0
inter_idx = 0
for j in range(num_anchor_real):
score = np_data[i, j, 0]
if score > score_threshold:
for k in range(5):
np_out2[i, inter_idx, k] = np_data[i, j, k]
np_out1[i] += 1
np_out3[i, inter_idx] = j
inter_idx += 1
if j >= np_out1[i]:
for k in range(5):
np_out2[i, j, k] = -1.0
np_out3[i, j] = -1

z = relay.vision.get_valid_counts(data, score_threshold, 0, score_index=0)

mod["main"] = relay.Function([data], z.astuple())

check_result([np_data], mod, [np_out1, np_out2, np_out3], targets=targets)


@tvm.testing.uses_gpu
def test_any_get_valid_counts():
verify_any_get_valid_counts(10, "float32")
# opencl seems to have issues with empty size buffer
# Check failed: err_code == CL_SUCCESS == false: OpenCL Error,
# code=-61: CL_INVALID_BUFFER_SIZE
targets = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0)), ("nvptx", tvm.gpu(0))]
verify_any_get_valid_counts(0, "float32", targets=targets)


@tvm.testing.uses_gpu
def test_fused_ops():
x = relay.var("x", shape=(relay.Any(), relay.Any()), dtype="float32")
Expand Down

0 comments on commit a6c7403

Please sign in to comment.