Skip to content

Commit

Permalink
feat: add dsyevd, ssyevd, zheevd, and cheevd module functions to Tiny…
Browse files Browse the repository at this point in the history
…Linalg::Lapack
  • Loading branch information
yoshoku committed Aug 7, 2023
1 parent 9c8e98d commit d07a5de
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 0 deletions.
87 changes: 87 additions & 0 deletions ext/numo/tiny_linalg/lapack/heevd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
namespace TinyLinalg {

struct ZHeEvd {
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_double* a, lapack_int lda, double* w) {
return LAPACKE_zheevd(matrix_layout, jobz, uplo, n, a, lda, w);
}
};

struct CHeEvd {
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, lapack_complex_float* a, lapack_int lda, float* w) {
return LAPACKE_cheevd(matrix_layout, jobz, uplo, n, a, lda, w);
}
};

template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
class HeEvd {
public:
static void define_module_function(VALUE mLapack, const char* fnc_name) {
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_heevd), -1);
}

private:
struct heevd_opt {
int matrix_layout;
char jobz;
char uplo;
};

static void iter_heevd(na_loop_t* const lp) {
dtype* a = (dtype*)NDL_PTR(lp, 0);
rtype* w = (rtype*)NDL_PTR(lp, 1);
int* info = (int*)NDL_PTR(lp, 2);
heevd_opt* opt = (heevd_opt*)(lp->opt_ptr);
const lapack_int n = NDL_SHAPE(lp, 0)[1];
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
*info = static_cast<int>(i);
}

static VALUE tiny_linalg_heevd(int argc, VALUE* argv, VALUE self) {
VALUE nary_dtype = NaryTypes[nary_dtype_id];
VALUE nary_rtype = NaryTypes[nary_rtype_id];

VALUE a_vnary = Qnil;
VALUE kw_args = Qnil;
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;

if (CLASS_OF(a_vnary) != nary_dtype) {
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
}
if (!RTEST(nary_check_contiguous(a_vnary))) {
a_vnary = nary_dup(a_vnary);
}

narray_t* a_nary = nullptr;
GetNArray(a_vnary, a_nary);
if (NA_NDIM(a_nary) != 2) {
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
return Qnil;
}
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
rb_raise(rb_eArgError, "input array a must be square");
return Qnil;
}

const size_t n = NA_SHAPE(a_nary)[1];
size_t shape[1] = { n };
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
ndfunc_arg_out_t aout[2] = { { nary_rtype, 1, shape }, { numo_cInt32, 0 } };
ndfunc_t ndf = { iter_heevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
heevd_opt opt = { matrix_layout, jobz, uplo };
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));

RB_GC_GUARD(a_vnary);

return ret;
}
};

} // namespace TinyLinalg
86 changes: 86 additions & 0 deletions ext/numo/tiny_linalg/lapack/syevd.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
namespace TinyLinalg {

struct DSyEvd {
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, double* a, lapack_int lda, double* w) {
return LAPACKE_dsyevd(matrix_layout, jobz, uplo, n, a, lda, w);
}
};

struct SSyEvd {
lapack_int call(int matrix_layout, char jobz, char uplo, lapack_int n, float* a, lapack_int lda, float* w) {
return LAPACKE_ssyevd(matrix_layout, jobz, uplo, n, a, lda, w);
}
};

template <int nary_dtype_id, typename dtype, class LapackFn>
class SyEvd {
public:
static void define_module_function(VALUE mLapack, const char* fnc_name) {
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_syevd), -1);
}

private:
struct syevd_opt {
int matrix_layout;
char jobz;
char uplo;
};

static void iter_syevd(na_loop_t* const lp) {
dtype* a = (dtype*)NDL_PTR(lp, 0);
dtype* w = (dtype*)NDL_PTR(lp, 1);
int* info = (int*)NDL_PTR(lp, 2);
syevd_opt* opt = (syevd_opt*)(lp->opt_ptr);
const lapack_int n = NDL_SHAPE(lp, 0)[1];
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->jobz, opt->uplo, n, a, lda, w);
*info = static_cast<int>(i);
}

