diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 1b09cf307554e..594ab2df7a8da 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2570,46 +2570,12 @@ def convert_batch_to_space_nd(self, op): input_tensor_idx = input_tensor.tensor_idx in_expr = self.get_expr(input_tensor_idx) - input_shape = list(input_tensor.tensor.ShapeAsNumpy()) - batch = input_shape[0] - block_shape = list(self.get_tensor_value(input_tensors[1])) - M = len(block_shape) - - crops = list(self.get_tensor_value(input_tensors[2])) + crops = self.get_tensor_value(input_tensors[2]).tolist() - # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: - # Reshape input to reshaped of shape - shape1 = block_shape + [batch // np.prod(block_shape)] + input_shape[1:] - reshaped = _op.reshape(in_expr, newshape=shape1) - - # Permute dimensions of reshaped to produce permuted of shape - axes = ( - [M] - + [axis for i in range(M) for axis in [M + i + 1, i]] - + list(range(2 * M + 1, len(shape1))) - ) - permuted = _op.transpose(reshaped, axes=axes) - - # Reshape permuted to produce reshaped_permuted of shape - shape2 = [0] + [-3] * M + [-2] - reshaped_permuted = _op.reshape(permuted, newshape=shape2) - - # Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops - # to produce the output of shape: - reshaped_permuted_shape = _infer_shape(reshaped_permuted) - cropped = reshaped_permuted - for axis in range(1, M + 1): - crop = crops[axis - 1] - if (crop != [0, 0]).all(): - indices = _op.arange( - _expr.const(crop[0]), - _expr.const(reshaped_permuted_shape[axis] - crop[1]), - dtype="int32", - ) - cropped = _op.take(cropped, indices=indices, axis=axis) + out = _op.nn.batch_to_space_nd(in_expr, block_shape, crops) - return cropped + return out def convert_space_to_batch_nd(self, op): """space_to_batch_nd implementation.""" @@ -2620,51 +2586,12 @@ def convert_space_to_batch_nd(self, op): input_tensor_idx = input_tensor.tensor_idx in_expr = self.get_expr(input_tensor_idx) - input_shape = list(input_tensor.tensor.ShapeAsNumpy()) - batch = input_shape[0] - N = len(input_shape) - block_shape = list(self.get_tensor_value(input_tensors[1])) - M = len(block_shape) - - paddings = list(self.get_tensor_value(input_tensors[2])) + paddings = self.get_tensor_value(input_tensors[2]).tolist() - # From https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd: - # Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings - # to produce padded of shape padded_shape. - remaining_shape_length = N - M - 1 - padded_list = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length + out = _op.nn.space_to_batch_nd(in_expr, block_shape, paddings) - padded_shape = [] - for element in padded_list: - if isinstance(element, np.ndarray): - element = element.tolist() - - padded_shape.append(element) - - padded_shape = tuple(padded_shape) - padded = _op.nn.pad(in_expr, pad_width=tuple(padded_shape)) - - # Reshape padded to reshaped_padded of shape: - shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2] - reshaped_padded = _op.reshape(padded, newshape=shape1) - - # Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape: - axes = ( - [2 * i + 2 for i in range(M)] - + [0] - + [2 * i + 1 for i in range(M)] - + list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) - ) - permuted_reshaped_padded = _op.transpose(reshaped_padded, axes=axes) - permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded) - - # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, - # producing an output tensor of shape: - shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1 :] - reshaped_permuted_reshaped_padded = _op.reshape(permuted_reshaped_padded, newshape=shape2) - - return reshaped_permuted_reshaped_padded + return out def convert_depth_to_space(self, op): """Convert TFLite DEPTH_TO_SPACE"""