Skip to content

Commit

Permalink
resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Nov 2, 2020
1 parent d0b14b0 commit 6a6fda5
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,85 @@ def test_any_ndarray_size():
verify_any_ndarray_size((1, 2, 3, 4))


def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
data = relay.var('data', shape=data_shape, dtype=dtype)
if layout == "NHWC":
size = (data_shape[1] * scale, data_shape[2] * scale)
else:
size = (data_shape[2] * scale, data_shape[3] * scale)
y = relay.image.resize(data, size, layout)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
check_result([data_np], mod, ref_out_shape, assert_shape=True)

@tvm.testing.uses_gpu
def test_any_resize():
verify_any_resize(
data_shape=(relay.Any(), 4, 4, 4),
scale=2,
layout='NHWC',
static_data_shape=(1, 4, 4, 4),
ref_out_shape=(1, 8, 8, 4))
verify_any_resize(
data_shape=(relay.Any(), 8, 17, 20),
scale=3,
layout='NCHW',
static_data_shape=(2, 8, 17, 20),
ref_out_shape=(2, 8, 51, 60))


def verify_any_grid_sample(data_shape, grid_shape, static_data_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
data = relay.var('data', shape=data_shape, dtype=dtype)
grid = relay.var('grid', shape=grid_shape, dtype=dtype)
y = relay.image.grid_sample(data, grid)
mod["main"] = relay.Function([data, grid], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
grid_np = np.random.uniform(size=grid_shape).astype(dtype)
check_result([data_np, grid_np], mod, ref_out_shape, assert_shape=True)

@tvm.testing.uses_gpu
def test_any_grid_sample():
verify_any_grid_sample(
data_shape=(relay.Any(), 4, 16, 32),
grid_shape=(4, 2, 8, 8),
static_data_shape=(4, 4, 16, 32),
ref_out_shape=(4, 4, 8, 8))
verify_any_grid_sample(
data_shape=(relay.Any(), 4, 16, 32),
grid_shape=(4, 2, 32, 32),
static_data_shape=(4, 4, 16, 32),
ref_out_shape=(4, 4, 32, 32))


def verify_any_affine_grid(num_batch, static_num_batch, target_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
data_shape = (num_batch, 2, 3)
static_data_shape = (static_num_batch, 2, 3)
data = relay.var('data', shape=data_shape, dtype=dtype)
y = relay.image.affine_grid(data, target_shape)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
check_result([data_np], mod, ref_out_shape, assert_shape=True)

@tvm.testing.uses_gpu
def test_any_affine_grid():
verify_any_affine_grid(
num_batch=relay.Any(),
static_num_batch=1,
target_shape=(16, 32),
ref_out_shape=(1, 2, 16, 32))
verify_any_affine_grid(
num_batch=relay.Any(),
static_num_batch=8,
target_shape=(32, 32),
ref_out_shape=(8, 2, 32, 32))


def test_any_consecutive_broadcast():
dtype = "float32"
data0 = relay.var("data0", shape=any_dims(2), dtype=dtype)
Expand Down

0 comments on commit 6a6fda5

Please sign in to comment.