Skip to content

Commit

Permalink
support conv_transpose output_size attr test=develop (#2749)
Browse files Browse the repository at this point in the history
* support conv_transpose output_size attr test=develop
  • Loading branch information
jiweibo authored Jan 10, 2020
1 parent 9df558f commit 0c0a8a9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
16 changes: 16 additions & 0 deletions lite/operators/conv_transpose_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ bool ConvTransposeOpLite::InferShape() const {
paddings[i * 2 + 1],
param_.strides[i]));
}
if (!param_.output_size.empty()) {
for (size_t i = 0; i < param_.output_size.size(); ++i) {
CHECK_LT(param_.output_size[i], output_shape[i + 2] + param_.strides[i])
<< "set output_size error, the output_size should less than "
<< output_shape[i + 2] + param_.strides[i] << ", but the value is "
<< param_.output_size[i];
CHECK_GE(param_.output_size[i], output_shape[i + 2])
<< "set output_size error, the output_size should greater than or "
<< "equal to " << output_shape[i + 2] << ", but the value is "
<< param_.output_size[i];
output_shape[i + 2] = param_.output_size[i];
}
}

// Set output dims
param_.output->Resize(lite::DDim(output_shape));
Expand Down Expand Up @@ -157,6 +170,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc,
if (op_desc.HasAttr("fuse_relu")) {
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
}
if (op_desc.HasAttr("output_size")) {
param_.output_size = op_desc.GetAttr<std::vector<int>>("output_size");
}
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ struct ConvParam {
ActivationParam activation_param;
// support var_length or not
bool var_length{false};
// only used in conv_transpose.
std::vector<int> output_size;
// for int8
WITH_INT8_CONFIG
};
Expand Down

0 comments on commit 0c0a8a9

Please sign in to comment.