diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index b4e4ea370682b..09c727e749328 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -42,22 +42,24 @@ def compute_resize(attrs, inputs, out_type): reg.register_injective_schedule("image.resize") @script -def _resize_func(image_shape, size, height_axis, width_axis, channel_axis): +def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): out = output_tensor((4,), "int64") - out[0] = int64(image_shape[0]) + out[batch_axis] = int64(image_shape[0]) out[height_axis] = int64(size[0]) out[width_axis] = int64(size[1]) out[channel_axis] = image_shape[channel_axis] return out @reg.register_shape_func("image.resize", False) -def resize_func(attrs, inputs, _): +def resize_shape_func(attrs, inputs, _): """ Shape function for resize op. """ layout = attrs.layout height_axis = width_axis = channel_axis = 1 for i, letter in enumerate(layout): + if letter == "N": + batch_axis = i if letter == "H": height_axis = i if letter == "W": @@ -65,8 +67,9 @@ def resize_func(attrs, inputs, _): if letter == "C": channel_axis = i size = get_const_tuple(attrs.size) - return [_resize_func(inputs[0], convert(size), convert(height_axis), - convert(width_axis), convert(channel_axis))] + return [_resize_shape_func(inputs[0], convert(size), convert(batch_axis), + convert(height_axis), convert(width_axis), + convert(channel_axis))] @reg.register_compute("image.resize3d")