Skip to content

Commit

Permalink
[Relay] Add space_to_batch_nd and batch_to_space_nd operators (#6477)
Browse files Browse the repository at this point in the history
* [Relay] Add space_to_batch_nd and batch_to_space_nd operators

* Correct python-format errors

* correct lint errors

* tflite frontend to use batch_to_space and space_to_batch operators

* Add new pad_value parameter with default value is 0 for space_to_batch_nd and correct variable names

* Fix cppdocs - add documentation for pad_value
  • Loading branch information
BhushanIMG authored Nov 17, 2020
1 parent 6c01998 commit 7e90e7d
Show file tree
Hide file tree
Showing 18 changed files with 1,018 additions and 156 deletions.
28 changes: 28 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,34 @@ struct CorrelationAttrs : public tvm::AttrsNode<CorrelationAttrs> {
}
}; // struct CorrelationAttrs

/*! \brief Attributes used in SpaceToBatchND operator */
struct SpaceToBatchNDAttrs : public tvm::AttrsNode<SpaceToBatchNDAttrs> {
Array<Integer> block_shape;
Array<Array<IndexExpr>> paddings;
double pad_value;

TVM_DECLARE_ATTRS(SpaceToBatchNDAttrs, "relay.attrs.SpaceToBatchNDAttrs") {
TVM_ATTR_FIELD(block_shape)
.set_default(Array<Integer>({1, 1}))
.describe("1-D containing block size for each spatial dimension.");
TVM_ATTR_FIELD(paddings).describe("2-D containing paddings for each spatial dimension.");
TVM_ATTR_FIELD(pad_value).set_default(0.0).describe("The value used for padding.");
}
}; // struct SpaceToBatchNDAttrs

/*! \brief Attributes used in BatchToSpaceND operator */
struct BatchToSpaceNDAttrs : public tvm::AttrsNode<BatchToSpaceNDAttrs> {
Array<Integer> block_shape;
Array<Array<IndexExpr>> crops;

TVM_DECLARE_ATTRS(BatchToSpaceNDAttrs, "relay.attrs.BatchToSpaceNDAttrs") {
TVM_ATTR_FIELD(block_shape)
.set_default(Array<Integer>({1, 1}))
.describe("1-D containing block size for each spatial dimension.");
TVM_ATTR_FIELD(crops).describe("2-D containing amount to crop from spatial dimension.");
}
}; // struct BatchToSpaceNDAttrs

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_NN_H_
178 changes: 178 additions & 0 deletions include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <tvm/tir/op.h>
#include <tvm/topi/detail/constant_utils.h>
#include <tvm/topi/tags.h>
#include <tvm/topi/transform.h>

#include <algorithm>
#include <string>
Expand Down Expand Up @@ -459,6 +460,183 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t
return tvm::te::compute(output_shape, l, name, tag);
}

/*!
* \brief Divide spatial dimensions of the input into a grid of blocks.
*
* \param data The input tensor.
* \param block_shape The size of the spatial block.
* \param pad_before The zero-padding size before each spatial dimension.
* \param pad_after The zero-padding size after each spatial dimension.
* \param pad_value The value used for padding.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the space_to_batch_nd operation
*/
inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data,
const tvm::Array<Integer>& block_shape,
const tvm::Array<tvm::PrimExpr>& pad_before,
const tvm::Array<tvm::PrimExpr>& pad_after,
PrimExpr pad_value = PrimExpr(),
std::string name = "space_to_batch_nd",
std::string tag = kInjective) {
tvm::te::Tensor padded_t;
CHECK_EQ(pad_before.size(), pad_after.size());
CHECK_EQ(block_shape.size(), pad_before.size())
<< "Paddings must be provided for each spatial dimension";
tvm::Array<tvm::PrimExpr> pad_before_int32;
tvm::Array<tvm::PrimExpr> pad_after_int32;

// pad size for batch dimension is 0
pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), 0));
// insert pad sizes given for spatial dimensions
for (const auto& ele : pad_before) {
pad_before_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
}
for (const auto& ele : pad_after) {
pad_after_int32.push_back(tvm::cast(tvm::DataType::Int(32), ele));
}

// pad the input with paddings provided
if (!pad_value.defined()) {
pad_value = tvm::tir::make_const(data->dtype, 0);
}
padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value);

auto input_shape = data->shape;
auto padded_shape = padded_t->shape;

// infer shapes
tvm::Array<PrimExpr> r_shape;
tvm::Array<Integer> axis;
tvm::Array<PrimExpr> o_shape;

size_t num_block_dims = block_shape.size();
int batch = static_cast<int>(GetConstInt(input_shape[0]));
tvm::PrimExpr block_shape_prod(1);
r_shape.push_back(batch);

