Skip to content

Commit

Permalink
refactor: fix Lapack function selection by array type
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Aug 6, 2023
1 parent 41f1b0c commit 4bc09c5
Showing 1 changed file with 16 additions and 35 deletions.
51 changes: 16 additions & 35 deletions lib/numo/tiny_linalg.rb
Original file line number Diff line number Diff line change
Expand Up @@ -279,22 +279,20 @@ def qr(a, mode: 'reduce')
# pp (b - a.dot(x)).abs.max
# # => 2.1081041547796492e-16
#
# @param a [Numo::NArray] The n-by-n square matrix (>= 2-dimensinal NArray).
# @param a [Numo::NArray] The n-by-n square matrix.
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix (>= 1-dimensinal NArray).
# @param driver [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
# @param uplo [String] This argument is for compatibility with Numo::Linalg.solver, and is not used.
# @return [Numo::NArray] The solusion vector / matrix `x`.
def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArgument
case blas_char(a, b)
when 'd'
Lapack.dgesv(a.dup, b.dup)[1]
when 's'
Lapack.sgesv(a.dup, b.dup)[1]
when 'z'
Lapack.zgesv(a.dup, b.dup)[1]
when 'c'
Lapack.cgesv(a.dup, b.dup)[1]
end
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]

bchr = blas_char(a, b)
raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'

gesv = "#{bchr}gesv".to_sym
Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
end

# Calculates the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
Expand Down Expand Up @@ -336,33 +334,16 @@ def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArg
def svd(a, driver: 'svd', job: 'A')
raise ArgumentError, "invalid job: #{job}" unless /^[ASN]/i.match?(job.to_s)

bchr = blas_char(a)
raise ArgumentError, "invalid array type: #{a.class}" if bchr == 'n'

case driver.to_s
when 'sdd'
s, u, vt, info = case a
when Numo::DFloat
Numo::TinyLinalg::Lapack.dgesdd(a.dup, jobz: job)
when Numo::SFloat
Numo::TinyLinalg::Lapack.sgesdd(a.dup, jobz: job)
when Numo::DComplex
Numo::TinyLinalg::Lapack.zgesdd(a.dup, jobz: job)
when Numo::SComplex
Numo::TinyLinalg::Lapack.cgesdd(a.dup, jobz: job)
else
raise ArgumentError, "invalid array type: #{a.class}"
end
gesdd = "#{bchr}gesdd".to_sym
s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesdd, a.dup, jobz: job)
when 'svd'
s, u, vt, info = case a
when Numo::DFloat
Numo::TinyLinalg::Lapack.dgesvd(a.dup, jobu: job, jobvt: job)
when Numo::SFloat
Numo::TinyLinalg::Lapack.sgesvd(a.dup, jobu: job, jobvt: job)
when Numo::DComplex
Numo::TinyLinalg::Lapack.zgesvd(a.dup, jobu: job, jobvt: job)
when Numo::SComplex
Numo::TinyLinalg::Lapack.cgesvd(a.dup, jobu: job, jobvt: job)
else
raise ArgumentError, "invalid array type: #{a.class}"
end
gesvd = "#{bchr}gesvd".to_sym
s, u, vt, info = Numo::TinyLinalg::Lapack.send(gesvd, a.dup, jobu: job, jobvt: job)
else
raise ArgumentError, "invalid driver: #{driver}"
end
Expand Down

0 comments on commit 4bc09c5

Please sign in to comment.