-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Generalized reshape_like operator #11928
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -476,6 +476,31 @@ void HardSigmoidBackward(const nnvm::NodeAttrs& attrs, | |
}); | ||
} | ||
|
||
struct ReshapeLikeParam : public dmlc::Parameter<ReshapeLikeParam> { | ||
int lhs_begin, rhs_begin; | ||
dmlc::optional<int> lhs_end, rhs_end; | ||
DMLC_DECLARE_PARAMETER(ReshapeLikeParam) { | ||
DMLC_DECLARE_FIELD(lhs_begin).set_default(0).describe( | ||
"Defaults to 0. " | ||
"The beginning index along which the lhs dimensions are to be " | ||
"reshaped. Supports negative indices."); | ||
DMLC_DECLARE_FIELD(lhs_end) | ||
.set_default(dmlc::optional<int>()) | ||
.describe("Defaults to None. The ending index to be used, " | ||
"The ending index along which the lhs dimensions are to be " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This describe comment seems to have a spurious sentence fragment in it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
"reshaped. Supports negative indices."); | ||
DMLC_DECLARE_FIELD(rhs_begin).set_default(0).describe( | ||
"Defaults to 0. " | ||
"The beginning index along which the rhs dimensions are to be used for " | ||
"reshaping. Supports negative indices."); | ||
DMLC_DECLARE_FIELD(rhs_end) | ||
.set_default(dmlc::optional<int>()) | ||
.describe("Defaults to None." | ||
"The ending index along which the rhs dimensions are to be " | ||
"used for reshaping. Supports negative indices."); | ||
} | ||
}; | ||
|
||
/*! \brief Unary compute */ | ||
#define MXNET_OPERATOR_REGISTER_UNARY(__name$) \ | ||
NNVM_REGISTER_OP(__name$) \ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -350,10 +350,108 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) | |
.add_argument("lhs", "NDArray-or-Symbol", "First input.") | ||
.add_argument("rhs", "NDArray-or-Symbol", "Second input."); | ||
|
||
void ReshapeLikeRangeCanonicalize(int ndims, const char *side, int begin, | ||
const dmlc::optional<int> &end, int *cbegin, | ||
int *cend) { | ||
*cbegin = begin; | ||
if (*cbegin < 0) | ||
*cbegin += ndims; | ||
|
||
if (!static_cast<bool>(end)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use has_value for better readability. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, that is cleaner. Fixed. Btw: the reason I used this casting method because it was done here. Should this be changed as well? |
||
*cend = ndims; | ||
} else { | ||
*cend = end.value(); | ||
if (*cend < 0) { | ||
*cend += ndims; | ||
} | ||
} | ||
CHECK(*cend <= ndims) << "Invalid end for " << side << "_end=" << end | ||
<< " as dimension number is " << ndims; | ||
CHECK((*cbegin < *cend)) << "Invalid begin, end, get " << side | ||
<< "_begin=" << begin << ", " << side | ||
<< "_end=" << end; | ||
|
||
CHECK(*cend >= 0) << "Invalid end for " << side << "_end=" << end; | ||
CHECK(*cbegin >= 0) << "Invalid begin for " << side << "_begin=" << begin; | ||
} | ||
|
||
void GetReshapeLikeParams(const ReshapeLikeParam ¶m, const TShape &lshape, | ||
const TShape &rshape, int *lhs_begin, int *lhs_end, | ||
int *rhs_begin, int *rhs_end) { | ||
// LHS params | ||
ReshapeLikeRangeCanonicalize(lshape.ndim(), "lhs", param.lhs_begin, | ||
param.lhs_end, lhs_begin, lhs_end); | ||
// RHS params | ||
ReshapeLikeRangeCanonicalize(rshape.ndim(), "rhs", param.rhs_begin, | ||
param.rhs_end, rhs_begin, rhs_end); | ||
} | ||
|
||
bool ReshapeLikeShapeCompute(const nnvm::NodeAttrs &attrs, | ||
std::vector<TShape> *in_attrs, | ||
std::vector<TShape> *out_attrs) { | ||
const ReshapeLikeParam ¶m = nnvm::get<ReshapeLikeParam>(attrs.parsed); | ||
const TShape &lshape = (*in_attrs)[0]; | ||
const TShape &rshape = (*in_attrs)[1]; | ||
int lhs_begin, lhs_end, rhs_begin, rhs_end; | ||
GetReshapeLikeParams(param, lshape, rshape, &lhs_begin, &lhs_end, &rhs_begin, | ||
&rhs_end); | ||
|
||
int lhsrank = static_cast<int>(lshape.ndim()); | ||
int orank = lhsrank + (rhs_end - rhs_begin) - (lhs_end - lhs_begin); | ||
TShape oshape(orank); | ||
|
||
for (int i = 0; i < lhs_begin; ++i) | ||
oshape[i] = lshape[i]; | ||
|
||
int opos = lhs_begin; | ||
for (int i = rhs_begin; i < rhs_end; ++i) { | ||
oshape[opos] = rshape[i]; | ||
opos += 1; | ||
} | ||
|
||
for (int i = lhs_end; i < lhsrank; ++i) { | ||
oshape[opos] = lshape[i]; | ||
opos += 1; | ||
} | ||
|
||
CHECK_EQ((*in_attrs)[0].Size(), oshape.Size()) | ||
<< "Cannot reshape lhs with shape " << (*in_attrs)[0] << "to new " | ||
<< "shape " << oshape << " because they have different " | ||
<< "size."; | ||
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); | ||
return true; | ||
} | ||
|
||
DMLC_REGISTER_PARAMETER(ReshapeLikeParam); | ||
NNVM_REGISTER_OP(reshape_like) | ||
.describe("Reshape lhs to have the same shape as rhs.") | ||
.describe(R"code(Reshape `lhs` to have the same shape as `rhs`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would alter this line to say something more along the line: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, much better. Changed. |
||
|
||
Returns a **view** of the `lhs` array with a new shape without altering any data. | ||
|
||
Example:: | ||
|
||
x = [1, 2, 3, 4, 5, 6] | ||
y = [[0, -4], [3, 2], [2, 2]] | ||
reshape_like(x, y) = [[1, 2], [3, 4], [5, 6]] | ||
|
||
More precise control over how dimensions are inherited is achieved by specifying \ | ||
slices over the `lhs` and `rhs` array dimensions. Only the sliced `lhs` dimensions \ | ||
are reshaped to the `rhs` sliced dimensions, with the non-sliced `lhs` dimensions staying the same. | ||
|
||
Examples:: | ||
|
||
- lhs shape = (30,7), rhs shape = (15,2,4), lhs_begin=0, lhs_end=1, rhs_begin=0, rhs_end=2, output shape = (15,2,7) | ||
- lhs shape = (3, 5), rhs shape = (1,15,4), lhs_begin=0, lhs_end=2, rhs_begin=1, rhs_end=2, output shape = (15) | ||
|
||
Negative indices are supported, and `None` can be used for either `lhs_end` or `rhs_end` to indicate the end of the range. | ||
|
||
Example:: | ||
|
||
- lhs shape = (30, 12), rhs shape = (4, 2, 2, 3), lhs_begin=-1, lhs_end=None, rhs_begin=1, rhs_end=None, output shape = (30, 2, 2, 3) | ||
|
||
)code" ADD_FILELINE) | ||
.set_num_inputs(2) | ||
.set_attr_parser(ParamParser<ReshapeLikeParam>) | ||
.set_attr<nnvm::FListInputNames>("FListInputNames", | ||
[](const NodeAttrs& attrs) { return std::vector<std::string>{"lhs", "rhs"}; }) | ||
.set_attr<nnvm::FInplaceOption>( | ||
|
@@ -365,19 +463,7 @@ NNVM_REGISTER_OP(reshape_like) | |
.set_attr<nnvm::FIgnoreInputs>("FIgnoreInputs", | ||
[](const NodeAttrs& attrs) { return std::vector<uint32_t>(1, 1); }) | ||
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>) | ||
.set_attr<nnvm::FInferShape>("FInferShape", | ||
[](const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape> *in_attrs, | ||
std::vector<TShape> *out_attrs) { | ||
if ((*in_attrs)[0].ndim()) { | ||
CHECK_EQ((*in_attrs)[0].Size(), (*in_attrs)[1].Size()) | ||
<< "Cannot reshape lhs with shape " << (*in_attrs)[0] << "to rhs " | ||
<< "with shape " << (*in_attrs)[1] << " because they have different " | ||
<< "size."; | ||
} | ||
SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[1]); | ||
return true; | ||
}) | ||
.set_attr<nnvm::FInferShape>("FInferShape", ReshapeLikeShapeCompute) | ||
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) | ||
.set_attr<nnvm::FGradient>( | ||
"FGradient", [](const nnvm::NodePtr& n, | ||
|
@@ -438,7 +524,8 @@ Example:: | |
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); | ||
return out_attrs->at(0) != -1; | ||
}) | ||
.add_argument("data", "NDArray-or-Symbol", "Input Array."); | ||
.add_argument("data", "NDArray-or-Symbol", "Input Array.") | ||
.add_arguments(ReshapeLikeParam::__FIELDS__()); | ||
|
||
void SizeComputeCPU(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if these two parameters should be made optional too, given that reshape_like op has been around for 10 months.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are optional in the sense that they have default values
DMLC_DECLARE_FIELD(lhs_begin).set_default(0)
that match the old behaviour if not explicitly specified (so no backward-compatibility is broken). I thoughtdmlc::optional<int>
was simply for the case where you had to handle theNone
-value case, which is not necessary to support forlhs_begin
andrhs_begin
. Note that the same is done forslice_axis
,begin
is anint
andend
isdmlc::optional<int>
.Or am I misunderstanding something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Considering the case where the parameters may be in some serialized format, it may be necessary to support null values to ensure compatibility there too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, if its necessary, I will change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szha Can I clarify something? Maybe I am misunderstanding. Are you saying that it is required to make new parameters on existing layers optional<...> for backward compatibility reasons, even if they have defaults?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@taliesinb I think it doesn't hurt in this case and was suggesting that we lean on the safer side. There might be cases where the default value filling on the frontend fails, such as deserializing a graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szha ok. so it sounds like there isn't a particular policy at the moment about how to add new parameters to existing layers. do you mind if i ask on the mailing list about what such a policy should be, just for the next time this happens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
of course not. good idea :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@szha: I've changed
lhs_begin
andrhs_begin
to use optional (and added tests for this case). Can we merge?