Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lstsq: return correct array size #818

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/specs/stdlib_linalg.md
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ Result vector `x` returns the approximate solution that minimizes the 2-norm \(

`b`: Shall be a rank-1 or rank-2 array of the same kind as `a`, containing one or more right-hand-side vector(s), each in its leading dimension. It is an `intent(in)` argument.

`x`: Shall be an array of same kind and rank as `b`, containing the solution(s) to the least squares system. It is an `intent(inout)` argument.
`x`: Shall be an array of same kind and rank as `b`, and leading dimension of at least `n`, containing the solution(s) to the least squares system. It is an `intent(inout)` argument.

`real_storage` (optional): Shall be a `real` rank-1 array of the same kind `a`, providing working storage for the solver. It minimum size can be determined with a call to [[stdlib_linalg(module):lstsq_space(interface)]]. It is an `intent(inout)` argument.

Expand Down
53 changes: 41 additions & 12 deletions src/stdlib_linalg_least_squares.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
pure module subroutine stdlib_linalg_${ri}$_lstsq_space_${ndsuf}$(a,b,lrwork,liwork#{if rt.startswith('c')}#,lcwork#{endif}#)
!> Input matrix a[m,n]
${rt}$, intent(in), target :: a(:,:)
!> Right hand side vector or array, b[n] or b[n,nrhs]
!> Right hand side vector or array, b[m] or b[m,nrhs]
${rt}$, intent(in) :: b${nd}$
!> Size of the working space arrays
integer(ilp), intent(out) :: lrwork,liwork
Expand All @@ -111,7 +111,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
!! This function computes the least-squares solution of a linear matrix problem.
!!
!! param: a Input matrix of size [m,n].
!! param: b Right-hand-side vector of size [n] or matrix of size [n,nrhs].
!! param: b Right-hand-side vector of size [m] or matrix of size [m,nrhs].
!! param: cond [optional] Real input threshold indicating that singular values `s_i <= cond*maxval(s)`
!! do not contribute to the matrix rank.
!! param: overwrite_a [optional] Flag indicating if the input matrix can be overwritten.
Expand All @@ -121,7 +121,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
!!
!> Input matrix a[m,n]
${rt}$, intent(inout), target :: a(:,:)
!> Right hand side vector or array, b[n] or b[n,nrhs]
!> Right hand side vector or array, b[m] or b[m,nrhs]
${rt}$, intent(in) :: b${nd}$
!> [optional] cutoff for rank evaluation: singular values s(i)<=cond*maxval(s) are considered 0.
real(${rk}$), optional, intent(in) :: cond
Expand All @@ -134,9 +134,19 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
!> Result array/matrix x[n] or x[n,nrhs]
${rt}$, allocatable, target :: x${nd}$

! Initialize solution with the shape of the rhs
allocate(x,mold=b)
integer(ilp) :: n,nrhs,ldb

n = size(a,2,kind=ilp)
ldb = size(b,1,kind=ilp)
nrhs = size(b,kind=ilp)/ldb

! Initialize solution with the shape of the rhs
#:if ndsuf=="one"
allocate(x(n))
#:else
allocate(x(n,nrhs))
#:endif

call stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$(a,b,x,&
cond=cond,overwrite_a=overwrite_a,rank=rank,err=err)

Expand All @@ -155,7 +165,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
!!
!! param: a Input matrix of size [m,n].
!! param: b Right-hand-side vector of size [n] or matrix of size [n,nrhs].
!! param: x Solution vector of size [n] or solution matrix of size [n,nrhs].
!! param: x Solution vector of size at [>=n] or solution matrix of size [>=n,nrhs].
!! param: real_storage [optional] Real working space
!! param: int_storage [optional] Integer working space
#:if rt.startswith('c')
Expand Down Expand Up @@ -198,7 +208,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
integer(ilp) :: m,n,lda,ldb,nrhs,ldx,nrhsx,info,mnmin,mnmax,arank,lrwork,liwork,lcwork
integer(ilp) :: nrs,nis,ncs,nsvd
integer(ilp), pointer :: iwork(:)
logical(lk) :: copy_a
logical(lk) :: copy_a,large_enough_x
real(${rk}$) :: acond,rcond
real(${rk}$), pointer :: rwork(:),singular(:)
${rt}$, pointer :: xmat(:,:),amat(:,:),cwork(:)
Expand All @@ -214,8 +224,8 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
mnmin = min(m,n)
mnmax = max(m,n)

if (lda<1 .or. n<1 .or. ldb<1 .or. ldb/=m .or. ldx/=m) then
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
if (lda<1 .or. n<1 .or. ldb<1 .or. ldb/=m .or. ldx<n) then
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'insufficient sizes: a=',[lda,n], &
'b=',[ldb,nrhs],' x=',[ldx,nrhsx])
call linalg_error_handling(err0,err)
if (present(rank)) rank = 0
Expand All @@ -236,9 +246,19 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
amat => a
endif

! Initialize solution with the rhs
x = b
xmat(1:n,1:nrhs) => x
! If x is large enough to store b, use it as temporary rhs storage.
large_enough_x = ldx>=m
if (large_enough_x) then
xmat(1:ldx,1:nrhs) => x
else
allocate(xmat(m,nrhs))
endif

#:if ndsuf=="one"
xmat(1:m,1) = b
#:else
xmat(1:m,1:nrhs) = b
#:endif

! Singular values array (in decreasing order)
if (present(singvals)) then
Expand Down Expand Up @@ -316,7 +336,16 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
endif

! Process output and return
if (.not.large_enough_x) then
#:if ndsuf=="one"
x(1:n) = xmat(1:n,1)
#:else
x(1:n,1:nrhs) = xmat(1:n,1:nrhs)
#:endif
deallocate(xmat)
endif
if (copy_a) deallocate(amat)

if (present(rank)) rank = arank
if (.not.present(real_storage)) deallocate(rwork)
if (.not.present(int_storage)) deallocate(iwork)
Expand Down
7 changes: 6 additions & 1 deletion test/linalg/test_linalg_lstsq.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ module test_linalg_least_squares
type(linalg_state_type) :: state
integer(ilp), parameter :: n = 12, m = 3
real :: Arnd(n,m),xrnd(m)
${rt}$ :: xsol(m),x(m),y(n),A(n,m)
${rt}$, allocatable :: x(:)
${rt}$ :: xsol(m),y(n),A(n,m)

! Random coefficient matrix and solution
call random_number(Arnd)
Expand All @@ -88,6 +89,10 @@ module test_linalg_least_squares
call check(error,state%ok(),state%print())
if (allocated(error)) return

! Check size
call check(error,size(x)==m)
if (allocated(error)) return

call check(error, all(abs(x-xsol)<1.0e-4_${rk}$), 'data converged')
if (allocated(error)) return

Expand Down
Loading