Skip to content

Commit

Permalink
modify Ops to complex template (#33041)
Browse files Browse the repository at this point in the history
* modify conj, real, imag OP to complex template

* replace with complex template to dot Op

* replace with complex template to Abs Op

* add support for complex64 and complex128
  • Loading branch information
MingMingShangTian authored May 25, 2021
1 parent 86ea8dc commit 5fa44c3
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 95 deletions.
12 changes: 6 additions & 6 deletions paddle/fluid/operators/abs_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,19 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::AbsKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);

REGISTER_OP_CPU_KERNEL(
abs_grad, ops::AbsGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);

REGISTER_OP_CPU_KERNEL(
abs_grad_grad,
Expand All @@ -187,6 +187,6 @@ REGISTER_OP_CPU_KERNEL(
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
12 changes: 6 additions & 6 deletions paddle/fluid/operators/abs_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,23 @@ REGISTER_OP_CUDA_KERNEL(
ops::AbsKernel<plat::CUDADeviceContext, int>,
ops::AbsKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex128>);
ops::AbsKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsKernel<plat::CUDADeviceContext, plat::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
abs_grad, ops::AbsGradKernel<plat::CUDADeviceContext, float>,
ops::AbsGradKernel<plat::CUDADeviceContext, double>,
ops::AbsGradKernel<plat::CUDADeviceContext, int>,
ops::AbsGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
abs_grad_grad, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex128>);
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
4 changes: 2 additions & 2 deletions paddle/fluid/operators/conj_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ REGISTER_OPERATOR(conj, ops::ConjOp, ops::ConjOpMaker,

REGISTER_OP_CPU_KERNEL(
conj, ops::ConjKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ConjKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>,
paddle::platform::complex<double>>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, float>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConjKernel<paddle::platform::CPUDeviceContext, int>,
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/operators/conj_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
// limitations under the License.

#include "paddle/fluid/operators/conj_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
conj, ops::ConjKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ConjKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>,
paddle::platform::complex<double>>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, float>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, double>,
ops::ConjKernel<paddle::platform::CUDADeviceContext, int>,
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/operators/dot_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DotOp : public framework::OperatorWithKernel {
"Output(Out) of DotOp should not be null."));

auto x_dims = ctx->GetInputDim("X");
auto x_rank = (size_t)x_dims.size();
auto x_rank = static_cast<size_t>(x_dims.size());
PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank,
platform::errors::PreconditionNotMet(
"ShapeError: The dimensions of input tensor X (%s) "
Expand Down Expand Up @@ -154,15 +154,15 @@ REGISTER_OP_CPU_KERNEL(
ops::DotKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::DotKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
20 changes: 11 additions & 9 deletions paddle/fluid/operators/dot_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ REGISTER_OP_CUDA_KERNEL(
ops::DotKernel<plat::CUDADeviceContext, double>,
ops::DotKernel<plat::CUDADeviceContext, int>,
ops::DotKernel<plat::CUDADeviceContext, int64_t>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
dot_grad, ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex<float>>,
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(dot_grad,
ops::DotGradKernel<plat::CUDADeviceContext, float>,
ops::DotGradKernel<plat::CUDADeviceContext, double>,
ops::DotGradKernel<plat::CUDADeviceContext, int>,
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
ops::DotGradKernel<plat::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::DotGradKernel<plat::CUDADeviceContext,
paddle::platform::complex<double>>);
8 changes: 4 additions & 4 deletions paddle/fluid/operators/imag_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker,
REGISTER_OPERATOR(imag_grad, ops::ImagGradOp);

REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ImagKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(imag_grad,
ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
8 changes: 4 additions & 4 deletions paddle/fluid/operators/imag_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(imag,
ops::ImagKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ImagKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(imag_grad,
ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
101 changes: 54 additions & 47 deletions paddle/fluid/operators/math/complex_functors.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ limitations under the License. */

#include <type_traits>

#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/hostdevice.h"

namespace paddle {
Expand Down Expand Up @@ -66,7 +65,10 @@ using select_t = typename select<Head, Tail...>::type;
template <typename T>
using Real =
select_t<cond<std::is_same<T, platform::complex64>::value, float>,
cond<std::is_same<T, platform::complex128>::value, double>, T>;
cond<std::is_same<T, platform::complex128>::value, double>,
cond<std::is_same<T, platform::complex<float>>::value, float>,
cond<std::is_same<T, platform::complex<double>>::value, double>,
T>;

template <typename T, typename RealT>
using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;
Expand All @@ -76,14 +78,18 @@ template <typename T, typename RealT>
using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;

template <typename T>
using EnableComplex =
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value>::type;
using EnableComplex = typename std::enable_if<
std::is_same<T, platform::complex64>::value ||
std::is_same<T, platform::complex128>::value ||
std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type;

template <typename T>
using DisableComplex = typename std::enable_if<
!std::is_same<T, platform::complex64>::value &&
!std::is_same<T, platform::complex128>::value>::type;
!std::is_same<T, platform::complex128>::value &&
!std::is_same<T, platform::complex<float>>::value &&
!std::is_same<T, platform::complex<double>>::value>::type;

template <typename T, typename Enable = void>
struct RealFunctor;
Expand Down Expand Up @@ -173,44 +179,45 @@ struct AbsGradFunctor {
};

template <>
struct AbsGradFunctor<paddle::platform::complex64> {
AbsGradFunctor(const float* dout, const paddle::platform::complex64* x,
paddle::platform::complex64* output, int64_t numel)
struct AbsGradFunctor<paddle::platform::complex<float>> {
AbsGradFunctor(const float* dout, const paddle::platform::complex<float>* x,
paddle::platform::complex<float>* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex64(0)) {
output_[idx] = paddle::platform::complex64(0);
if (x_[idx] == paddle::platform::complex<float>(0)) {
output_[idx] = paddle::platform::complex<float>(0);
} else {
output_[idx] = paddle::platform::complex64(dout_[idx]) *
(x_[idx] / paddle::platform::complex64(abs(x_[idx])));
output_[idx] = paddle::platform::complex<float>(dout_[idx]) *
(x_[idx] / paddle::platform::complex<float>(abs(x_[idx])));
}
}

const float* dout_;
const paddle::platform::complex64* x_;
paddle::platform::complex64* output_;
const paddle::platform::complex<float>* x_;
paddle::platform::complex<float>* output_;
int64_t numel_;
};

template <>
struct AbsGradFunctor<paddle::platform::complex128> {
AbsGradFunctor(const double* dout, const paddle::platform::complex128* x,
paddle::platform::complex128* output, int64_t numel)
struct AbsGradFunctor<paddle::platform::complex<double>> {
AbsGradFunctor(const double* dout, const paddle::platform::complex<double>* x,
paddle::platform::complex<double>* output, int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex128(0)) {
output_[idx] = paddle::platform::complex128(0);
if (x_[idx] == paddle::platform::complex<double>(0)) {
output_[idx] = paddle::platform::complex<double>(0);
} else {
output_[idx] = paddle::platform::complex128(dout_[idx]) *
(x_[idx] / paddle::platform::complex128(abs(x_[idx])));
output_[idx] =
paddle::platform::complex<double>(dout_[idx]) *
(x_[idx] / paddle::platform::complex<double>(abs(x_[idx])));
}
}

const double* dout_;
const paddle::platform::complex128* x_;
paddle::platform::complex128* output_;
const paddle::platform::complex<double>* x_;
paddle::platform::complex<double>* output_;
int64_t numel_;
};

Expand All @@ -234,46 +241,46 @@ struct AbsGradGradFunctor {
};

template <>
struct AbsGradGradFunctor<paddle::platform::complex128> {
AbsGradGradFunctor(const paddle::platform::complex128* ddx,
const paddle::platform::complex128* x,
paddle::platform::complex128* output, int64_t numel)
struct AbsGradGradFunctor<paddle::platform::complex<double>> {
AbsGradGradFunctor(const paddle::platform::complex<double>* ddx,
const paddle::platform::complex<double>* x,
paddle::platform::complex<double>* output, int64_t numel)
: ddx_(ddx), x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex128(0)) {
output_[idx] = paddle::platform::complex128(0);
if (x_[idx] == paddle::platform::complex<double>(0)) {
output_[idx] = paddle::platform::complex<double>(0);
} else {
output_[idx] = paddle::platform::complex128(ddx_[idx]) * x_[idx] /
paddle::platform::complex128(abs(x_[idx]));
output_[idx] = paddle::platform::complex<double>(ddx_[idx]) * x_[idx] /
paddle::platform::complex<double>(abs(x_[idx]));
}
}

const paddle::platform::complex128* ddx_;
const paddle::platform::complex128* x_;
paddle::platform::complex128* output_;
const paddle::platform::complex<double>* ddx_;
const paddle::platform::complex<double>* x_;
paddle::platform::complex<double>* output_;
int64_t numel_;
};

template <>
struct AbsGradGradFunctor<paddle::platform::complex64> {
AbsGradGradFunctor(const paddle::platform::complex64* ddx,
const paddle::platform::complex64* x,
paddle::platform::complex64* output, int64_t numel)
struct AbsGradGradFunctor<paddle::platform::complex<float>> {
AbsGradGradFunctor(const paddle::platform::complex<float>* ddx,
const paddle::platform::complex<float>* x,
paddle::platform::complex<float>* output, int64_t numel)
: ddx_(ddx), x_(x), output_(output), numel_(numel) {}

HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == paddle::platform::complex64(0)) {
output_[idx] = paddle::platform::complex64(0);
if (x_[idx] == paddle::platform::complex<float>(0)) {
output_[idx] = paddle::platform::complex<float>(0);
} else {
output_[idx] = paddle::platform::complex64(ddx_[idx]) * x_[idx] /
paddle::platform::complex64(abs(x_[idx]));
output_[idx] = paddle::platform::complex<float>(ddx_[idx]) * x_[idx] /
paddle::platform::complex<float>(abs(x_[idx]));
}
}

const paddle::platform::complex64* ddx_;
const paddle::platform::complex64* x_;
paddle::platform::complex64* output_;
const paddle::platform::complex<float>* ddx_;
const paddle::platform::complex<float>* x_;
paddle::platform::complex<float>* output_;
int64_t numel_;
};
template <typename T, typename Enable = void>
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/real_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker,
REGISTER_OPERATOR(real_grad, ops::RealGradOp);

REGISTER_OP_CPU_KERNEL(real, ops::RealKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::RealKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(real_grad,
ops::RealGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::RealGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
8 changes: 4 additions & 4 deletions paddle/fluid/operators/real_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(real,
ops::RealKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::RealKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(real_grad,
ops::RealGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
paddle::platform::complex<float>>,
ops::RealGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
paddle::platform::complex<double>>);

0 comments on commit 5fa44c3

Please sign in to comment.