-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support search gan model. 1.add pixel_unshuffle support &2.enable fil…
…l_constant calc offline on arm and opencl & 3.enable reshape_calc_offline_pass on arm and opencl (#10537) * support search gan model. 1. add pixel_unshuffle support 2. enable fill_constant calc offline on arm and opencl 3. enable reshape_calc_offline_pass on arm and opencl 4. use chinese comments 5. add test for new kernel. test=develop * support search gan model. 1. add pixel_unshuffle support 2. enable fill_constant calc offline on arm and opencl 3. enable reshape_calc_offline_pass on arm and opencl 4. use chinese comments 5. add test for new kernel. test=develop * support search gan model. 1. add pixel_unshuffle support 2. enable fill_constant calc offline on arm and opencl 3. enable reshape_calc_offline_pass on arm and opencl 4. use chinese comments 5. add test for new kernel. 6. fix metal pre-commit test=develop
- Loading branch information
1 parent
a903554
commit 1d0284e
Showing
12 changed files
with
391 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "lite/kernels/host/pixel_unshuffle_compute.h" | ||
|
||
namespace paddle { | ||
namespace lite { | ||
namespace kernels { | ||
namespace host { | ||
|
||
void PixelUnShuffleCompute::Run() { | ||
auto& param = Param<operators::PixelUnShuffleParam>(); | ||
|
||
const float* x_data = param.x->data<float>(); | ||
float* output_data = param.output->mutable_data<float>(); | ||
int downscale_factor = param.downscale_factor; | ||
|
||
int batch_size = param.x->dims()[0]; | ||
int in_channels = param.x->dims()[1]; | ||
int height = param.x->dims()[2]; | ||
int width = param.x->dims()[3]; | ||
int out_channels = param.output->dims()[1]; | ||
int out_height = param.output->dims()[2]; | ||
int out_width = param.output->dims()[3]; | ||
|
||
for (int b = 0; b < batch_size; ++b) { | ||
for (int c = 0; c < in_channels; ++c) { | ||
for (int y = 0; y < height; ++y) { | ||
for (int x = 0; x < width; ++x) { | ||
int out_c = c * downscale_factor * downscale_factor + | ||
(y % downscale_factor) * downscale_factor + | ||
(x % downscale_factor); | ||
int out_y = y / downscale_factor; | ||
int out_x = x / downscale_factor; | ||
int in_index = ((b * in_channels + c) * height + y) * width + x; | ||
int out_index = | ||
((b * out_channels + out_c) * out_height + out_y) * out_width + | ||
out_x; | ||
output_data[out_index] = x_data[in_index]; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
} // namespace host | ||
} // namespace kernels | ||
} // namespace lite | ||
} // namespace paddle | ||
|
||
REGISTER_LITE_KERNEL(pixel_unshuffle, | ||
kHost, | ||
kFloat, | ||
kNCHW, | ||
paddle::lite::kernels::host::PixelUnShuffleCompute, | ||
def) | ||
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost))}) | ||
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kHost))}) | ||
.Finalize(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
#include <algorithm> | ||
#include "lite/core/kernel.h" | ||
#include "lite/core/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace lite { | ||
namespace kernels { | ||
namespace host { | ||
|
||
class PixelUnShuffleCompute | ||
: public KernelLite<TARGET(kHost), PRECISION(kFloat)> { | ||
public: | ||
using param_t = operators::PixelUnShuffleParam; | ||
|
||
void Run() override; | ||
|
||
virtual ~PixelUnShuffleCompute() = default; | ||
}; | ||
|
||
} // namespace host | ||
} // namespace kernels | ||
} // namespace lite | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "lite/operators/pixel_unshuffle_op.h" | ||
#include "lite/core/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace lite { | ||
namespace operators { | ||
|
||
bool PixelUnShuffleOpLite::CheckShape() const { | ||
CHECK_OR_FALSE(param_.x); | ||
CHECK_OR_FALSE(param_.output); | ||
CHECK_OR_FALSE(param_.downscale_factor > 0); | ||
|
||
const auto x_dims = param_.x->dims(); | ||
const auto downscale_factor = param_.downscale_factor; | ||
|
||
// check input tensor dims size | ||
CHECK_EQ_OR_FALSE(x_dims.size(), 4); | ||
|
||
// check if the height and width can be divided by downscale_factor | ||
CHECK_EQ_OR_FALSE(x_dims[2] % downscale_factor, 0); | ||
CHECK_EQ_OR_FALSE(x_dims[3] % downscale_factor, 0); | ||
|
||
return true; | ||
} | ||
|
||
bool PixelUnShuffleOpLite::InferShapeImpl() const { | ||
const auto x_dims = param_.x->dims(); | ||
const auto downscale_factor = param_.downscale_factor; | ||
|
||
// input tensor dims | ||
int N = x_dims[0]; | ||
int C = x_dims[1]; | ||
int H = x_dims[2]; | ||
int W = x_dims[3]; | ||
|
||
// output tensor dims | ||
int out_C = C * (downscale_factor * downscale_factor); | ||
int out_H = H / downscale_factor; | ||
int out_W = W / downscale_factor; | ||
|
||
// make sure the output height and width can be divided by downscale_factor | ||
if (H % downscale_factor != 0 || W % downscale_factor != 0) { | ||
return false; | ||
} | ||
|
||
DDim output_dims({N, out_C, out_H, out_W}); | ||
param_.output->Resize(output_dims); | ||
return true; | ||
} | ||
|
||
bool PixelUnShuffleOpLite::AttachImpl(const cpp::OpDesc& opdesc, | ||
lite::Scope* scope) { | ||
auto input = opdesc.Input("X").front(); | ||
auto out = opdesc.Output("Out").front(); | ||
|
||
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>(); | ||
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>(); | ||
|
||
if (opdesc.HasAttr("downscale_factor")) { | ||
param_.downscale_factor = opdesc.GetAttr<int>("downscale_factor"); | ||
} | ||
|
||
return true; | ||
} | ||
|
||
} // namespace operators | ||
} // namespace lite | ||
} // namespace paddle | ||
|
||
REGISTER_LITE_OP(pixel_unshuffle, | ||
paddle::lite::operators::PixelUnShuffleOpLite); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <string> | ||
#include <vector> | ||
#include "lite/core/op_lite.h" | ||
|
||
namespace paddle { | ||
namespace lite { | ||
namespace operators { | ||
|
||
class PixelUnShuffleOpLite : public OpLite { | ||
public: | ||
PixelUnShuffleOpLite() {} | ||
explicit PixelUnShuffleOpLite(const std::string &op_type) : OpLite(op_type) {} | ||
|
||
bool CheckShape() const override; | ||
|
||
bool InferShapeImpl() const override; | ||
|
||
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override; | ||
|
||
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } | ||
std::string DebugString() const override { return "pixel_unshuffle"; } | ||
|
||
#ifdef LITE_WITH_PROFILE | ||
void GetOpRuntimeInfo(paddle::lite::profile::OpCharacter *ch) { | ||
auto input_dims = param_.x->dims(); | ||
auto output_dims = param_.output->dims(); | ||
ch->input_shape = ch->DimToStr(input_dims); | ||
ch->output_shape = ch->DimToStr(output_dims); | ||
ch->remark = "downscale_factor" + std::to_string(param_.downscale_factor); | ||
|
||
ch->macs = 1; | ||
} | ||
#endif | ||
|
||
private: | ||
mutable PixelUnShuffleParam param_; | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace lite | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.