Skip to content

Commit

Permalink
Generalized reshape_like operator (apache#11928)
Browse files Browse the repository at this point in the history
* first commit

* fix documentation

* changed static_cast<bool>(end) to end.has_value()
fixed documentation issues

* change begin from int to optional

* test None as lhs
  • Loading branch information
sbodenstein authored and szha committed Aug 11, 2018
1 parent 17a4e2b commit 5b77036
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 15 deletions.
28 changes: 28 additions & 0 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,34 @@ void HardSigmoidBackward(const nnvm::NodeAttrs& attrs,
});
}

struct ReshapeLikeParam : public dmlc::Parameter<ReshapeLikeParam> {
dmlc::optional<int> lhs_begin, rhs_begin, lhs_end, rhs_end;
DMLC_DECLARE_PARAMETER(ReshapeLikeParam) {
DMLC_DECLARE_FIELD(lhs_begin)
.set_default(dmlc::optional<int>())
.describe(
"Defaults to 0. "
"The beginning index along which the lhs dimensions are to be "
"reshaped. Supports negative indices.");
DMLC_DECLARE_FIELD(lhs_end)
.set_default(dmlc::optional<int>())
.describe("Defaults to None. "
"The ending index along which the lhs dimensions are to be "
"used for reshaping. Supports negative indices.");
DMLC_DECLARE_FIELD(rhs_begin)
.set_default(dmlc::optional<int>())
.describe("Defaults to 0. "
"The beginning index along which the rhs dimensions are to "
"be used for "
"reshaping. Supports negative indices.");
DMLC_DECLARE_FIELD(rhs_end)
.set_default(dmlc::optional<int>())
.describe("Defaults to None. "
"The ending index along which the rhs dimensions are to be "
"used for reshaping. Supports negative indices.");
}
};

/*! \brief Unary compute */
#define MXNET_OPERATOR_REGISTER_UNARY(__name$) \
NNVM_REGISTER_OP(__name$) \
Expand Down
118 changes: 103 additions & 15 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,10 +350,109 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs)
.add_argument("lhs", "NDArray-or-Symbol", "First input.")
.add_argument("rhs", "NDArray-or-Symbol", "Second input.");

void ReshapeLikeRangeCanonicalize(int ndims, const char *side,
const dmlc::optional<int> &begin,
const dmlc::optional<int> &end, int *cbegin,
int *cend) {
*cbegin = begin.has_value() ? begin.value() : 0;
if (*cbegin < 0)
*cbegin += ndims;

if (!end.has_value()) {
*cend = ndims;
} else {
*cend = end.value();
if (*cend < 0) {
*cend += ndims;
}
}
CHECK(*cend <= ndims) << "Invalid end for " << side << "_end=" << end
<< " as dimension number is " << ndims;
CHECK((*cbegin < *cend)) << "Invalid begin, end, get " << side
<< "_begin=" << begin << ", " << side
<< "_end=" << end;

CHECK(*cend >= 0) << "Invalid end for " << side << "_end=" << end;
CHECK(*cbegin >= 0) << "Invalid begin for " << side << "_begin=" << begin;
}

void GetReshapeLikeParams(const ReshapeLikeParam &param, const TShape &lshape,
const TShape &rshape, int *lhs_begin, int *lhs_end,
int *rhs_begin, int *rhs_end) {
// LHS params
ReshapeLikeRangeCanonicalize(lshape.ndim(), "lhs", param.lhs_begin,
param.lhs_end, lhs_begin, lhs_end);
// RHS params
ReshapeLikeRangeCanonicalize(rshape.ndim(), "rhs", param.rhs_begin,
param.rhs_end, rhs_begin, rhs_end);
}

bool ReshapeLikeShapeCompute(const nnvm::NodeAttrs &attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const ReshapeLikeParam &param = nnvm::get<ReshapeLikeParam>(attrs.parsed);
const TShape &lshape = (*in_attrs)[0];
const TShape &rshape = (*in_attrs)[1];
int lhs_begin, lhs_end, rhs_begin, rhs_end;
GetReshapeLikeParams(param, lshape, rshape, &lhs_begin, &lhs_end, &rhs_begin,
&rhs_end);

int lhsrank = static_cast<int>(lshape.ndim());
int orank = lhsrank + (rhs_end - rhs_begin) - (lhs_end - lhs_begin);
TShape oshape(orank);

for (int i = 0; i < lhs_begin; ++i)
oshape[i] = lshape[i];

int opos = lhs_begin;
for (int i = rhs_begin; i < rhs_end; ++i) {
oshape[opos] = rshape[i];
opos += 1;
}

for (int i = lhs_end; i < lhsrank; ++i) {
oshape[opos] = lshape[i];
opos += 1;
}

CHECK_EQ((*in_attrs)[0].Size(), oshape.Size())
<< "Cannot reshape lhs with shape " << (*in_attrs)[0] << "to new "
<< "shape " << oshape << " because they have different "
<< "size.";
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return true;
}

