diff --git a/paddle/fluid/operators/pixel_unshuffle_op.cc b/paddle/fluid/operators/pixel_unshuffle_op.cc new file mode 100644 index 00000000000000..8464940ac1e031 --- /dev/null +++ b/paddle/fluid/operators/pixel_unshuffle_op.cc @@ -0,0 +1,142 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + +namespace paddle { +namespace operators { + +class PixelUnshuffleOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +class PixelUnshuffleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), " + "the input feature data of PixelUnshuffleOp, the layout is [N, C, " + "H, W] or [N, H, W, C]."); + AddOutput("Out", + "(Tensor, default Tensor), the output of " + "PixelUnshuffleOp. The layout is [N, C*factor^2, H/factor, " + "W/factor] or [N, H/factor, W/factor, C*factor^2]."); + AddAttr("downscale_factor", + "the factor to descrease spatial resolution by.") + .SetDefault(1) + .AddCustomChecker([](const int& downscale_factor) { + PADDLE_ENFORCE_GE(downscale_factor, 1, + platform::errors::InvalidArgument( + "downscale_factor should be larger than 0.")); + }); + AddAttr( + "data_format", + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\", Specify the data format of the input data.") + .SetDefault("NCHW"); + + AddComment(R"DOC( + Pixel Unshuffle operator + This operator rearranges elements in a tensor of shape :math:`(*, C , H \times r, W \times r)` + to a tensor of shape :math:`( C \times r^2, H, W)`, where r is downscale factor. + + This is taken as the inverse operator of Pixel Shuffle . + + Please refer to the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient + Sub-Pixel Convolutional Neural Network `_ + by Shi et. al (2016) for more details. + + )DOC"); + } +}; + +template +class PixelUnshuffleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType("pixel_unshuffle_grad"); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetAttrMap(this->Attrs()); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +class PixelUnshuffleGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::NotFound("Input(Out@Grad) should not be null")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput(framework::GradVarName("X")), true, + platform::errors::NotFound("Output(X@Grad) should not be null")); + + auto do_dims = ctx->GetInputDim(framework::GradVarName("Out")); + PADDLE_ENFORCE_EQ(do_dims.size(), 4, + platform::errors::InvalidArgument( + "Input should be a 4-D tensor of format [N, C, H, W] " + "or [N, H, W, C], but got %u.", + do_dims.size())); + + auto downscale_factor = ctx->Attrs().Get("downscale_factor"); + + const std::string data_format = + ctx->Attrs().Get("data_format"); + const bool channel_last = (data_format == "NHWC"); + + auto dx_dims = do_dims; + dx_dims[0] = do_dims[0]; + + if (!channel_last) { + dx_dims[1] = do_dims[1] / (downscale_factor * downscale_factor); + dx_dims[2] = do_dims[2] * downscale_factor; + dx_dims[3] = do_dims[3] * downscale_factor; + } else { + dx_dims[1] = do_dims[1] * downscale_factor; + dx_dims[2] = do_dims[2] * downscale_factor; + dx_dims[3] = do_dims[3] / (downscale_factor * downscale_factor); + } + ctx->SetOutputDim(framework::GradVarName("X"), dx_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +DECLARE_INFER_SHAPE_FUNCTOR(pixel_unshuffle, PixelUnshuffleInferShapeFunctor, + PD_INFER_META(phi::PixelUnshuffleInferMeta)); + +REGISTER_OPERATOR(pixel_unshuffle, ops::PixelUnshuffleOp, + ops::PixelUnshuffleOpMaker, + ops::PixelUnshuffleGradMaker, + ops::PixelUnshuffleGradMaker, + PixelUnshuffleInferShapeFunctor); + +REGISTER_OPERATOR(pixel_unshuffle_grad, ops::PixelUnshuffleGradOp); + +REGISTER_OP_VERSION(pixel_unshuffle) + .AddCheckpoint( + R"ROC( + Compatible upgrade of pixel_unshuffle, add a new attribute [data_format])ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "data_format", "Specify the data format of the input data", true)); diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 199029a3a094aa..f6ade7884a7a52 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1204,6 +1204,71 @@ void PixelShuffleInferMeta(const MetaTensor& x, out->set_dims(output_dims); } +void PixelUnshuffleInferMeta(const MetaTensor& x, + int downscale_factor, + const std::string& data_format, + MetaTensor* out) { + auto input_dims = x.dims(); + PADDLE_ENFORCE_EQ(input_dims.size(), + 4, + phi::errors::InvalidArgument( + "Input should be a 4-D tensor of format [N, C, H, W] " + "or [N, H, W, C], but got %u.", + input_dims.size())); + + const bool channel_last = (data_format == "NHWC"); + + if (!channel_last) { + PADDLE_ENFORCE_EQ( + input_dims[2] % downscale_factor, + 0, + phi::errors::InvalidArgument( + "The square of downscale_factor[%u] should divide the " + "height[%u]", + downscale_factor, + input_dims[2])); + PADDLE_ENFORCE_EQ( + input_dims[3] % downscale_factor, + 0, + phi::errors::InvalidArgument( + "The square of downscale_factor[%u] should divide the " + "height[%u]", + downscale_factor, + input_dims[3])); + } else { + PADDLE_ENFORCE_EQ( + input_dims[1] % downscale_factor, + 0, + phi::errors::InvalidArgument( + "The square of downscale_factor[%u] should divide the " + "width[%u]", + downscale_factor, + input_dims[1])); + PADDLE_ENFORCE_EQ( + input_dims[2] % downscale_factor, + 0, + phi::errors::InvalidArgument( + "The square of downscale_factor[%u] should divide the " + "width[%u]", + downscale_factor, + input_dims[2])); + } + + auto output_dims = input_dims; + output_dims[0] = input_dims[0]; + if (!channel_last) { + output_dims[1] = input_dims[1] * (downscale_factor * downscale_factor); + output_dims[2] = input_dims[2] / downscale_factor; + output_dims[3] = input_dims[3] / downscale_factor; + } else { + output_dims[1] = input_dims[1] / downscale_factor; + output_dims[2] = input_dims[2] / downscale_factor; + output_dims[3] = input_dims[3] * (downscale_factor * downscale_factor); + } + out->set_dtype(x.dtype()); + out->set_dims(output_dims); +} + void PNormInferMeta(const MetaTensor& x, float porder, int axis, diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index bae8083ef71916..f55c2f4aaa9086 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -191,6 +191,11 @@ void PixelShuffleInferMeta(const MetaTensor& x, const std::string& data_format, MetaTensor* out); +void PixelUnshuffleInferMeta(const MetaTensor& x, + int downscale_factor, + const std::string& data_format, + MetaTensor* out); + void PNormInferMeta(const MetaTensor& x, float porder, int axis, diff --git a/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc b/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc new file mode 100644 index 00000000000000..ef61fca35957e8 --- /dev/null +++ b/paddle/phi/kernels/cpu/pixel_unshuffle_grad_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" +#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle_grad, + CPU, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc b/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc new file mode 100644 index 00000000000000..9f4bc747f3209b --- /dev/null +++ b/paddle/phi/kernels/cpu/pixel_unshuffle_kernel.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" +#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle, + CPU, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu b/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu new file mode 100644 index 00000000000000..9cbbc5072aa256 --- /dev/null +++ b/paddle/phi/kernels/gpu/pixel_unshuffle_grad_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h" +#include "paddle/phi/kernels/pixel_unshuffle_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle_grad, + GPU, + ALL_LAYOUT, + phi::PixelUnshuffleGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu b/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu new file mode 100644 index 00000000000000..ca2e520ffde10e --- /dev/null +++ b/paddle/phi/kernels/gpu/pixel_unshuffle_kernel.cu @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h" +#include "paddle/phi/kernels/pixel_unshuffle_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(pixel_unshuffle, + GPU, + ALL_LAYOUT, + phi::PixelUnshuffleKernel, + float, + double) {} diff --git a/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h b/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h new file mode 100644 index 00000000000000..5674a9f7080504 --- /dev/null +++ b/paddle/phi/kernels/impl/pixel_unshuffle_grad_kernel_impl.h @@ -0,0 +1,58 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PixelUnshuffleGradKernel(const Context& ctx, + const DenseTensor& out_grad, + int upscale_factor, + const std::string& data_format, + DenseTensor* x_grad) { + auto* dout = &out_grad; + auto* dx = x_grad; + ctx.template Alloc(dx); + int factor = upscale_factor; + bool channel_last = (data_format == "NHWC"); + auto do_dims = dout->dims(); + auto dx_dims = dx->dims(); + + DenseTensor t(*dout); + if (!channel_last) { + t.Resize({do_dims[0], dx_dims[1], factor, factor, do_dims[2], do_dims[3]}); + } else { + t.Resize({do_dims[0], do_dims[1], do_dims[2], dx_dims[3], factor, factor}); + } + + std::vector axis = {0, 1, 4, 2, 5, 3}; + + DenseTensor o(*dx); + if (!channel_last) { + o.Resize({do_dims[0], dx_dims[1], do_dims[2], factor, do_dims[3], factor}); + } else { + o.Resize({do_dims[0], do_dims[1], factor, do_dims[2], factor, dx_dims[3]}); + } + phi::funcs::Transpose trans; + trans(ctx, t, &o, axis); + dx->Resize(dx_dims); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h b/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h new file mode 100644 index 00000000000000..f829c9e28a5f5b --- /dev/null +++ b/paddle/phi/kernels/impl/pixel_unshuffle_kernel_impl.h @@ -0,0 +1,57 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { + +template +void PixelUnshuffleKernel(const Context& ctx, + const DenseTensor& x, + int downscale_factor, + const std::string& data_format, + DenseTensor* out) { + auto* in = &x; + ctx.template Alloc(out); + int factor = downscale_factor; + bool channel_last = (data_format == "NHWC"); + auto in_dims = in->dims(); + auto o_dims = out->dims(); + + DenseTensor t(*in); + if (!channel_last) { + t.Resize({in_dims[0], in_dims[1], o_dims[2], factor, o_dims[3], factor}); + } else { + t.Resize({in_dims[0], o_dims[1], factor, o_dims[2], factor, in_dims[3]}); + } + std::vector axis = {0, 1, 3, 5, 2, 4}; + + DenseTensor o(*out); + if (!channel_last) { + o.Resize({in_dims[0], in_dims[1], factor, factor, o_dims[2], o_dims[3]}); + } else { + o.Resize({in_dims[0], o_dims[1], o_dims[2], in_dims[3], factor, factor}); + } + + phi::funcs::Transpose trans; + trans(ctx, t, &o, axis); + out->Resize(o_dims); +} + +} // namespace phi diff --git a/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h new file mode 100644 index 00000000000000..f62f1f5b4c7b72 --- /dev/null +++ b/paddle/phi/kernels/pixel_unshuffle_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PixelUnshuffleGradKernel(const Context& ctx, + const DenseTensor& out_grad, + int downscale_factor, + const std::string& data_format, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/pixel_unshuffle_kernel.h b/paddle/phi/kernels/pixel_unshuffle_kernel.h new file mode 100644 index 00000000000000..a631223034e96b --- /dev/null +++ b/paddle/phi/kernels/pixel_unshuffle_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void PixelUnshuffleKernel(const Context& ctx, + const DenseTensor& x, + int downscale_factor, + const std::string& data_format, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/pixel_unshuffle_sig.cc b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc new file mode 100644 index 00000000000000..e78b676dd629e2 --- /dev/null +++ b/paddle/phi/ops/compat/pixel_unshuffle_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature PixelUnshuffleOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "pixel_unshuffle", {"X"}, {"downscale_factor", "data_format"}, {"Out"}); +} + +KernelSignature PixelUnshuffleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("pixel_unshuffle_grad", + {GradVarName("Out")}, + {"downscale_factor", "data_format"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(pixel_unshuffle, + phi::PixelUnshuffleOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(pixel_unshuffle_grad, + phi::PixelUnshuffleGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py new file mode 100644 index 00000000000000..8668a2a9d0caa0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pixel_unshuffle.py @@ -0,0 +1,269 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +from op_test import OpTest +import paddle +import paddle.nn.functional as F +import paddle.fluid.core as core +import paddle.fluid as fluid + + +def pixel_unshuffle_np(x, down_factor, data_format="NCHW"): + if data_format == "NCHW": + n, c, h, w = x.shape + new_shape = (n, c, h // down_factor, down_factor, w // down_factor, + down_factor) + # reshape to (num,output_channel,h,downscale_factor,w,downscale_factor) + npresult = np.reshape(x, new_shape) + # transpose to (num,output_channel,downscale_factor,downscale_factor, h, w) + npresult = npresult.transpose(0, 1, 3, 5, 2, 4) + oshape = [ + n, c * (down_factor * down_factor), h // down_factor, + w // down_factor + ] + npresult = np.reshape(npresult, oshape) + return npresult + else: + n, h, w, c = x.shape + new_shape = (n, h // down_factor, down_factor, w // down_factor, + down_factor, c) + # reshape to (num,h,w,downscale_factor,downscale_factor, output_channel) + npresult = np.reshape(x, new_shape) + # transpose to (num,h, w, output_channel,downscale_factor,downscale_factor) + npresult = npresult.transpose(0, 1, 3, 5, 2, 4) + oshape = [ + n, h // down_factor, w // down_factor, + c * (down_factor * down_factor) + ] + npresult = np.reshape(npresult, oshape) + return npresult + + +class TestPixelUnshuffleOp(OpTest): + def setUp(self): + self.op_type = "pixel_unshuffle" + self.init_data_format() + n, c, h, w = 2, 1, 12, 12 + + if self.format == "NCHW": + shape = [n, c, h, w] + if self.format == "NHWC": + shape = [n, h, w, c] + + down_factor = 3 + + x = np.random.random(shape).astype("float64") + npresult = pixel_unshuffle_np(x, down_factor, self.format) + self.inputs = {'X': x} + self.outputs = {'Out': npresult} + self.attrs = { + 'downscale_factor': down_factor, + "data_format": self.format + } + + def init_data_format(self): + self.format = "NCHW" + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestChannelLast(TestPixelUnshuffleOp): + def init_data_format(self): + self.format = "NHWC" + + +class TestPixelUnshuffleAPI(unittest.TestCase): + def setUp(self): + self.x_1_np = np.random.random([2, 1, 12, 12]).astype("float64") + self.x_2_np = np.random.random([2, 12, 12, 1]).astype("float64") + self.out_1_np = pixel_unshuffle_np(self.x_1_np, 3) + self.out_2_np = pixel_unshuffle_np(self.x_2_np, 3, "NHWC") + + def test_static_graph_functional(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + + paddle.enable_static() + x_1 = paddle.fluid.data( + name="x", shape=[2, 1, 12, 12], dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=[2, 12, 12, 1], dtype="float64") + out_1 = F.pixel_unshuffle(x_1, 3) + out_2 = F.pixel_unshuffle(x_2, 3, "NHWC") + + exe = paddle.static.Executor(place=place) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": self.x_1_np}, + fetch_list=out_1, + use_prune=True) + + res_2 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_2_np}, + fetch_list=out_2, + use_prune=True) + + assert np.allclose(res_1, self.out_1_np) + assert np.allclose(res_2, self.out_2_np) + + # same test between layer and functional in this op. + def test_static_graph_layer(self): + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + + paddle.enable_static() + x_1 = paddle.fluid.data( + name="x", shape=[2, 1, 12, 12], dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=[2, 12, 12, 1], dtype="float64") + + # init instance + ps_1 = paddle.nn.PixelUnshuffle(3) + ps_2 = paddle.nn.PixelUnshuffle(3, "NHWC") + out_1 = ps_1(x_1) + out_2 = ps_2(x_2) + out_1_np = pixel_unshuffle_np(self.x_1_np, 3) + out_2_np = pixel_unshuffle_np(self.x_2_np, 3, "NHWC") + + exe = paddle.static.Executor(place=place) + res_1 = exe.run(fluid.default_main_program(), + feed={"x": self.x_1_np}, + fetch_list=out_1, + use_prune=True) + + res_2 = exe.run(fluid.default_main_program(), + feed={"x2": self.x_2_np}, + fetch_list=out_2, + use_prune=True) + + assert np.allclose(res_1, out_1_np) + assert np.allclose(res_2, out_2_np) + + def run_dygraph(self, down_factor, data_format): + n, c, h, w = 2, 1, 12, 12 + + if data_format == "NCHW": + shape = [n, c, h, w] + if data_format == "NHWC": + shape = [n, h, w, c] + + x = np.random.random(shape).astype("float64") + npresult = pixel_unshuffle_np(x, down_factor, data_format) + + for use_cuda in ([False, True] + if core.is_compiled_with_cuda() else [False]): + place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() + + paddle.disable_static(place=place) + + pixel_unshuffle = paddle.nn.PixelUnshuffle( + down_factor, data_format=data_format) + result = pixel_unshuffle(paddle.to_tensor(x)) + self.assertTrue(np.allclose(result.numpy(), npresult)) + result_functional = F.pixel_unshuffle( + paddle.to_tensor(x), 3, data_format) + self.assertTrue(np.allclose(result_functional.numpy(), npresult)) + + def test_dygraph1(self): + + self.run_dygraph(3, "NCHW") + + def test_dygraph2(self): + self.run_dygraph(3, "NHWC") + + +class TestPixelUnshuffleError(unittest.TestCase): + def test_error_functional(self): + def error_downscale_factor_negative(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), -1) + + def error_downscale_factor_zero(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), 0) + + def error_downscale_factor_float(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), 3.33) + + def error_downscale_factor_divide(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle(paddle.to_tensor(x), 5) + + self.assertRaises(ValueError, error_downscale_factor_negative) + self.assertRaises(ValueError, error_downscale_factor_zero) + self.assertRaises(TypeError, error_downscale_factor_float) + self.assertRaises(ValueError, error_downscale_factor_divide) + + def error_data_format(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = F.pixel_unshuffle( + paddle.to_tensor(x), 3, "WOW") + + self.assertRaises(ValueError, error_data_format) + + def test_error_layer(self): + def error_downscale_factor_negative(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = paddle.nn.PixelUnshuffle(-1) + pixel_unshuffle(x) + + def error_downscale_factor_zero(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = paddle.nn.PixelUnshuffle(0) + pixel_unshuffle(x) + + def error_downscale_factor_float(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = paddle.nn.PixelUnshuffle(3.33) + pixel_unshuffle(x) + + def error_downscale_factor_divide(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + pixel_unshuffle = paddle.nn.PixelUnshuffle(5) + pixel_unshuffle(x) + + self.assertRaises(ValueError, error_downscale_factor_negative) + self.assertRaises(ValueError, error_downscale_factor_zero) + self.assertRaises(TypeError, error_downscale_factor_float) + + def error_data_format_layer(): + with paddle.fluid.dygraph.guard(): + x = np.random.random([2, 1, 12, 12]).astype("float64") + ps = paddle.nn.PixelUnshuffle(3, "MEOW") + + self.assertRaises(ValueError, error_data_format_layer) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index c0820e140268b6..3b0eafa23cde9d 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -138,6 +138,7 @@ from .layer.vision import PixelShuffle # noqa: F401 from .layer.container import LayerDict # noqa: F401 +from .layer.vision import PixelUnshuffle # noqa: F401pix from .utils.spectral_norm_hook import spectral_norm @@ -306,4 +307,5 @@ def weight_norm(*args): 'MaxUnPool2D', 'MaxUnPool3D', 'HingeEmbeddingLoss', + 'PixelUnshuffle', ] diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index a24afc45a59951..b756b4c236bb26 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -118,6 +118,7 @@ from .input import embedding # noqa: F401 from ...fluid.layers import gather_tree # noqa: F401 from ...fluid.layers import temporal_shift # noqa: F401 +from .vision import pixel_unshuffle # noqa: F401 from .sparse_attention import sparse_attention @@ -224,4 +225,5 @@ 'class_center_sample', 'sparse_attention', 'fold', + 'pixel_unshuffle', ] diff --git a/python/paddle/nn/functional/vision.py b/python/paddle/nn/functional/vision.py index 43c7757a8777ba..56a177f9e0fbde 100644 --- a/python/paddle/nn/functional/vision.py +++ b/python/paddle/nn/functional/vision.py @@ -344,3 +344,88 @@ def pixel_shuffle(x, upscale_factor, data_format="NCHW", name=None): attrs={"upscale_factor": upscale_factor, "data_format": data_format}) return out + + +def pixel_unshuffle(x, downscale_factor, data_format="NCHW", name=None): + """ + PixelUnshuffle Layer + + Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements + in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape + :math:`(*, C \times r^2, H, W)`, where r is a downscale factor. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et. al (2016) for more details. + + Args: + downscale_factor (int): factor to decrease spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \times \text{downscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \div \text{downscale\_factor} + + .. math:: + W_{out} = W_{in} \div \text{downscale\_factor} + + + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + + x = np.random.randn(2, 1, 12, 12).astype(np.float32) + x_var = paddle.to_tensor(x) + pixel_unshuffle = nn.PixelUnshuffle(3) + out_var = pixel_unshuffle(x_var) + out = out_var.numpy() + print(out.shape) + # (2, 9, 4, 4) + + """ + if not isinstance(downscale_factor, int): + raise TypeError("downscale factor must be int type") + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'." + "But recevie Attr(data_format): {} ".format( + data_format)) + if downscale_factor < 1: + raise ValueError("downscale factor should not less than 1." + "But recevie downscale factor: {} ".format( + downscale_factor)) + + _, _, h, w = x.shape + if data_format == "NHWC": + _, h, w, _ = x.shape + + if h % downscale_factor != 0 or w % downscale_factor != 0: + raise ValueError( + "Both height and width should be divided by downscale_factor." + "But recevie downscale factor: {} ".format(downscale_factor)) + + if in_dynamic_mode(): + return _C_ops.pixel_unshuffle(x, "downscale_factor", downscale_factor, + "data_format", data_format) + + helper = LayerHelper("pixel_unshuffle", **locals()) + check_variable_and_dtype(x, 'X', ['float32', 'float64'], 'pixel_unshuffle') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type="pixel_unshuffle", + inputs={"X": x}, + outputs={"Out": out}, + attrs={ + "downscale_factor": downscale_factor, + "data_format": data_format + }) + return out diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 2b505080656050..799b61a927953e 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -89,5 +89,6 @@ from .vision import PixelShuffle # noqa: F401 from .distance import PairwiseDistance # noqa: F401 from .container import LayerDict # noqa: F401 +from .vision import PixelUnshuffle # noqa: F401 __all__ = [] diff --git a/python/paddle/nn/layer/vision.py b/python/paddle/nn/layer/vision.py index 0531afb4eeeeb9..3fdc1473fea847 100644 --- a/python/paddle/nn/layer/vision.py +++ b/python/paddle/nn/layer/vision.py @@ -87,3 +87,86 @@ def extra_repr(self): if self._name is not None: main_str += ', name={}'.format(self._name) return main_str + + +class PixelUnshuffle(Layer): + """ + + PixelUnshuffle Layer + + Reverses the :class:`~torch.nn.PixelShuffle` operation by rearranging elements + in a tensor of shape :math:`(*, C, H \times r, W \times r)` to a tensor of shape + :math:`(*, C \times r^2, H, W)`, where r is a downscale factor. + + See the paper: + `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ + by Shi et. al (2016) for more details. + + Args: + downscale_factor (int): factor to decrease spatial resolution by + + Shape: + - Input: :math:`(*, C_{in}, H_{in}, W_{in})`, where * is zero or more batch dimensions + - Output: :math:`(*, C_{out}, H_{out}, W_{out})`, where + + .. math:: + C_{out} = C_{in} \times \text{downscale\_factor}^2 + + .. math:: + H_{out} = H_{in} \div \text{downscale\_factor} + + .. math:: + W_{out} = W_{in} \div \text{downscale\_factor} + + + + Examples: + .. code-block:: python + + import paddle + import paddle.nn as nn + import numpy as np + + x = np.random.randn(2, 1, 12, 12).astype(np.float32) + x_var = paddle.to_tensor(x) + pixel_unshuffle = nn.PixelUnshuffle(3) + out_var = pixel_unshuffle(x_var) + out = out_var.numpy() + print(out.shape) + # (2, 9, 4, 4) + + """ + + def __init__(self, downscale_factor, data_format="NCHW", name=None): + super(PixelUnshuffle, self).__init__() + + if not isinstance(downscale_factor, int): + raise TypeError("downscale factor must be int type." + "But recevie downscale factor: {} ".format( + downscale_factor)) + + if downscale_factor < 1: + raise ValueError("downscale factor should not less than 1." + "But recevie downscale factor: {} ".format( + downscale_factor)) + + if data_format not in ["NCHW", "NHWC"]: + raise ValueError("Attr(data_format) should be 'NCHW' or 'NHWC'." + "But recevie Attr(data_format): {} ".format( + data_format)) + + self._downscale_factor = downscale_factor + self._data_format = data_format + self._name = name + + def forward(self, x): + return functional.pixel_unshuffle(x, self._downscale_factor, + self._data_format, self._name) + + def extra_repr(self): + main_str = 'downscale_factor={}'.format(self._downscale_factor) + if self._data_format != 'NCHW': + main_str += ', data_format={}'.format(self._data_format) + if self._name is not None: + main_str += ', name={}'.format(self._name) + return main_str diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 365047f7e8382a..36fcb6780fe117 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -732,4 +732,5 @@ 'test_pull_gpups_sparse_op', 'test_fused_gemm_epilogue_op', 'test_fused_gemm_epilogue_grad_op', + 'test_pixel_unshuffle', ]