Skip to content

Commit

Permalink
support complex attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
yangguohao committed Nov 10, 2023
1 parent 30451eb commit 4db8899
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 33 deletions.
17 changes: 17 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4351,6 +4351,23 @@ void TileInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void PowInferMeta(const MetaTensor& x, const Scalar& y, MetaTensor* out) {
if (y.dtype() == DataType::COMPLEX128 &&
!(x.dtype() == DataType::COMPLEX64 ||
x.dtype() == DataType::COMPLEX128)) {
if (x.dtype() == DataType::FLOAT64) {
out->set_dtype(phi::DataType::COMPLEX128);
} else {
out->set_dtype(phi::DataType::COMPLEX64);
}
} else if (y.dtype() == DataType::FLOAT64 &&
(x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64)) {
out->set_dtype(phi::DataType::FLOAT32);
} else {
out->set_dtype(x.dtype());
}
}

void TopKInferMeta(const MetaTensor& x,
const Scalar& k_scalar,
int axis,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,8 @@ int GetSplitAxisValue(const MetaTensor& x,
const Scalar& axis,
MetaConfig config);

void PowInferMeta(const MetaTensor& x, const Scalar& y, MetaTensor* out);

void FillSplitOutDims(const MetaTensor& x,
const int axis_value,
const std::vector<int64_t>& sections_vec,
Expand Down
36 changes: 17 additions & 19 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2742,30 +2742,29 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {

template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
T factor;
using AttrPair = std::vector<std::pair<const char*, ELEMENT_TYPE*>>;

typename AttrPair GetAttrs() { return {{"factor", &factor}}; }
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.pow(static_cast<T>(factor)); // NOLINT
out.device(d) = x.template cast<T>().pow(factor); // NOLINT
}
};

template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
T factor;
using AttrPair = std::vector<std::pair<const char*, ELEMENT_TYPE*>>;

typename AttrPair GetAttrs() { return {{"factor", &factor}}; }
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor) - static_cast<T>(1));
dx.device(d) = dout * factor * x.pow(factor - static_cast<T>(1));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
Expand All @@ -2774,20 +2773,19 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct PowGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
float factor;
typename BaseActivationFunctor<ComplexType<T>>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
ComplexType<T> factor;
using AttrPair = std::vector<std::pair<const char*, ComplexType<T>*>>;

typename AttrPair GetAttrs() { return {{"factor", &factor}}; }
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<ComplexType<T>>(factor) *
x.pow(static_cast<ComplexType<T>>(factor) -
static_cast<ComplexType<T>>(1)))
.unaryExpr(Conj<T>());
dx.device(d) =
dout * (factor * x.pow(factor - static_cast<ComplexType<T>>(1)))
.unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/impl/activation_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,11 @@ void PowGradKernel(const Context& dev_ctx,
GET_DATA_SAFELY(dx, "Output", "X@GRAD", "PowGrad"));
auto x_flatten =
EigenVector<T>::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "PowGrad"));
std::cout << dout.dtype() << dx->dtype() << std::endl;
auto* place = dev_ctx.eigen_device();
phi::funcs::PowGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
*(attrs[0].second) = factor.to<T>();
functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten);
}

Expand Down
45 changes: 32 additions & 13 deletions paddle/phi/kernels/impl/activation_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"

namespace phi {

#define ToString(x) #x
Expand Down Expand Up @@ -62,23 +62,42 @@ void LogitKernel(const Context& dev_ctx,
functor(place, eigen_in, eigen_out, eigen_p, eps);
}

template <typename T, typename Context>
void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
template <typename InT, typename OutT, typename Context>
void PowImpl(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
PADDLE_ENFORCE_NOT_NULL(out,
errors::NotFound("Output Out should not be nullptr"));
dev_ctx.template Alloc<T>(out);
auto x_flatten = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(&x, "Input", "X", "Activation"));
auto out_flatten = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(out, "Output", "Out", "Activation"));
dev_ctx.template Alloc<OutT>(out, out->numel() * sizeof(OutT));
auto x_flatten = phi::EigenVector<InT>::Flatten(x);
auto out_flatten = phi::EigenVector<OutT>::Flatten(*out);
auto* place = dev_ctx.eigen_device();
phi::funcs::PowFunctor<T> functor;
phi::funcs::PowFunctor<OutT> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
*(attrs[0].second) = factor.to<OutT>();
functor(*place, x_flatten, out_flatten);
}

template <typename T, typename Context>
void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
if (factor.dtype() == DataType::COMPLEX128 &&
!(x.dtype() == DataType::COMPLEX64 ||
x.dtype() == DataType::COMPLEX128)) {
if (x.dtype() == DataType::FLOAT64) {
PowImpl<T, phi::dtype::complex<double>, Context>(dev_ctx, x, factor, out);
} else {
PowImpl<T, phi::dtype::complex<float>, Context>(dev_ctx, x, factor, out);
}
} else if (factor.dtype() == DataType::FLOAT64 &&
(x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64)) {
PowImpl<T, float, Context>(dev_ctx, x, factor, out);
} else {
PowImpl<T, T, Context>(dev_ctx, x, factor, out);
}
}

} // namespace phi

0 comments on commit 4db8899

Please sign in to comment.