DMLC_REGISTER_PARAMETER(ReshapeLikeParam);
NNVM_REGISTER_OP(reshape_like)
.describe("Reshape lhs to have the same shape as rhs.")
.describe(R"code(Reshape some or all dimensions of `lhs` to have the same shape as some or all dimensions of `rhs`.
Returns a **view** of the `lhs` array with a new shape without altering any data.
Example::
x = [1, 2, 3, 4, 5, 6]
y = [[0, -4], [3, 2], [2, 2]]
reshape_like(x, y) = [[1, 2], [3, 4], [5, 6]]
More precise control over how dimensions are inherited is achieved by specifying \
slices over the `lhs` and `rhs` array dimensions. Only the sliced `lhs` dimensions \
are reshaped to the `rhs` sliced dimensions, with the non-sliced `lhs` dimensions staying the same.
Examples::
- lhs shape = (30,7), rhs shape = (15,2,4), lhs_begin=0, lhs_end=1, rhs_begin=0, rhs_end=2, output shape = (15,2,7)
- lhs shape = (3, 5), rhs shape = (1,15,4), lhs_begin=0, lhs_end=2, rhs_begin=1, rhs_end=2, output shape = (15)
Negative indices are supported, and `None` can be used for either `lhs_end` or `rhs_end` to indicate the end of the range.
Example::
- lhs shape = (30, 12), rhs shape = (4, 2, 2, 3), lhs_begin=-1, lhs_end=None, rhs_begin=1, rhs_end=None, output shape = (30, 2, 2, 3)
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_attr_parser(ParamParser<ReshapeLikeParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) { return std::vector<std::string>{"lhs", "rhs"}; })
.set_attr<nnvm::FInplaceOption>(
Expand All @@ -365,19 +464,7 @@ NNVM_REGISTER_OP(reshape_like)
.set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs",
[](const NodeAttrs& attrs) { return std::vector<uint32_t>(1, 1); })
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.set_attr<nnvm::FInferShape>("FInferShape",
[](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
if ((*in_attrs)[0].ndim()) {
CHECK_EQ((*in_attrs)[0].Size(), (*in_attrs)[1].Size())
<< "Cannot reshape lhs with shape " << (*in_attrs)[0] << "to rhs "
<< "with shape " << (*in_attrs)[1] << " because they have different "
<< "size.";
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[1]);
return true;
})
.set_attr<nnvm::FInferShape>("FInferShape", ReshapeLikeShapeCompute)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FGradient>(
"FGradient", [](const nnvm::NodePtr& n,
Expand Down Expand Up @@ -438,7 +525,8 @@ Example::
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
return out_attrs->at(0) != -1;
})
.add_argument("data", "NDArray-or-Symbol", "Input Array.");
.add_argument("data", "NDArray-or-Symbol", "Input Array.")
.add_arguments(ReshapeLikeParam::__FIELDS__());

void SizeComputeCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
53 changes: 53 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,59 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape):
assert_allclose(exe.grad_arrays[0].asnumpy(), out_grad_npy.reshape((5, 4, 3, 7)))


@with_seed()
def test_reshape_like():
def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shape):
lhs = mx.sym.Variable("lhs")
rhs = mx.sym.Variable("rhs")
net = mx.sym.reshape_like(lhs, rhs, lhs_begin=lbeg, lhs_end=lend, rhs_begin=rbeg, rhs_end=rend)
js = net.tojson()
net = mx.sym.load_json(js)
_, output_shape, __ = net.infer_shape(lhs=lhs_shape, rhs=rhs_shape)

assert output_shape[0] == dst_shape, \
'LHS Shape = %s, RHS Shape = %s, lhs_begin = %s, lhs_end = %s, rhs_begin= %s, rhs_end= %s'\
%(str(lhs_shape), str(rhs_shape), str(lbeg), str(lend), str(rbeg), str(rend))

lhs_npy = np.random.rand(*lhs_shape)
rhs_npy = np.random.rand(*rhs_shape)
grad_npy = np.random.rand(*dst_shape)

exe = net.simple_bind(default_context(), lhs=lhs_shape, rhs=rhs_shape)
exe.arg_dict['lhs'][:] = lhs_npy
exe.arg_dict['rhs'][:] = rhs_npy
exe.forward(is_train=True)
assert np.square(exe.outputs[0].asnumpy() - lhs_npy.reshape(dst_shape)).mean() < 1E-7, \
'LHS Shape = %s, RHS Shape = %s, lhs_begin = %s, lhs_end = %s, rhs_begin= %s, rhs_end= %s'\
%(str(lhs_shape), str(rhs_shape), str(lbeg), str(lend), str(rbeg), str(rend))
exe.backward(out_grads=mx.nd.array(grad_npy))
assert np.square(exe.grad_dict['lhs'].asnumpy() - grad_npy.reshape(lhs_shape)).mean() < 1E-7, \
'LHS Shape = %s, RHS Shape = %s, lhs_begin = %s, lhs_end = %s, rhs_begin= %s, rhs_end= %s'\
%(str(lhs_shape), str(rhs_shape), str(lbeg), str(lend), str(rbeg), str(rend))
# Test new api (Using shape)
test_cases = [
[(30,), (15,2,4), 0, None, 0, 2, (15,2)],
[(30,), (15,2,4), None, 1, None, 2, (15,2)],
[(30,7), (15,2,4), 0, 1, 0, 2, (15,2,7)],
[(3,5), (1,15,4), 0, 2, 1, 2, (15,)],
[(3,5), (1,15,4), 0, None, 1, -1, (15,)],
[(30,12), (4,2,2,3), -1, None, 1, None, (30,2,2,3)],
[(1,1,7,3,1,1), (81,1,1,21), 1, -1, 1, None, (1,1,1,21,1)]
]
# for test_case in test_cases:
for test_case in test_cases:
test_reshape_like_new(*test_case)

# Test old api
lhs = mx.sym.Variable("lhs")
rhs = mx.sym.Variable("rhs")
net = mx.sym.reshape_like(lhs, rhs)
js = net.tojson()
net = mx.sym.load_json(js)
_, output_shape, __ = net.infer_shape(lhs=(40, 30), rhs=(30,20,2))
assert(output_shape[0] == (30,20,2))


@with_seed()
def test_reduce():
sample_num = 500
Expand Down

0 comments on commit 5b77036

Please sign in to comment.