From fd9c29ef86a28453ceed8644ab565797f95dc781 Mon Sep 17 00:00:00 2001 From: yoshoku Date: Fri, 19 Jan 2024 00:22:21 +0900 Subject: [PATCH] feat: add solve_triangular module function to TinyLinalg --- lib/numo/tiny_linalg.rb | 41 ++++++++++++++++++++++++++++++++++++++++ test/test_tiny_linalg.rb | 9 +++++++++ 2 files changed, 50 insertions(+) diff --git a/lib/numo/tiny_linalg.rb b/lib/numo/tiny_linalg.rb index 9cfa169..e03046c 100644 --- a/lib/numo/tiny_linalg.rb +++ b/lib/numo/tiny_linalg.rb @@ -399,6 +399,47 @@ def solve(a, b, driver: 'gen', uplo: 'U') # rubocop:disable Lint/UnusedMethodArg Numo::TinyLinalg::Lapack.send(gesv, a.dup, b.dup)[1] end + # Solves linear equation `A * x = b` or `A * X = B` for `x` assuming `A` is a triangular matrix. + # + # @example + # require 'numo/tiny_linalg' + # + # Numo::Linalg = Numo::TinyLinalg unless defined?(Numo::Linalg) + # + # a = Numo::DFloat.new(3, 3).rand.triu + # b = Numo::DFloat.eye(3) + # + # x = Numo::Linalg.solve(a, b) + # + # pp x + # # => + # # Numo::DFloat#shape=[3,3] + # # [[16.1932, -52.0604, 30.5283], + # # [0, 8.61765, -17.9585], + # # [0, 0, 6.05735]] + # + # pp (b - a.dot(x)).abs.max + # # => 4.071100642430302e-16 + # + # @param a [Numo::NArray] The n-by-n triangular matrix. + # @param b [Numo::NArray] The n right-hand side vector, or n-by-nrhs right-hand side matrix. + # @param lower [Boolean] The flag indicating whether to use the lower-triangular part of `a`. + # @return [Numo::NArray] The solusion vector / matrix `X`. + def solve_triangular(a, b, lower: false) + raise ArgumentError, 'input array a must be 2-dimensional' if a.ndim != 2 + raise ArgumentError, 'input array a must be square' if a.shape[0] != a.shape[1] + + bchr = blas_char(a, b) + raise ArgumentError, "invalid array type: #{a.class}, #{b.class}" if bchr == 'n' + + trtrs = :"#{bchr}trtrs" + uplo = lower ? 'L' : 'U' + x, info = Numo::TinyLinalg::Lapack.send(trtrs, a, b.dup, uplo: uplo) + raise "wrong value is given to the #{info}-th argument of #{trtrs} used internally" if info.negative? + + x + end + # Computes the Singular Value Decomposition (SVD) of a matrix: `A = U * S * V^T` # # @example diff --git a/test/test_tiny_linalg.rb b/test/test_tiny_linalg.rb index 65df684..0d36d92 100644 --- a/test/test_tiny_linalg.rb +++ b/test/test_tiny_linalg.rb @@ -208,6 +208,15 @@ def test_solve assert(error_ab < 1e-7) end + def test_solve_triangular + a = Numo::SFloat.new(3, 3).rand.triu + b = Numo::DComplex.new(3).rand + x = Numo::TinyLinalg.solve_triangular(a, b) + error_ab = (b - a.dot(x)).abs.max + + assert(error_ab < 1e-7) + end + def test_svd x = Numo::DFloat.new(5, 3).rand.dot(Numo::DFloat.new(3, 2).rand) s, u, vt, = Numo::TinyLinalg.svd(x, driver: 'sdd', job: 'S')