From 925d0577146e45675cd6f0ca924e6fb29e05e6da Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Tue, 3 Nov 2020 20:57:50 -0800 Subject: [PATCH] Register shape functions for some image related ops (#6373) * debugging * added three shape funcs * fix lint * address comment * resolve conflicts * resolve conflicts * resolve conflicts * resolve conflicts * resolve conflicts --- python/tvm/relay/op/image/_image.py | 76 +++++++++++++++++++++++++ tests/python/relay/test_any.py | 88 +++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+) diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index c0cdf64c621a..ee8a5b3883b1 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -42,6 +42,45 @@ def compute_resize(attrs, inputs, out_type): reg.register_injective_schedule("image.resize") +@script +def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): + out = output_tensor((4,), "int64") + out[batch_axis] = int64(image_shape[0]) + out[height_axis] = int64(size[0]) + out[width_axis] = int64(size[1]) + out[channel_axis] = image_shape[channel_axis] + return out + + +@reg.register_shape_func("image.resize", False) +def resize_shape_func(attrs, inputs, _): + """ + Shape function for resize op. + """ + layout = attrs.layout + height_axis = width_axis = channel_axis = 1 + for i, letter in enumerate(layout): + if letter == "N": + batch_axis = i + if letter == "H": + height_axis = i + if letter == "W": + width_axis = i + if letter == "C": + channel_axis = i + size = get_const_tuple(attrs.size) + return [ + _resize_shape_func( + inputs[0], + convert(size), + convert(batch_axis), + convert(height_axis), + convert(width_axis), + convert(channel_axis), + ) + ] + + @reg.register_compute("image.resize3d") def compute_resize3d(attrs, inputs, out_type): size = attrs.size @@ -134,6 +173,25 @@ def compute_affine_grid(attrs, inputs, out_dtype): reg.register_injective_schedule("image.affine_grid") +@script +def _affine_grid_func(data, target_shape): + out = output_tensor((4,), "int64") + out[0] = int64(data[0]) + out[1] = int64(2) + out[2] = int64(target_shape[0]) + out[3] = int64(target_shape[1]) + return out + + +@reg.register_shape_func("image.affine_grid", False) +def affine_grid_func(attrs, inputs, _): + """ + Shape function for affine_grid op. + """ + target_shape = get_const_tuple(attrs.target_shape) + return [_affine_grid_func(inputs[0], convert(target_shape))] + + # grid_sample @reg.register_compute("image.grid_sample") def compute_grid_sample(attrs, inputs, out_dtype): @@ -143,3 +201,21 @@ def compute_grid_sample(attrs, inputs, out_dtype): reg.register_injective_schedule("image.grid_sample") + + +@script +def _grid_sample_func(data, grid): + out = output_tensor((4,), "int64") + out[0] = int64(data[0]) + out[1] = int64(data[1]) + out[2] = int64(grid[2]) + out[3] = int64(grid[3]) + return out + + +@reg.register_shape_func("image.grid_sample", False) +def grid_sample_func(attrs, inputs, _): + """ + Shape function for grid_sample op. + """ + return [_grid_sample_func(inputs[0], inputs[1])] diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 8784b97a31fa..546973704fea 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1121,6 +1121,94 @@ def test_any_ndarray_size(): verify_any_ndarray_size((1, 2, 3, 4)) +def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + if layout == "NHWC": + size = (data_shape[1] * scale, data_shape[2] * scale) + else: + size = (data_shape[2] * scale, data_shape[3] * scale) + y = relay.image.resize(data, size, layout) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + check_result([data_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_resize(): + verify_any_resize( + data_shape=(relay.Any(), 4, 4, 4), + scale=2, + layout="NHWC", + static_data_shape=(1, 4, 4, 4), + ref_out_shape=(1, 8, 8, 4), + ) + verify_any_resize( + data_shape=(relay.Any(), 8, 17, 20), + scale=3, + layout="NCHW", + static_data_shape=(2, 8, 17, 20), + ref_out_shape=(2, 8, 51, 60), + ) + + +def verify_any_grid_sample(data_shape, grid_shape, static_data_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data = relay.var("data", shape=data_shape, dtype=dtype) + grid = relay.var("grid", shape=grid_shape, dtype=dtype) + y = relay.image.grid_sample(data, grid) + mod["main"] = relay.Function([data, grid], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + grid_np = np.random.uniform(size=grid_shape).astype(dtype) + check_result([data_np, grid_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_grid_sample(): + verify_any_grid_sample( + data_shape=(relay.Any(), 4, 16, 32), + grid_shape=(4, 2, 8, 8), + static_data_shape=(4, 4, 16, 32), + ref_out_shape=(4, 4, 8, 8), + ) + verify_any_grid_sample( + data_shape=(relay.Any(), 4, 16, 32), + grid_shape=(4, 2, 32, 32), + static_data_shape=(4, 4, 16, 32), + ref_out_shape=(4, 4, 32, 32), + ) + + +def verify_any_affine_grid(num_batch, static_num_batch, target_shape, ref_out_shape): + mod = tvm.IRModule() + dtype = "float32" + data_shape = (num_batch, 2, 3) + static_data_shape = (static_num_batch, 2, 3) + data = relay.var("data", shape=data_shape, dtype=dtype) + y = relay.image.affine_grid(data, target_shape) + mod["main"] = relay.Function([data], y) + data_np = np.random.uniform(size=static_data_shape).astype(dtype) + check_result([data_np], mod, ref_out_shape, assert_shape=True) + + +@tvm.testing.uses_gpu +def test_any_affine_grid(): + verify_any_affine_grid( + num_batch=relay.Any(), + static_num_batch=1, + target_shape=(16, 32), + ref_out_shape=(1, 2, 16, 32), + ) + verify_any_affine_grid( + num_batch=relay.Any(), + static_num_batch=8, + target_shape=(32, 32), + ref_out_shape=(8, 2, 32, 32), + ) + + def test_any_consecutive_broadcast(): dtype = "float32" data0 = relay.var("data0", shape=any_dims(2), dtype=dtype)