Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【complex op】NO.40 add complex support for inv (cuda only) #57862

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/fluid/platform/dynload/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,14 @@ namespace dynload {
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched); \
__macro(cublasCgetrfBatched); \
__macro(cublasCgetriBatched); \
__macro(cublasZgetrfBatched); \
__macro(cublasZgetriBatched); \
__macro(cublasSmatinvBatched); \
__macro(cublasDmatinvBatched); \
__macro(cublasCmatinvBatched); \
__macro(cublasZmatinvBatched); \
__macro(cublasSgetrsBatched); \
__macro(cublasDgetrsBatched);

Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/backends/dynload/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,14 @@ extern void *cublas_dso_handle;
__macro(cublasSgetriBatched); \
__macro(cublasDgetrfBatched); \
__macro(cublasDgetriBatched); \
__macro(cublasCgetrfBatched); \
__macro(cublasCgetriBatched); \
__macro(cublasZgetrfBatched); \
__macro(cublasZgetriBatched); \
__macro(cublasSmatinvBatched); \
__macro(cublasDmatinvBatched); \
__macro(cublasCmatinvBatched); \
__macro(cublasZmatinvBatched); \
__macro(cublasSgetrsBatched); \
__macro(cublasDgetrsBatched);

Expand Down
114 changes: 114 additions & 0 deletions paddle/phi/kernels/funcs/blas/blas_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,63 @@ struct CUBlas<phi::dtype::complex<float>> {
ldb,
batch_size));
}

static void MATINV_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<float> **A,
int lda,
phi::dtype::complex<float> **Ainv,
int lda_inv,
int *info,
int batchSize) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCmatinvBatched(
handle,
n,
reinterpret_cast<const cuFloatComplex **>(A),
lda,
reinterpret_cast<cuFloatComplex **>(Ainv),
lda_inv,
info,
batchSize));
}

static void GETRF_BATCH(cublasHandle_t handle,
int n,
phi::dtype::complex<float> **A,
int lda,
int *P,
int *info,
int batchSize) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetrfBatched(
handle,
n,
reinterpret_cast<cuFloatComplex **>(A),
lda,
P,
info,
batchSize));
}

static void GETRI_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<float> **A,
int lda,
const int *P,
phi::dtype::complex<float> **C,
int ldc,
int *info,
int batchSize) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasCgetriBatched(
handle,
n,
reinterpret_cast<const cuFloatComplex **>(A),
lda,
P,
reinterpret_cast<cuFloatComplex **>(C),
ldc,
info,
batchSize));
}
};

template <>
Expand Down Expand Up @@ -924,6 +981,63 @@ struct CUBlas<phi::dtype::complex<double>> {
"cublasGemmEx is not supported on cuda <= 7.5"));
#endif
}

static void MATINV_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<double> **A,
int lda,
phi::dtype::complex<double> **Ainv,
int lda_inv,
int *info,
int batchSize) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZmatinvBatched(
handle,
n,
reinterpret_cast<const cuDoubleComplex **>(A),
lda,
reinterpret_cast<cuDoubleComplex **>(Ainv),
lda_inv,
info,
batchSize));
}

static void GETRF_BATCH(cublasHandle_t handle,
int n,
phi::dtype::complex<double> **A,
int lda,
int *P,
int *info,
int batchSize) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetrfBatched(
handle,
n,
reinterpret_cast<cuDoubleComplex **>(A),
lda,
P,
info,
batchSize));
}

static void GETRI_BATCH(cublasHandle_t handle,
int n,
const phi::dtype::complex<double> **A,
int lda,
const int *P,
phi::dtype::complex<double> **C,
int ldc,
int *info,
int batchSize) {
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasZgetriBatched(
handle,
n,
reinterpret_cast<const cuDoubleComplex **>(A),
lda,
P,
reinterpret_cast<cuDoubleComplex **>(C),
ldc,
info,
batchSize));
}
};

template <>
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_inverse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ void MatrixInverseFunctor<Context, T>::operator()(const Context& dev_ctx,

template class MatrixInverseFunctor<CPUContext, float>;
template class MatrixInverseFunctor<CPUContext, double>;
template class MatrixInverseFunctor<CPUContext, phi::dtype::complex<float>>;
template class MatrixInverseFunctor<CPUContext, phi::dtype::complex<double>>;

} // namespace funcs
} // namespace phi
2 changes: 2 additions & 0 deletions paddle/phi/kernels/funcs/matrix_inverse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ void MatrixInverseFunctor<Context, T>::operator()(const Context& dev_ctx,

template class MatrixInverseFunctor<GPUContext, float>;
template class MatrixInverseFunctor<GPUContext, double>;
template class MatrixInverseFunctor<GPUContext, phi::dtype::complex<float>>;
template class MatrixInverseFunctor<GPUContext, phi::dtype::complex<double>>;

} // namespace funcs
} // namespace phi
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/inverse_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/inverse_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
inverse_grad, GPU, ALL_LAYOUT, phi::InverseGradKernel, float, double) {}
PD_REGISTER_KERNEL(inverse_grad,
GPU,
ALL_LAYOUT,
phi::InverseGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
10 changes: 8 additions & 2 deletions paddle/phi/kernels/gpu/inverse_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,11 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/inverse_kernel_impl.h"

PD_REGISTER_KERNEL(
inverse, GPU, ALL_LAYOUT, phi::InverseKernel, float, double) {}
PD_REGISTER_KERNEL(inverse,
GPU,
ALL_LAYOUT,
phi::InverseKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
37 changes: 37 additions & 0 deletions test/legacy_test/test_inverse_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def setUp(self):

np.random.seed(123)
mat = np.random.random(self.matrix_shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
mat = (
np.random.random(self.matrix_shape)
+ 1j * np.random.random(self.matrix_shape)
).astype(self.dtype)
inverse = np.linalg.inv(mat)

self.inputs = {'Input': mat}
Expand All @@ -46,6 +51,38 @@ def test_grad(self):
self.check_grad(['Input'], 'Output')


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not complied with CUDA"
)
class TestInverseOp_Complex64(TestInverseOp):
def config(self):
self.matrix_shape = [10, 10]
self.dtype = np.complex64
self.python_api = paddle.tensor.math.inverse

def test_check_output(self):
self.check_output_with_place(paddle.CUDAPlace())

def test_grad(self):
self.check_grad_with_place(paddle.CUDAPlace(), ['Input'], 'Output')


@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not complied with CUDA"
)
class TestInverseOp_Complex128(TestInverseOp):
def config(self):
self.matrix_shape = [10, 10]
self.dtype = np.complex128
self.python_api = paddle.tensor.math.inverse

def test_check_output(self):
self.check_output_with_place(paddle.CUDAPlace())

def test_grad(self):
self.check_grad_with_place(paddle.CUDAPlace(), ['Input'], 'Output')


class TestInverseOpBatched(TestInverseOp):
def config(self):
self.matrix_shape = [8, 4, 4]
Expand Down