Skip to content

Commit

Permalink
Add new pad_value parameter with default value is 0 for space_to_batc…
Browse files Browse the repository at this point in the history
…h_nd and correct variable names
  • Loading branch information
BhushanIMG committed Oct 28, 2020
1 parent 59798e5 commit 00fa61d
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 27 deletions.
2 changes: 2 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1328,12 +1328,14 @@ struct CorrelationAttrs : public tvm::AttrsNode<CorrelationAttrs> {
struct SpaceToBatchNDAttrs : public tvm::AttrsNode<SpaceToBatchNDAttrs> {
Array<Integer> block_shape;
Array<Array<IndexExpr>> paddings;
double pad_value;

TVM_DECLARE_ATTRS(SpaceToBatchNDAttrs, "relay.attrs.SpaceToBatchNDAttrs") {
TVM_ATTR_FIELD(block_shape)
.set_default(Array<Integer>({1, 1}))
.describe("1-D containing block size for each spatial dimension.");
TVM_ATTR_FIELD(paddings).describe("2-D containing paddings for each spatial dimension.");
TVM_ATTR_FIELD(pad_value).set_default(0.0).describe("The value used for padding.");
}
}; // struct SpaceToBatchNDAttrs

Expand Down
32 changes: 18 additions & 14 deletions include/tvm/topi/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data,
const tvm::Array<Integer>& block_shape,
const tvm::Array<tvm::PrimExpr>& pad_before,
const tvm::Array<tvm::PrimExpr>& pad_after,
PrimExpr pad_value = PrimExpr(),
std::string name = "space_to_batch_nd",
std::string tag = kInjective) {
tvm::te::Tensor padded_t;
Expand All @@ -497,7 +498,10 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data,
}

// pad the input with paddings provided
padded_t = pad(data, pad_before_int32, pad_after_int32, make_const(DataType::Int(32), 0));
if (!pad_value.defined()) {
pad_value = tvm::tir::make_const(data->dtype, 0);
}
padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value);

auto input_shape = data->shape;
auto padded_shape = padded_t->shape;
Expand All @@ -507,12 +511,12 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data,
tvm::Array<Integer> axis;
tvm::Array<PrimExpr> o_shape;

size_t M = block_shape.size();
size_t num_block_dims = block_shape.size();
int batch = static_cast<int>(GetConstInt(input_shape[0]));
tvm::PrimExpr block_shape_prod(1);
r_shape.push_back(batch);

