diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 0c37a941fb67..e09a6cccddbf 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -476,6 +476,34 @@ void HardSigmoidBackward(const nnvm::NodeAttrs& attrs, }); } +struct ReshapeLikeParam : public dmlc::Parameter { + dmlc::optional lhs_begin, rhs_begin, lhs_end, rhs_end; + DMLC_DECLARE_PARAMETER(ReshapeLikeParam) { + DMLC_DECLARE_FIELD(lhs_begin) + .set_default(dmlc::optional()) + .describe( + "Defaults to 0. " + "The beginning index along which the lhs dimensions are to be " + "reshaped. Supports negative indices."); + DMLC_DECLARE_FIELD(lhs_end) + .set_default(dmlc::optional()) + .describe("Defaults to None. " + "The ending index along which the lhs dimensions are to be " + "used for reshaping. Supports negative indices."); + DMLC_DECLARE_FIELD(rhs_begin) + .set_default(dmlc::optional()) + .describe("Defaults to 0. " + "The beginning index along which the rhs dimensions are to " + "be used for " + "reshaping. Supports negative indices."); + DMLC_DECLARE_FIELD(rhs_end) + .set_default(dmlc::optional()) + .describe("Defaults to None. " + "The ending index along which the rhs dimensions are to be " + "used for reshaping. Supports negative indices."); + } +}; + /*! \brief Unary compute */ #define MXNET_OPERATOR_REGISTER_UNARY(__name$) \ NNVM_REGISTER_OP(__name$) \ diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 929bc7426d5f..f7f21f9076a6 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -350,10 +350,109 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) .add_argument("lhs", "NDArray-or-Symbol", "First input.") .add_argument("rhs", "NDArray-or-Symbol", "Second input."); +void ReshapeLikeRangeCanonicalize(int ndims, const char *side, + const dmlc::optional &begin, + const dmlc::optional &end, int *cbegin, + int *cend) { + *cbegin = begin.has_value() ? begin.value() : 0; + if (*cbegin < 0) + *cbegin += ndims; + + if (!end.has_value()) { + *cend = ndims; + } else { + *cend = end.value(); + if (*cend < 0) { + *cend += ndims; + } + } + CHECK(*cend <= ndims) << "Invalid end for " << side << "_end=" << end + << " as dimension number is " << ndims; + CHECK((*cbegin < *cend)) << "Invalid begin, end, get " << side + << "_begin=" << begin << ", " << side + << "_end=" << end; + + CHECK(*cend >= 0) << "Invalid end for " << side << "_end=" << end; + CHECK(*cbegin >= 0) << "Invalid begin for " << side << "_begin=" << begin; +} + +void GetReshapeLikeParams(const ReshapeLikeParam ¶m, const TShape &lshape, + const TShape &rshape, int *lhs_begin, int *lhs_end, + int *rhs_begin, int *rhs_end) { + // LHS params + ReshapeLikeRangeCanonicalize(lshape.ndim(), "lhs", param.lhs_begin, + param.lhs_end, lhs_begin, lhs_end); + // RHS params + ReshapeLikeRangeCanonicalize(rshape.ndim(), "rhs", param.rhs_begin, + param.rhs_end, rhs_begin, rhs_end); +} +bool ReshapeLikeShapeCompute(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ReshapeLikeParam ¶m = nnvm::get(attrs.parsed); + const TShape &lshape = (*in_attrs)[0]; + const TShape &rshape = (*in_attrs)[1]; + int lhs_begin, lhs_end, rhs_begin, rhs_end; + GetReshapeLikeParams(param, lshape, rshape, &lhs_begin, &lhs_end, &rhs_begin, + &rhs_end); + + int lhsrank = static_cast(lshape.ndim()); + int orank = lhsrank + (rhs_end - rhs_begin) - (lhs_end - lhs_begin); + TShape oshape(orank); + + for (int i = 0; i < lhs_begin; ++i) + oshape[i] = lshape[i]; + + int opos = lhs_begin; + for (int i = rhs_begin; i < rhs_end; ++i) { + oshape[opos] = rshape[i]; + opos += 1; + } + + for (int i = lhs_end; i < lhsrank; ++i) { + oshape[opos] = lshape[i]; + opos += 1; + } + + CHECK_EQ((*in_attrs)[0].Size(), oshape.Size()) + << "Cannot reshape lhs with shape " << (*in_attrs)[0] << "to new " + << "shape " << oshape << " because they have different " + << "size."; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + return true; +} + +DMLC_REGISTER_PARAMETER(ReshapeLikeParam); NNVM_REGISTER_OP(reshape_like) -.describe("Reshape lhs to have the same shape as rhs.") +.describe(R"code(Reshape some or all dimensions of `lhs` to have the same shape as some or all dimensions of `rhs`. + +Returns a **view** of the `lhs` array with a new shape without altering any data. + +Example:: + + x = [1, 2, 3, 4, 5, 6] + y = [[0, -4], [3, 2], [2, 2]] + reshape_like(x, y) = [[1, 2], [3, 4], [5, 6]] + +More precise control over how dimensions are inherited is achieved by specifying \ +slices over the `lhs` and `rhs` array dimensions. Only the sliced `lhs` dimensions \ +are reshaped to the `rhs` sliced dimensions, with the non-sliced `lhs` dimensions staying the same. + + Examples:: + + - lhs shape = (30,7), rhs shape = (15,2,4), lhs_begin=0, lhs_end=1, rhs_begin=0, rhs_end=2, output shape = (15,2,7) + - lhs shape = (3, 5), rhs shape = (1,15,4), lhs_begin=0, lhs_end=2, rhs_begin=1, rhs_end=2, output shape = (15) + +Negative indices are supported, and `None` can be used for either `lhs_end` or `rhs_end` to indicate the end of the range. + + Example:: + + - lhs shape = (30, 12), rhs shape = (4, 2, 2, 3), lhs_begin=-1, lhs_end=None, rhs_begin=1, rhs_end=None, output shape = (30, 2, 2, 3) + +)code" ADD_FILELINE) .set_num_inputs(2) +.set_attr_parser(ParamParser) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"lhs", "rhs"}; }) .set_attr( @@ -365,19 +464,7 @@ NNVM_REGISTER_OP(reshape_like) .set_attr("FIgnoreInputs", [](const NodeAttrs& attrs) { return std::vector(1, 1); }) .set_attr("FCompute", UnaryOp::IdentityCompute) -.set_attr("FInferShape", - [](const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - if ((*in_attrs)[0].ndim()) { - CHECK_EQ((*in_attrs)[0].Size(), (*in_attrs)[1].Size()) - << "Cannot reshape lhs with shape " << (*in_attrs)[0] << "to rhs " - << "with shape " << (*in_attrs)[1] << " because they have different " - << "size."; - } - SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[1]); - return true; - }) +.set_attr("FInferShape", ReshapeLikeShapeCompute) .set_attr("FInferType", ElemwiseType<2, 1>) .set_attr( "FGradient", [](const nnvm::NodePtr& n, @@ -438,7 +525,8 @@ Example:: TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); return out_attrs->at(0) != -1; }) -.add_argument("data", "NDArray-or-Symbol", "Input Array."); +.add_argument("data", "NDArray-or-Symbol", "Input Array.") +.add_arguments(ReshapeLikeParam::__FIELDS__()); void SizeComputeCPU(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 11180ebbc5d4..d228703f35bc 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2114,6 +2114,59 @@ def test_reshape_new(src_shape, shape_args, reverse, dst_shape): assert_allclose(exe.grad_arrays[0].asnumpy(), out_grad_npy.reshape((5, 4, 3, 7))) +@with_seed() +def test_reshape_like(): + def test_reshape_like_new(lhs_shape, rhs_shape, lbeg, lend, rbeg, rend, dst_shape): + lhs = mx.sym.Variable("lhs") + rhs = mx.sym.Variable("rhs") + net = mx.sym.reshape_like(lhs, rhs, lhs_begin=lbeg, lhs_end=lend, rhs_begin=rbeg, rhs_end=rend) + js = net.tojson() + net = mx.sym.load_json(js) + _, output_shape, __ = net.infer_shape(lhs=lhs_shape, rhs=rhs_shape) + + assert output_shape[0] == dst_shape, \ + 'LHS Shape = %s, RHS Shape = %s, lhs_begin = %s, lhs_end = %s, rhs_begin= %s, rhs_end= %s'\ + %(str(lhs_shape), str(rhs_shape), str(lbeg), str(lend), str(rbeg), str(rend)) + + lhs_npy = np.random.rand(*lhs_shape) + rhs_npy = np.random.rand(*rhs_shape) + grad_npy = np.random.rand(*dst_shape) + + exe = net.simple_bind(default_context(), lhs=lhs_shape, rhs=rhs_shape) + exe.arg_dict['lhs'][:] = lhs_npy + exe.arg_dict['rhs'][:] = rhs_npy + exe.forward(is_train=True) + assert np.square(exe.outputs[0].asnumpy() - lhs_npy.reshape(dst_shape)).mean() < 1E-7, \ + 'LHS Shape = %s, RHS Shape = %s, lhs_begin = %s, lhs_end = %s, rhs_begin= %s, rhs_end= %s'\ + %(str(lhs_shape), str(rhs_shape), str(lbeg), str(lend), str(rbeg), str(rend)) + exe.backward(out_grads=mx.nd.array(grad_npy)) + assert np.square(exe.grad_dict['lhs'].asnumpy() - grad_npy.reshape(lhs_shape)).mean() < 1E-7, \ + 'LHS Shape = %s, RHS Shape = %s, lhs_begin = %s, lhs_end = %s, rhs_begin= %s, rhs_end= %s'\ + %(str(lhs_shape), str(rhs_shape), str(lbeg), str(lend), str(rbeg), str(rend)) + # Test new api (Using shape) + test_cases = [ + [(30,), (15,2,4), 0, None, 0, 2, (15,2)], + [(30,), (15,2,4), None, 1, None, 2, (15,2)], + [(30,7), (15,2,4), 0, 1, 0, 2, (15,2,7)], + [(3,5), (1,15,4), 0, 2, 1, 2, (15,)], + [(3,5), (1,15,4), 0, None, 1, -1, (15,)], + [(30,12), (4,2,2,3), -1, None, 1, None, (30,2,2,3)], + [(1,1,7,3,1,1), (81,1,1,21), 1, -1, 1, None, (1,1,1,21,1)] + ] + # for test_case in test_cases: + for test_case in test_cases: + test_reshape_like_new(*test_case) + + # Test old api + lhs = mx.sym.Variable("lhs") + rhs = mx.sym.Variable("rhs") + net = mx.sym.reshape_like(lhs, rhs) + js = net.tojson() + net = mx.sym.load_json(js) + _, output_shape, __ = net.infer_shape(lhs=(40, 30), rhs=(30,20,2)) + assert(output_shape[0] == (30,20,2)) + + @with_seed() def test_reduce(): sample_num = 500