for (size_t i = 1; i <= num_block_dims; i++) {
int padded_input = static_cast<int>(GetConstInt(padded_shape[i]));
int block_size = static_cast<int>(GetConstInt(block_shape[i - 1]));
CHECK_EQ((padded_input % block_size), 0)
<< "(" << i
<< ")th "
"Input dimension after padding ("
<< padded_input << ")"
<< " must be divisible by its block size (" << block_size << ")";

r_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
r_shape.push_back(block_shape[i - 1]);
block_shape_prod *= block_shape[i - 1];
axis.push_back(Integer(r_shape.size() - 1)); // index of block_shape[i - 1]
}

size_t n = axis.size();
axis.push_back(0); // batch is at index 0
// index of (padded_shape[i] / block_shape[i - 1]) in r_shape
for (size_t i = 0; i < n; i++) {
axis.push_back(static_cast<int>(GetConstInt(axis[i] - 1)));
}
o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod);
for (size_t i = 1; i <= num_block_dims; i++) {
o_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
}
// append remaining shape
for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) {
r_shape.push_back(input_shape[i]);
axis.push_back(Integer(r_shape.size() - 1)); // index of remaining shape in r_shape
o_shape.push_back(input_shape[i]);
}

tvm::te::Tensor output = reshape(padded_t, r_shape);
output = transpose(output, axis);
output = reshape(output, o_shape);

return output;
}

/*!
* \brief Reshape the batch dimension into spatial dimensions.
*
* \param data The input tensor.
* \param block_shape The size of the spatial block.
* \param crop_begin_list The begin crop size for each spatial dimension.
* \param crop_end_list The end crop size for each spatial dimension.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
*
* \return A Tensor whose op member is the batch_to_space_nd operation
*/
inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
const tvm::Array<Integer>& block_shape,
const tvm::Array<tvm::PrimExpr>& crop_begin_list,
const tvm::Array<tvm::PrimExpr>& crop_end_list,
std::string name = "batch_to_space_nd",
std::string tag = kInjective) {
// Construct shapes for reshape and transpose operation
Array<PrimExpr> in_shape = data->shape;
Array<PrimExpr> r_shape;
Array<Integer> axis;
size_t num_block_dims = block_shape.size();
size_t num_input_dims = in_shape.size();
tvm::PrimExpr block_shape_prod(1);
int batch = static_cast<int>(GetConstInt(in_shape[0]));

for (size_t i = 0; i < num_block_dims; i++) {
r_shape.push_back(block_shape[i]);
block_shape_prod *= block_shape[i];
}
axis.push_back(Integer(r_shape.size())); // axis of (batch / block_shape_prod)
r_shape.push_back(batch / block_shape_prod);

for (size_t i = 1; i < num_input_dims; i++) {
axis.push_back(Integer(r_shape.size())); // axis of in_shape[i]
if (axis.size() < (num_block_dims + num_input_dims)) {
axis.push_back(Integer(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i]
}
r_shape.push_back(in_shape[i]);
}

Array<PrimExpr> r_p_shape;
r_p_shape.push_back(batch / block_shape_prod);
for (size_t i = 1; i <= num_block_dims; i++) {
r_p_shape.push_back(in_shape[i] * block_shape[i - 1]);
}
for (size_t i = num_block_dims + 1; i < num_input_dims; i++) {
r_p_shape.push_back(in_shape[i]);
}

tvm::te::Tensor out;
out = reshape(data, r_shape);
out = transpose(out, axis);
out = reshape(out, r_p_shape);

// Crop the start and end of dimensions of out
Array<Integer> begin_idx, end_idx, strides;
for (size_t i = 0; i < r_p_shape.size(); ++i) {
strides.push_back(Integer(1));
if (i > 0 && i <= num_block_dims) {
// prepare begin and end index for spatial dimensions
int begin_i = static_cast<int>(GetConstInt(crop_begin_list[i - 1]));
int end_i = static_cast<int>(GetConstInt(crop_end_list[i - 1]));
int out_i = static_cast<int>(GetConstInt(r_p_shape[i]));
CHECK_GT(out_i, (begin_i + end_i))
<< "Incorrect crop sizes for (" << i << ")th dim, can not crop more than"
<< " output size" << out_i << " vs " << (begin_i + end_i);
begin_idx.push_back(begin_i);
end_idx.push_back(out_i - end_i);
} else {
// ignore the batch and remaining dimension
begin_idx.push_back(Integer(0));
end_idx.push_back(static_cast<int>(GetConstInt(r_p_shape[i])));
}
}

out = strided_slice(out, begin_idx, end_idx, strides);
return out;
}
} // namespace topi
} // namespace tvm
#endif // TVM_TOPI_NN_H_
88 changes: 11 additions & 77 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2060,8 +2060,6 @@ def _impl(inputs, attr, params, mod):