for (size_t i = 1; i <= M; i++) {
for (size_t i = 1; i <= num_block_dims; i++) {
int padded_input = static_cast<int>(GetConstInt(padded_shape[i]));
int block_size = static_cast<int>(GetConstInt(block_shape[i - 1]));
CHECK_EQ((padded_input % block_size), 0)
Expand All @@ -535,11 +539,11 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data,
axis.push_back(static_cast<int>(GetConstInt(axis[i] - 1)));
}
o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod);
for (size_t i = 1; i <= M; i++) {
for (size_t i = 1; i <= num_block_dims; i++) {
o_shape.push_back(div(padded_shape[i], block_shape[i - 1]));
}
// append remaining shape
for (size_t i = M + 1; i < input_shape.size(); i++) {
for (size_t i = num_block_dims + 1; i < input_shape.size(); i++) {
r_shape.push_back(input_shape[i]);
axis.push_back(Integer(r_shape.size() - 1)); // index of remaining shape in r_shape
o_shape.push_back(input_shape[i]);
Expand Down Expand Up @@ -574,32 +578,32 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
Array<PrimExpr> in_shape = data->shape;
Array<PrimExpr> r_shape;
Array<Integer> axis;
size_t M = block_shape.size();
size_t N = in_shape.size();
size_t num_block_dims = block_shape.size();
size_t num_input_dims = in_shape.size();
tvm::PrimExpr block_shape_prod(1);
int batch = static_cast<int>(GetConstInt(in_shape[0]));

for (size_t i = 0; i < M; i++) {
for (size_t i = 0; i < num_block_dims; i++) {
r_shape.push_back(block_shape[i]);
block_shape_prod *= block_shape[i];
}
axis.push_back(Integer(r_shape.size())); // axis of (batch / block_shape_prod)
r_shape.push_back(batch / block_shape_prod);

for (size_t i = 1; i < N; i++) {
for (size_t i = 1; i < num_input_dims; i++) {
axis.push_back(Integer(r_shape.size())); // axis of in_shape[i]
if (axis.size() < (M + N)) {
axis.push_back(Integer(r_shape.size() - (M + 1))); // axis of block_shape[i]
if (axis.size() < (num_block_dims + num_input_dims)) {
axis.push_back(Integer(r_shape.size() - (num_block_dims + 1))); // axis of block_shape[i]
}
r_shape.push_back(in_shape[i]);
}

Array<PrimExpr> r_p_shape;
r_p_shape.push_back(batch / block_shape_prod);
for (size_t i = 1; i <= M; i++) {
for (size_t i = 1; i <= num_block_dims; i++) {
r_p_shape.push_back(in_shape[i] * block_shape[i - 1]);
}
for (size_t i = M + 1; i < N; i++) {
for (size_t i = num_block_dims + 1; i < num_input_dims; i++) {
r_p_shape.push_back(in_shape[i]);
}

Expand All @@ -612,7 +616,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data,
Array<Integer> begin_idx, end_idx, strides;
for (size_t i = 0; i < r_p_shape.size(); ++i) {
strides.push_back(Integer(1));
if (i > 0 && i <= M) {
if (i > 0 && i <= num_block_dims) {
// prepare begin and end index for spatial dimensions
int begin_i = static_cast<int>(GetConstInt(crop_begin_list[i - 1]));
int end_i = static_cast<int>(GetConstInt(crop_end_list[i - 1]));
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3158,7 +3158,7 @@ def correlation(
)


def space_to_batch_nd(data, block_shape, paddings):
def space_to_batch_nd(data, block_shape, paddings, pad_value=0):
r"""Divide spatial dimensions of the data into a grid of blocks
and interleave them into batch dim.
Expand All @@ -3175,6 +3175,9 @@ def space_to_batch_nd(data, block_shape, paddings):
2-D of shape [M, 2] where M is number of spatial dims, specifies
[before, after] paddings for each spatial dimension.
pad_value : float, or relay.Expr, optional, default=0
The value used for padding.
Returns
-------
result : relay.Expr
Expand All @@ -3184,7 +3187,7 @@ def space_to_batch_nd(data, block_shape, paddings):
remaining_shape]
"""

return _make.space_to_batch_nd(data, block_shape, paddings)
return _make.space_to_batch_nd(data, block_shape, paddings, pad_value)


def batch_to_space_nd(data, block_shape, crops):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/topi/nn/space_to_batch_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from . import cpp


def space_to_batch_nd(data, block_shape, pad_before, pad_after):
def space_to_batch_nd(data, block_shape, pad_before, pad_after, pad_value=0.0):
"""Perform batch to space transformation on the data
Parameters
Expand All @@ -41,9 +41,12 @@ def space_to_batch_nd(data, block_shape, pad_before, pad_after):
list of shape [M] where M is number of spatial dims, specifies
zero-padding size after each spatial dimension.
pad_value : float, optional
The value used for padding.
Returns
-------
output : tvm.te.Tensor
"""

return cpp.nn.space_to_batch_nd(data, block_shape, pad_before, pad_after)
return cpp.nn.space_to_batch_nd(data, block_shape, pad_before, pad_after, pad_value)
7 changes: 5 additions & 2 deletions python/tvm/topi/testing/space_to_batch_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np


def space_to_batch_nd_python(data, block_shape, pad_before, pad_after):
def space_to_batch_nd_python(data, block_shape, pad_before, pad_after, pad_value=0):
"""Space to Batch operator in python for NHWC layout.
Parameters
Expand All @@ -40,6 +40,9 @@ def space_to_batch_nd_python(data, block_shape, pad_before, pad_after):
list of shape [M] where M is number of spatial dims, specifies
zero-padding size after each spatial dimension.
pad_value : float, optional
the value used for padding. Defaults to 0.
Returns
-------
s2b_out : np.ndarray
Expand All @@ -56,7 +59,7 @@ def space_to_batch_nd_python(data, block_shape, pad_before, pad_after):
# Add the paddings for batch and remaining dims
paddings = map(list, zip(pad_before, pad_after))
paddings = [[0, 0]] + list(paddings) + [[0, 0]] * (data.ndim - 1 - M)
padded_data = np.pad(data, paddings, mode="constant")
padded_data = np.pad(data, paddings, mode="constant", constant_values=pad_value)
padded_shape = padded_data.shape

# Get the reshape shape and transpose axes
Expand Down
10 changes: 7 additions & 3 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1149,10 +1149,12 @@ RELAY_REGISTER_OP("nn.space_to_depth")
// used by frontend FFI
TVM_REGISTER_NODE_TYPE(SpaceToBatchNDAttrs);

Expr MakeSpaceToBatchND(Expr data, Array<Integer> block_shape, Array<Array<IndexExpr>> paddings) {
Expr MakeSpaceToBatchND(Expr data, Array<Integer> block_shape, Array<Array<IndexExpr>> paddings,
double pad_value) {
auto attrs = make_object<SpaceToBatchNDAttrs>();
attrs->block_shape = std::move(block_shape);
attrs->paddings = std::move(paddings);
attrs->pad_value = pad_value;
static const Op& op = Op::Get("nn.space_to_batch_nd");
return Call(op, {data}, Attrs(attrs), {});
}
Expand Down Expand Up @@ -1225,8 +1227,10 @@ Array<te::Tensor> SpaceToBatchNDCompute(const Attrs& attrs, const Array<te::Tens
for (size_t i = 0; i < paddings.size(); ++i) {
pad_after.push_back(paddings[i][1]);
}

return Array<te::Tensor>{topi::space_to_batch_nd(inputs[0], b_shape, pad_before, pad_after)};
const auto* out_ttype = out_type.as<TensorTypeNode>();
return Array<te::Tensor>{
topi::space_to_batch_nd(inputs[0], b_shape, pad_before, pad_after,
tvm::tir::make_const(out_ttype->dtype, param->pad_value))};
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.space_to_batch_nd").set_body_typed(MakeSpaceToBatchND);
Expand Down
2 changes: 1 addition & 1 deletion src/topi/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ TVM_REGISTER_GLOBAL("topi.nn.pad").set_body([](TVMArgs args, TVMRetValue* rv) {
});

TVM_REGISTER_GLOBAL("topi.nn.space_to_batch_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = space_to_batch_nd(args[0], args[1], args[2], args[3]);
*rv = space_to_batch_nd(args[0], args[1], args[2], args[3], args[4]);
});

TVM_REGISTER_GLOBAL("topi.nn.batch_to_space_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down
8 changes: 5 additions & 3 deletions tests/python/topi/python/test_topi_space_to_batch_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import tvm.topi.testing


def verify_space_to_batch_nd(input_shape, block_shape, pad_before, pad_after):
def verify_space_to_batch_nd(input_shape, block_shape, pad_before, pad_after, pad_value=0):
out_shape = []
out_shape.append(int((input_shape[0] * np.prod(block_shape))))
for i in range(1, len(block_shape) + 1):
Expand All @@ -36,9 +36,11 @@ def verify_space_to_batch_nd(input_shape, block_shape, pad_before, pad_after):
dtype = A.dtype
a_np = np.random.uniform(size=input_shape).astype(dtype)

B = topi.nn.space_to_batch_nd(A, block_shape, pad_before, pad_after)
B = topi.nn.space_to_batch_nd(A, block_shape, pad_before, pad_after, pad_value)

b_np = tvm.topi.testing.space_to_batch_nd_python(a_np, block_shape, pad_before, pad_after)
b_np = tvm.topi.testing.space_to_batch_nd_python(
a_np, block_shape, pad_before, pad_after, pad_value
)

def check_device(device, ctx):
print("Running on target: %s" % device)
Expand Down

0 comments on commit 00fa61d

Please sign in to comment.