Skip to content

Commit

Permalink
Register shape functions for some image related ops (apache#6373)
Browse files Browse the repository at this point in the history
* debugging

* added three shape funcs

* fix lint

* address comment

* resolve conflicts

* resolve conflicts

* resolve conflicts

* resolve conflicts

* resolve conflicts
  • Loading branch information
Laurawly authored and Trevor Morris committed Dec 4, 2020
1 parent 8b1ccc9 commit 925d057
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 0 deletions.
76 changes: 76 additions & 0 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,45 @@ def compute_resize(attrs, inputs, out_type):
reg.register_injective_schedule("image.resize")


@script
def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis):
out = output_tensor((4,), "int64")
out[batch_axis] = int64(image_shape[0])
out[height_axis] = int64(size[0])
out[width_axis] = int64(size[1])
out[channel_axis] = image_shape[channel_axis]
return out


@reg.register_shape_func("image.resize", False)
def resize_shape_func(attrs, inputs, _):
"""
Shape function for resize op.
"""
layout = attrs.layout
height_axis = width_axis = channel_axis = 1
for i, letter in enumerate(layout):
if letter == "N":
batch_axis = i
if letter == "H":
height_axis = i
if letter == "W":
width_axis = i
if letter == "C":
channel_axis = i
size = get_const_tuple(attrs.size)
return [
_resize_shape_func(
inputs[0],
convert(size),
convert(batch_axis),
convert(height_axis),
convert(width_axis),
convert(channel_axis),
)
]


@reg.register_compute("image.resize3d")
def compute_resize3d(attrs, inputs, out_type):
size = attrs.size
Expand Down Expand Up @@ -134,6 +173,25 @@ def compute_affine_grid(attrs, inputs, out_dtype):
reg.register_injective_schedule("image.affine_grid")


@script
def _affine_grid_func(data, target_shape):
out = output_tensor((4,), "int64")
out[0] = int64(data[0])
out[1] = int64(2)
out[2] = int64(target_shape[0])
out[3] = int64(target_shape[1])
return out


@reg.register_shape_func("image.affine_grid", False)
def affine_grid_func(attrs, inputs, _):
"""
Shape function for affine_grid op.
"""
target_shape = get_const_tuple(attrs.target_shape)
return [_affine_grid_func(inputs[0], convert(target_shape))]


# grid_sample
@reg.register_compute("image.grid_sample")
def compute_grid_sample(attrs, inputs, out_dtype):
Expand All @@ -143,3 +201,21 @@ def compute_grid_sample(attrs, inputs, out_dtype):


reg.register_injective_schedule("image.grid_sample")


@script
def _grid_sample_func(data, grid):
out = output_tensor((4,), "int64")
out[0] = int64(data[0])
out[1] = int64(data[1])
out[2] = int64(grid[2])
out[3] = int64(grid[3])
return out


@reg.register_shape_func("image.grid_sample", False)
def grid_sample_func(attrs, inputs, _):
"""
Shape function for grid_sample op.
"""
return [_grid_sample_func(inputs[0], inputs[1])]
88 changes: 88 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,94 @@ def test_any_ndarray_size():
verify_any_ndarray_size((1, 2, 3, 4))


def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
data = relay.var("data", shape=data_shape, dtype=dtype)
if layout == "NHWC":
size = (data_shape[1] * scale, data_shape[2] * scale)
else:
size = (data_shape[2] * scale, data_shape[3] * scale)
y = relay.image.resize(data, size, layout)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
check_result([data_np], mod, ref_out_shape, assert_shape=True)


@tvm.testing.uses_gpu
def test_any_resize():
verify_any_resize(
data_shape=(relay.Any(), 4, 4, 4),
scale=2,
layout="NHWC",
static_data_shape=(1, 4, 4, 4),
ref_out_shape=(1, 8, 8, 4),
)
verify_any_resize(
data_shape=(relay.Any(), 8, 17, 20),
scale=3,
layout="NCHW",
static_data_shape=(2, 8, 17, 20),
ref_out_shape=(2, 8, 51, 60),
)


def verify_any_grid_sample(data_shape, grid_shape, static_data_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
data = relay.var("data", shape=data_shape, dtype=dtype)
grid = relay.var("grid", shape=grid_shape, dtype=dtype)
y = relay.image.grid_sample(data, grid)
mod["main"] = relay.Function([data, grid], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
grid_np = np.random.uniform(size=grid_shape).astype(dtype)
check_result([data_np, grid_np], mod, ref_out_shape, assert_shape=True)


@tvm.testing.uses_gpu
def test_any_grid_sample():
verify_any_grid_sample(
data_shape=(relay.Any(), 4, 16, 32),
grid_shape=(4, 2, 8, 8),
static_data_shape=(4, 4, 16, 32),
ref_out_shape=(4, 4, 8, 8),
)
verify_any_grid_sample(
data_shape=(relay.Any(), 4, 16, 32),
grid_shape=(4, 2, 32, 32),
static_data_shape=(4, 4, 16, 32),
ref_out_shape=(4, 4, 32, 32),
)


def verify_any_affine_grid(num_batch, static_num_batch, target_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
data_shape = (num_batch, 2, 3)
static_data_shape = (static_num_batch, 2, 3)
data = relay.var("data", shape=data_shape, dtype=dtype)
y = relay.image.affine_grid(data, target_shape)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
check_result([data_np], mod, ref_out_shape, assert_shape=True)


@tvm.testing.uses_gpu
def test_any_affine_grid():
verify_any_affine_grid(
num_batch=relay.Any(),
static_num_batch=1,
target_shape=(16, 32),
ref_out_shape=(1, 2, 16, 32),
)
verify_any_affine_grid(
num_batch=relay.Any(),
static_num_batch=8,
target_shape=(32, 32),
ref_out_shape=(8, 2, 32, 32),
)


def test_any_consecutive_broadcast():
dtype = "float32"
data0 = relay.var("data0", shape=any_dims(2), dtype=dtype)
Expand Down

0 comments on commit 925d057

Please sign in to comment.