Skip to content

Commit

Permalink
feat: add solve_triangular module function to TinyLinalg
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Jan 18, 2024
1 parent 07786bb commit fd9c29e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
41 changes: 41 additions & 0 deletions lib/numo/tiny_linalg.rb
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,47 @@ def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArg
Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1]
end

# Solves linear equation `A * x = b` or `A * X = B` for `x` assuming `A` is a triangular matrix.
#
# @example
# require 'numo/tiny_linalg'
#
# Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg)
#
# a = Numo::DFloat.new(3, 3).rand.triu
# b = Numo::DFloat.eye(3)
#
# x = Numo::Linalg.solve(a, b)
#
# pp x
# # =>
# # Numo::DFloat#shape=[3,3]
# # [[16.1932, -52.0604, 30.5283],
# # [0, 8.61765, -17.9585],
# # [0, 0, 6.05735]]
#
# pp (b - a.dot(x)).abs.max
# # => 4.071100642430302e-16
#
# @param a [Numo::NArray] The n-by-n triangular matrix.
# @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix.
# @param lower [Boolean] The flag indicating whether to use the lower-triangular part of `a`.
# @return [Numo::NArray] The solusion vector / matrix `X`.
def solve_triangular(a, b, lower: false)
raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2
raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1]

bchr = blas_char(a, b)
raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n'

trtrs = :"#{bchr}trtrs"
uplo = lower ? 'L' : 'U'
x, info = Numo::TinyLinalg::Lapack.send(trtrs, a, b.dup, uplo: uplo)
raise "wrong value is given to the #{info}-th argument of #{trtrs} used internally" if info.negative?

x
end

# Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T`
#
# @example
Expand Down
9 changes: 9 additions & 0 deletions test/test_tiny_linalg.rb
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,15 @@ def test_solve
assert(error_ab < 1e-7)
end

def test_solve_triangular
a = Numo::SFloat.new(3, 3).rand.triu
b = Numo::DComplex.new(3).rand
x = Numo::TinyLinalg.solve_triangular(a, b)
error_ab = (b - a.dot(x)).abs.max

assert(error_ab < 1e-7)
end

def test_svd
x = Numo::DFloat.new(5, 3).rand.dot(Numo::DFloat.new(3, 2).rand)
s, u, vt, = Numo::TinyLinalg.svd(x, driver: 'sdd', job: 'S')
Expand Down

0 comments on commit fd9c29e

Please sign in to comment.