def _space_to_batch_nd():
def _impl(inputs, attr, params, mod):
input_node = inputs[0]
input_shape = _infer_shape(input_node, mod)
try:
block_shape = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
Expand All @@ -2075,48 +2073,18 @@ def _impl(inputs, attr, params, mod):
if len(paddings.shape) == 1:
paddings = np.expand_dims(paddings, axis=0)
paddings = paddings.tolist()
N = len(input_shape)
M = len(block_shape)
batch = input_shape[0]
remaining_shape_length = N - M - 1
paddings = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d:
# Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings
# to produce padded of shape padded_shape.
padded = tvm.relay.nn.pad(input_node, pad_width=paddings)
# Reshape padded to reshaped_padded of shape:
# [batch] + [padded_shape[1] / block_shape[0], block_shape[0], ...,
# padded_shape[M] / block_shape[M-1], block_shape[M-1]] + remaining_shape
shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2]
reshaped_padded = tvm.relay.reshape(padded, newshape=shape1)
# Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape:
# block_shape + [batch] + [padded_shape[1] / block_shape[0], ...,
# padded_shape[M] / block_shape[M-1]] + remaining_shape
axes = (
[2 * i + 2 for i in range(M)]
+ [0]
+ [2 * i + 1 for i in range(M)]
+ list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length))
)
permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
permuted_reshaped_padded_shape = _infer_shape(permuted_reshaped_padded, mod)
# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
# padded_shape[M] / block_shape[M-1]] + remaining_shape
shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1 :]
reshaped_permuted_reshaped_padded = tvm.relay.reshape(
permuted_reshaped_padded, newshape=shape2
)
return reshaped_permuted_reshaped_padded

attr["block_shape"] = block_shape
attr["paddings"] = paddings
out = AttrCvt("space_to_batch_nd", ignores=["Tblock_shape", "Tpaddings"])([inputs[0]], attr)

return out

return _impl


def _batch_to_space_nd():
def _impl(inputs, attr, params, mod):
input_node = inputs[0]
input_shape = _infer_shape(input_node, mod)
try:
block_shape = _get_list_param(params, inputs[1])
except (IndexError, KeyError, AttributeError):
Expand All @@ -2130,46 +2098,12 @@ def _impl(inputs, attr, params, mod):
if len(crops.shape) == 1:
crops = np.expand_dims(crops, axis=0)
crops = crops.tolist()
M = len(block_shape)
batch = input_shape[0]
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d:
# Reshape input to reshaped of shape:
# [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape),
# input_shape[1], ..., input_shape[N-1]]
shape1 = block_shape + [batch // np.prod(block_shape)] + list(input_shape[1:])
reshaped = tvm.relay.reshape(input_node, newshape=shape1)
# Permute dimensions of reshaped to produce permuted of shape
# [batch / prod(block_shape), input_shape[1], block_shape[0], ...,
# input_shape[M], block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]]
axes = (
[M]
+ [axis for i in range(M) for axis in [M + i + 1, i]]
+ list(range(2 * M + 1, len(shape1)))
)
permuted = tvm.relay.transpose(reshaped, axes=axes)
# Reshape permuted to produce reshaped_permuted of shape
# [batch / prod(block_shape), input_shape[1] * block_shape[0], ...,
# input_shape[M] * block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]]
shape2 = [0] + [-3] * M + [-2]
reshaped_permuted = tvm.relay.reshape(permuted, newshape=shape2)
# Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops
# to produce the output of shape:
# [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1],
# ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1],
# input_shape[M+1], ..., input_shape[N-1]]
reshaped_permuted_shape = _infer_shape(reshaped_permuted, mod)
cropped = reshaped_permuted
for axis in range(1, M + 1):
crop = crops[axis - 1]
if crop != [0, 0]:
indices = tvm.relay.arange(
_expr.const(crop[0]),
_expr.const(reshaped_permuted_shape[axis] - crop[1]),
dtype="int32",
)
cropped = tvm.relay.take(cropped, indices=indices, axis=axis)

return cropped
attr["block_shape"] = block_shape
attr["crops"] = crops
out = AttrCvt("batch_to_space_nd", ignores=["Tblock_shape", "Tcrops"])([inputs[0]], attr)

return out

return _impl

Expand Down
Loading

0 comments on commit 7e90e7d

Please sign in to comment.