static VALUE tiny_linalg_syevd(int argc, VALUE* argv, VALUE self) {
VALUE nary_dtype = NaryTypes[nary_dtype_id];

VALUE a_vnary = Qnil;
VALUE kw_args = Qnil;
rb_scan_args(argc, argv, "1:", &a_vnary, &kw_args);
ID kw_table[3] = { rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
rb_get_kwargs(kw_args, kw_table, 0, 3, kw_values);
const char jobz = kw_values[0] != Qundef ? Util().get_jobz(kw_values[0]) : 'V';
const char uplo = kw_values[1] != Qundef ? Util().get_uplo(kw_values[1]) : 'U';
const int matrix_layout = kw_values[2] != Qundef ? Util().get_matrix_layout(kw_values[2]) : LAPACK_ROW_MAJOR;

if (CLASS_OF(a_vnary) != nary_dtype) {
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
}
if (!RTEST(nary_check_contiguous(a_vnary))) {
a_vnary = nary_dup(a_vnary);
}

narray_t* a_nary = nullptr;
GetNArray(a_vnary, a_nary);
if (NA_NDIM(a_nary) != 2) {
rb_raise(rb_eArgError, "input array a must be 2-dimensional");
return Qnil;
}
if (NA_SHAPE(a_nary)[0] != NA_SHAPE(a_nary)[1]) {
rb_raise(rb_eArgError, "input array a must be square");
return Qnil;
}

const size_t n = NA_SHAPE(a_nary)[1];
size_t shape[1] = { n };
ndfunc_arg_in_t ain[1] = { { OVERWRITE, 2 } };
ndfunc_arg_out_t aout[2] = { { nary_dtype, 1, shape }, { numo_cInt32, 0 } };
ndfunc_t ndf = { iter_syevd, NO_LOOP | NDF_EXTRACT, 1, 2, ain, aout };
syevd_opt opt = { matrix_layout, jobz, uplo };
VALUE res = na_ndloop3(&ndf, &opt, 1, a_vnary);
VALUE ret = rb_ary_new3(3, a_vnary, rb_ary_entry(res, 0), rb_ary_entry(res, 1));

RB_GC_GUARD(a_vnary);

return ret;
}
};

} // namespace TinyLinalg
6 changes: 6 additions & 0 deletions ext/numo/tiny_linalg/tiny_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@
#include "lapack/getrf.hpp"
#include "lapack/getri.hpp"
#include "lapack/heev.hpp"
#include "lapack/heevd.hpp"
#include "lapack/hegv.hpp"
#include "lapack/hegvd.hpp"
#include "lapack/hegvx.hpp"
#include "lapack/orgqr.hpp"
#include "lapack/syev.hpp"
#include "lapack/syevd.hpp"
#include "lapack/sygv.hpp"
#include "lapack/sygvd.hpp"
#include "lapack/sygvx.hpp"
Expand Down Expand Up @@ -318,6 +320,10 @@ extern "C" void Init_tiny_linalg(void) {
TinyLinalg::SyEv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEv>::define_module_function(rb_mTinyLinalgLapack, "ssyev");
TinyLinalg::HeEv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEv>::define_module_function(rb_mTinyLinalgLapack, "zheev");
TinyLinalg::HeEv<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEv>::define_module_function(rb_mTinyLinalgLapack, "cheev");
TinyLinalg::SyEvd<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyEvd>::define_module_function(rb_mTinyLinalgLapack, "dsyevd");
TinyLinalg::SyEvd<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyEvd>::define_module_function(rb_mTinyLinalgLapack, "ssyevd");
TinyLinalg::HeEvd<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeEvd>::define_module_function(rb_mTinyLinalgLapack, "zheevd");
TinyLinalg::HeEvd<TinyLinalg::numo_cSComplexId, TinyLinalg::numo_cSFloatId, lapack_complex_float, float, TinyLinalg::CHeEvd>::define_module_function(rb_mTinyLinalgLapack, "cheevd");
TinyLinalg::SyGv<TinyLinalg::numo_cDFloatId, double, TinyLinalg::DSyGv>::define_module_function(rb_mTinyLinalgLapack, "dsygv");
TinyLinalg::SyGv<TinyLinalg::numo_cSFloatId, float, TinyLinalg::SSyGv>::define_module_function(rb_mTinyLinalgLapack, "ssygv");
TinyLinalg::HeGv<TinyLinalg::numo_cDComplexId, TinyLinalg::numo_cDFloatId, lapack_complex_double, double, TinyLinalg::ZHeGv>::define_module_function(rb_mTinyLinalgLapack, "zhegv");
Expand Down
40 changes: 40 additions & 0 deletions test/test_tiny_linalg_lapack.rb
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,46 @@ def test_lapack_cheev
assert(error < 1e-5)
end

def test_lapack_dsyevd
n = 5
a = Numo::DFloat.new(n, n).rand - 0.5
c = 0.5 * (a.transpose + a)
v, w, _info = Numo::TinyLinalg::Lapack.dsyevd(c.dup, jobz: 'V', uplo: 'U')
error = (c - v.dot(w.diag).dot(v.transpose)).abs.max

assert(error < 1e-7)
end

def test_lapack_ssyevd
n = 5
a = Numo::SFloat.new(n, n).rand - 0.5
c = 0.5 * (a.transpose + a)
v, w, _info = Numo::TinyLinalg::Lapack.ssyevd(c.dup, jobz: 'V', uplo: 'U')
error = (c - v.dot(w.diag).dot(v.transpose)).abs.max

assert(error < 1e-5)
end

def test_lapack_zheevd
n = 5
a = Numo::DComplex.new(n, n).rand - 0.5
c = a.transpose.conjugate.dot(a)
v, w, _info = Numo::TinyLinalg::Lapack.zheevd(c.dup, jobz: 'V', uplo: 'U')
error = (c - v.dot(w.diag).dot(v.transpose.conjugate)).abs.max

assert(error < 1e-7)
end

def test_lapack_cheevd
n = 5
a = Numo::SComplex.new(n, n).rand - 0.5
c = a.transpose.conjugate.dot(a)
v, w, _info = Numo::TinyLinalg::Lapack.cheevd(c.dup, jobz: 'V', uplo: 'U')
error = (c - v.dot(w.diag).dot(v.transpose.conjugate)).abs.max

assert(error < 1e-5)
end

def test_lapack_dsygv
n = 5
a = Numo::DFloat.new(n, n).rand - 0.5
Expand Down

0 comments on commit d07a5de

Please sign in to comment.