From eff6a7ddb191608191d178d4e419dd97b1f65661 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Tue, 19 Sep 2023 13:42:53 -0500 Subject: [PATCH] [CUSOLVER] Add a method for geqrf! --- lib/cusolver/dense.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/cusolver/dense.jl b/lib/cusolver/dense.jl index 27712a0fd5..1b49291f55 100644 --- a/lib/cusolver/dense.jl +++ b/lib/cusolver/dense.jl @@ -144,7 +144,7 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F (:cusolverDnCgeqrf_bufferSize, :cusolverDnCgeqrf, :ComplexF32), (:cusolverDnZgeqrf_bufferSize, :cusolverDnZgeqrf, :ComplexF64)) @eval begin - function geqrf!(A::StridedCuMatrix{$elty}) + function geqrf!(A::StridedCuMatrix{$elty}, tau::CuVector{$elty}) m, n = size(A) lda = max(1, stride(A, 2)) @@ -154,7 +154,6 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F return out[] end - tau = CuArray{$elty}(undef, min(m, n)) devinfo = CuArray{Cint}(undef, 1) with_workspace($elty, bufferSize) do buffer $fname(dense_handle(), m, n, A, lda, tau, buffer, length(buffer), devinfo) @@ -166,6 +165,12 @@ for (bname, fname,elty) in ((:cusolverDnSgeqrf_bufferSize, :cusolverDnSgeqrf, :F A, tau end + + function geqrf!(A::StridedCuMatrix{$elty}) + m, n = size(A) + tau = CuArray{$elty}(undef, min(m, n)) + geqrf!(A, tau) + end end end @@ -891,6 +896,7 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64) LinearAlgebra.LAPACK.potri!(uplo::Char, A::StridedCuMatrix{$elty}) = CUSOLVER.potri!(uplo, A) LinearAlgebra.LAPACK.getrf!(A::StridedCuMatrix{$elty}) = CUSOLVER.getrf!(A) LinearAlgebra.LAPACK.geqrf!(A::StridedCuMatrix{$elty}) = CUSOLVER.geqrf!(A) + LinearAlgebra.LAPACK.geqrf!(A::StridedCuMatrix{$elty}, tau::CuVector{$elty}) = CUSOLVER.geqrf!(A, tau) LinearAlgebra.LAPACK.sytrf!(uplo::Char, A::StridedCuMatrix{$elty}) = sytrf!(uplo, A) LinearAlgebra.LAPACK.getrs!(trans::Char, A::StridedCuMatrix{$elty}, ipiv::CuVector{Cint}, B::StridedCuVecOrMat{$elty}) = CUSOLVER.getrs!(trans, A, ipiv, B) LinearAlgebra.LAPACK.ormqr!(side::Char, trans::Char, A::CuMatrix{$elty}, tau::CuVector{$elty}, C::CuVecOrMat{$elty}) = CUSOLVER.ormqr!(side, trans, A, tau, C)