Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Nov 2, 2020
1 parent 8380dce commit 3aa3f5c
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,34 @@ def compute_resize(attrs, inputs, out_type):
reg.register_injective_schedule("image.resize")

@script
def _resize_func(image_shape, size, height_axis, width_axis, channel_axis):
def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis):
out = output_tensor((4,), "int64")
out[0] = int64(image_shape[0])
out[batch_axis] = int64(image_shape[0])
out[height_axis] = int64(size[0])
out[width_axis] = int64(size[1])
out[channel_axis] = image_shape[channel_axis]
return out

@reg.register_shape_func("image.resize", False)
def resize_func(attrs, inputs, _):
def resize_shape_func(attrs, inputs, _):
"""
Shape function for resize op.
"""
layout = attrs.layout
height_axis = width_axis = channel_axis = 1
for i, letter in enumerate(layout):
if letter == "N":
batch_axis = i
if letter == "H":
height_axis = i
if letter == "W":
width_axis = i
if letter == "C":
channel_axis = i
size = get_const_tuple(attrs.size)
return [_resize_func(inputs[0], convert(size), convert(height_axis),
convert(width_axis), convert(channel_axis))]
return [_resize_shape_func(inputs[0], convert(size), convert(batch_axis),
convert(height_axis), convert(width_axis),
convert(channel_axis))]


@reg.register_compute("image.resize3d")
Expand Down

0 comments on commit 3aa3f5c

Please sign in to comment.