Skip to content

Commit

Permalink
tflite frontend to use batch_to_space and space_to_batch operators
Browse files Browse the repository at this point in the history
  • Loading branch information
BhushanIMG committed Oct 1, 2020
1 parent d0f15c9 commit ba53bb9
Showing 1 changed file with 6 additions and 79 deletions.
85 changes: 6 additions & 79 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2570,46 +2570,12 @@ def convert_batch_to_space_nd(self, op):
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)

input_shape = list(input_tensor.tensor.ShapeAsNumpy())
batch = input_shape[0]

block_shape = list(self.get_tensor_value(input_tensors[1]))
M = len(block_shape)

crops = list(self.get_tensor_value(input_tensors[2]))
crops = self.get_tensor_value(input_tensors[2]).tolist()

# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
# Reshape input to reshaped of shape
shape1 = block_shape + [batch // np.prod(block_shape)] + input_shape[1:]
reshaped = _op.reshape(in_expr, newshape=shape1)

# Permute dimensions of reshaped to produce permuted of shape
axes = (
[M]
+ [axis for i in range(M) for axis in [M + i + 1, i]]
+ list(range(2 * M + 1, len(shape1)))
)
permuted = _op.transpose(reshaped, axes=axes)

# Reshape permuted to produce reshaped_permuted of shape
shape2 = [0] + [-3] * M + [-2]
reshaped_permuted = _op.reshape(permuted, newshape=shape2)

# Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops
# to produce the output of shape:
reshaped_permuted_shape = _infer_shape(reshaped_permuted)
cropped = reshaped_permuted
for axis in range(1, M + 1):
crop = crops[axis - 1]
if (crop != [0, 0]).all():
indices = _op.arange(
_expr.const(crop[0]),
_expr.const(reshaped_permuted_shape[axis] - crop[1]),
dtype="int32",
)
cropped = _op.take(cropped, indices=indices, axis=axis)
out = _op.nn.batch_to_space_nd(in_expr, block_shape, crops)

return cropped
return out

def convert_space_to_batch_nd(self, op):
"""space_to_batch_nd implementation."""
Expand All @@ -2620,51 +2586,12 @@ def convert_space_to_batch_nd(self, op):
input_tensor_idx = input_tensor.tensor_idx
in_expr = self.get_expr(input_tensor_idx)

input_shape = list(input_tensor.tensor.ShapeAsNumpy())
batch = input_shape[0]
N = len(input_shape)

block_shape = list(self.get_tensor_value(input_tensors[1]))
M = len(block_shape)

paddings = list(self.get_tensor_value(input_tensors[2]))
paddings = self.get_tensor_value(input_tensors[2]).tolist()

# From https://www.tensorflow.org/api_docs/python/tf/space_to_batch_nd:
# Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings
# to produce padded of shape padded_shape.
remaining_shape_length = N - M - 1
padded_list = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length
out = _op.nn.space_to_batch_nd(in_expr, block_shape, paddings)

padded_shape = []
for element in padded_list:
if isinstance(element, np.ndarray):
element = element.tolist()

padded_shape.append(element)

padded_shape = tuple(padded_shape)
padded = _op.nn.pad(in_expr, pad_width=tuple(padded_shape))

# Reshape padded to reshaped_padded of shape:
shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2]
reshaped_padded = _op.reshape(padded, newshape=shape1)

# Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape:
axes = (
[2 * i + 2 for i in range(M)]
+ [0]
+ [2 * i + 1 for i in range(M)]
+ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
)
permuted_reshaped_padded = _op.transpose(reshaped_padded, axes=axes)
permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded)

# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape:
shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1 :]
reshaped_permuted_reshaped_padded = _op.reshape(permuted_reshaped_padded, newshape=shape2)

return reshaped_permuted_reshaped_padded
return out

def convert_depth_to_space(self, op):
"""Convert TFLite DEPTH_TO_SPACE"""
Expand Down

0 comments on commit ba53bb9

Please sign in to comment.