diff --git a/ext/numo/tiny_linalg/blas/gemm.hpp b/ext/numo/tiny_linalg/blas/gemm.hpp index 550bb79..11462e9 100644 --- a/ext/numo/tiny_linalg/blas/gemm.hpp +++ b/ext/numo/tiny_linalg/blas/gemm.hpp @@ -102,9 +102,9 @@ class Gemm { dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one(); dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero(); - enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor; - enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans; - enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? get_cblas_trans(kw_values[4]) : CblasNoTrans; + enum CBLAS_ORDER order = kw_values[2] != Qundef ? Util().get_cblas_order(kw_values[2]) : CblasRowMajor; + enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? Util().get_cblas_trans(kw_values[3]) : CblasNoTrans; + enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? Util().get_cblas_trans(kw_values[4]) : CblasNoTrans; narray_t* a_nary = NULL; GetNArray(a, a_nary); @@ -172,52 +172,6 @@ class Gemm { return ret; }; - - static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) { - const char* option_str = StringValueCStr(val); - enum CBLAS_TRANSPOSE res = CblasNoTrans; - - if (std::strlen(option_str) > 0) { - switch (option_str[0]) { - case 'n': - case 'N': - res = CblasNoTrans; - break; - case 't': - case 'T': - res = CblasTrans; - break; - case 'c': - case 'C': - res = CblasConjTrans; - break; - } - } - - RB_GC_GUARD(val); - - return res; - } - - static enum CBLAS_ORDER get_cblas_order(VALUE val) { - const char* option_str = StringValueCStr(val); - - if (std::strlen(option_str) > 0) { - switch (option_str[0]) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major."); - break; - } - } - - RB_GC_GUARD(val); - - return CblasRowMajor; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/blas/gemv.hpp b/ext/numo/tiny_linalg/blas/gemv.hpp index 70f9353..db547a8 100644 --- a/ext/numo/tiny_linalg/blas/gemv.hpp +++ b/ext/numo/tiny_linalg/blas/gemv.hpp @@ -94,8 +94,8 @@ class Gemv { dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one(); dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero(); - enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor; - enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans; + enum CBLAS_ORDER order = kw_values[2] != Qundef ? Util().get_cblas_order(kw_values[2]) : CblasRowMajor; + enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? Util().get_cblas_trans(kw_values[3]) : CblasNoTrans; narray_t* a_nary = NULL; GetNArray(a, a_nary); @@ -160,52 +160,6 @@ class Gemv { return ret; } - - static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) { - const char* option_str = StringValueCStr(val); - enum CBLAS_TRANSPOSE res = CblasNoTrans; - - if (std::strlen(option_str) > 0) { - switch (option_str[0]) { - case 'n': - case 'N': - res = CblasNoTrans; - break; - case 't': - case 'T': - res = CblasTrans; - break; - case 'c': - case 'C': - res = CblasConjTrans; - break; - } - } - - RB_GC_GUARD(val); - - return res; - } - - static enum CBLAS_ORDER get_cblas_order(VALUE val) { - const char* option_str = StringValueCStr(val); - - if (std::strlen(option_str) > 0) { - switch (option_str[0]) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major."); - break; - } - } - - RB_GC_GUARD(val); - - return CblasRowMajor; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/geqrf.hpp b/ext/numo/tiny_linalg/lapack/geqrf.hpp index ec76c42..c990aa9 100644 --- a/ext/numo/tiny_linalg/lapack/geqrf.hpp +++ b/ext/numo/tiny_linalg/lapack/geqrf.hpp @@ -61,7 +61,7 @@ class GeQrf { ID kw_table[1] = { rb_intern("order") }; VALUE kw_values[1] = { Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values); - const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; + const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -93,26 +93,6 @@ class GeQrf { return ret; } - - static int get_matrix_layout(VALUE val) { - const char* option_str = StringValueCStr(val); - - if (std::strlen(option_str) > 0) { - switch (option_str[0]) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major."); - break; - } - } - - RB_GC_GUARD(val); - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/gesv.hpp b/ext/numo/tiny_linalg/lapack/gesv.hpp index 95e9e75..4b0deec 100644 --- a/ext/numo/tiny_linalg/lapack/gesv.hpp +++ b/ext/numo/tiny_linalg/lapack/gesv.hpp @@ -71,7 +71,7 @@ class GeSv { rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values); - const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; + const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -123,26 +123,6 @@ class GeSv { return ret; } - - static int get_matrix_layout(VALUE val) { - const char* option_str = StringValueCStr(val); - - if (std::strlen(option_str) > 0) { - switch (option_str[0]) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.gesv does not support column major."); - break; - } - } - - RB_GC_GUARD(val); - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/getrf.hpp b/ext/numo/tiny_linalg/lapack/getrf.hpp index cc25780..088218a 100644 --- a/ext/numo/tiny_linalg/lapack/getrf.hpp +++ b/ext/numo/tiny_linalg/lapack/getrf.hpp @@ -61,7 +61,7 @@ class GeTrf { ID kw_table[1] = { rb_intern("order") }; VALUE kw_values[1] = { Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values); - const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; + const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -93,26 +93,6 @@ class GeTrf { return ret; } - - static int get_matrix_layout(VALUE val) { - const char* option_str = StringValueCStr(val); - - if (std::strlen(option_str) > 0) { - switch (option_str[0]) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major."); - break; - } - } - - RB_GC_GUARD(val); - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/getri.hpp b/ext/numo/tiny_linalg/lapack/getri.hpp index dc56357..491213a 100644 --- a/ext/numo/tiny_linalg/lapack/getri.hpp +++ b/ext/numo/tiny_linalg/lapack/getri.hpp @@ -57,7 +57,7 @@ class GeTri { ID kw_table[1] = { rb_intern("order") }; VALUE kw_values[1] = { Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values); - const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; + const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -102,26 +102,6 @@ class GeTri { return ret; } - - static int get_matrix_layout(VALUE val) { - const char* option_str = StringValueCStr(val); - - if (std::strlen(option_str) > 0) { - switch (option_str[0]) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.getri does not support column major."); - break; - } - } - - RB_GC_GUARD(val); - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/hegv.hpp b/ext/numo/tiny_linalg/lapack/hegv.hpp index 9049da0..b9a43ae 100644 --- a/ext/numo/tiny_linalg/lapack/hegv.hpp +++ b/ext/numo/tiny_linalg/lapack/hegv.hpp @@ -63,10 +63,10 @@ class HeGv { ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") }; VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values); - const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1; - const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V'; - const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U'; - const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR; + const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1; + const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V'; + const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U'; + const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -116,52 +116,6 @@ class HeGv { return ret; } - - static lapack_int get_itype(VALUE val) { - const lapack_int itype = NUM2INT(val); - - if (itype != 1 && itype != 2 && itype != 3) { - rb_raise(rb_eArgError, "itype must be 1, 2 or 3"); - } - - return itype; - } - - static char get_jobz(VALUE val) { - const char jobz = NUM2CHR(val); - - if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') { - rb_raise(rb_eArgError, "jobz must be 'N' or 'V'"); - } - - return jobz; - } - - static char get_uplo(VALUE val) { - const char uplo = NUM2CHR(val); - - if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') { - rb_raise(rb_eArgError, "uplo must be 'U' or 'L'"); - } - - return uplo; - } - - static int get_matrix_layout(VALUE val) { - const char option = NUM2CHR(val); - - switch (option) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major."); - break; - } - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/hegvd.hpp b/ext/numo/tiny_linalg/lapack/hegvd.hpp index 8b0cf94..f7674da 100644 --- a/ext/numo/tiny_linalg/lapack/hegvd.hpp +++ b/ext/numo/tiny_linalg/lapack/hegvd.hpp @@ -63,10 +63,10 @@ class HeGvd { ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") }; VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values); - const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1; - const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V'; - const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U'; - const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR; + const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1; + const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V'; + const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U'; + const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -116,52 +116,6 @@ class HeGvd { return ret; } - - static lapack_int get_itype(VALUE val) { - const lapack_int itype = NUM2INT(val); - - if (itype != 1 && itype != 2 && itype != 3) { - rb_raise(rb_eArgError, "itype must be 1, 2 or 3"); - } - - return itype; - } - - static char get_jobz(VALUE val) { - const char jobz = NUM2CHR(val); - - if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') { - rb_raise(rb_eArgError, "jobz must be 'N' or 'V'"); - } - - return jobz; - } - - static char get_uplo(VALUE val) { - const char uplo = NUM2CHR(val); - - if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') { - rb_raise(rb_eArgError, "uplo must be 'U' or 'L'"); - } - - return uplo; - } - - static int get_matrix_layout(VALUE val) { - const char option = NUM2CHR(val); - - switch (option) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major."); - break; - } - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/hegvx.hpp b/ext/numo/tiny_linalg/lapack/hegvx.hpp index de5162e..6b3462d 100644 --- a/ext/numo/tiny_linalg/lapack/hegvx.hpp +++ b/ext/numo/tiny_linalg/lapack/hegvx.hpp @@ -70,15 +70,15 @@ class HeGvx { rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") }; VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values); - const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1; - const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V'; - const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A'; - const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U'; + const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1; + const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V'; + const char range = kw_values[2] != Qundef ? Util().get_range(kw_values[2]) : 'A'; + const char uplo = kw_values[3] != Qundef ? Util().get_uplo(kw_values[3]) : 'U'; const rtype vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0; const rtype vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0; const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0; const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0; - const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR; + const int matrix_layout = kw_values[8] != Qundef ? Util().get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -132,62 +132,6 @@ class HeGvx { return ret; } - - static lapack_int get_itype(VALUE val) { - const lapack_int itype = NUM2INT(val); - - if (itype != 1 && itype != 2 && itype != 3) { - rb_raise(rb_eArgError, "itype must be 1, 2 or 3"); - } - - return itype; - } - - static char get_jobz(VALUE val) { - const char jobz = NUM2CHR(val); - - if (jobz != 'N' && jobz != 'V') { - rb_raise(rb_eArgError, "jobz must be 'N' or 'V'"); - } - - return jobz; - } - - static char get_range(VALUE val) { - const char range = NUM2CHR(val); - - if (range != 'A' && range != 'V' && range != 'I') { - rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'"); - } - - return range; - } - - static char get_uplo(VALUE val) { - const char uplo = NUM2CHR(val); - - if (uplo != 'U' && uplo != 'L') { - rb_raise(rb_eArgError, "uplo must be 'U' or 'L'"); - } - - return uplo; - } - - static int get_matrix_layout(VALUE val) { - const char option = NUM2CHR(val); - - switch (option) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.hegvx does not support column major."); - break; - } - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/orgqr.hpp b/ext/numo/tiny_linalg/lapack/orgqr.hpp index 00e69ba..76ca7ce 100644 --- a/ext/numo/tiny_linalg/lapack/orgqr.hpp +++ b/ext/numo/tiny_linalg/lapack/orgqr.hpp @@ -49,7 +49,7 @@ class OrgQr { ID kw_table[1] = { rb_intern("order") }; VALUE kw_values[1] = { Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values); - const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; + const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -90,26 +90,6 @@ class OrgQr { return ret; } - - static int get_matrix_layout(VALUE val) { - const char* option_str = StringValueCStr(val); - - if (std::strlen(option_str) > 0) { - switch (option_str[0]) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major."); - break; - } - } - - RB_GC_GUARD(val); - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/sygv.hpp b/ext/numo/tiny_linalg/lapack/sygv.hpp index fa60c7f..c829b74 100644 --- a/ext/numo/tiny_linalg/lapack/sygv.hpp +++ b/ext/numo/tiny_linalg/lapack/sygv.hpp @@ -54,10 +54,10 @@ class SyGv { ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") }; VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values); - const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1; - const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V'; - const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U'; - const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR; + const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1; + const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V'; + const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U'; + const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -107,52 +107,6 @@ class SyGv { return ret; } - - static lapack_int get_itype(VALUE val) { - const lapack_int itype = NUM2INT(val); - - if (itype != 1 && itype != 2 && itype != 3) { - rb_raise(rb_eArgError, "itype must be 1, 2 or 3"); - } - - return itype; - } - - static char get_jobz(VALUE val) { - const char jobz = NUM2CHR(val); - - if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') { - rb_raise(rb_eArgError, "jobz must be 'N' or 'V'"); - } - - return jobz; - } - - static char get_uplo(VALUE val) { - const char uplo = NUM2CHR(val); - - if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') { - rb_raise(rb_eArgError, "uplo must be 'U' or 'L'"); - } - - return uplo; - } - - static int get_matrix_layout(VALUE val) { - const char option = NUM2CHR(val); - - switch (option) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major."); - break; - } - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/sygvd.hpp b/ext/numo/tiny_linalg/lapack/sygvd.hpp index bc32d24..002d7b7 100644 --- a/ext/numo/tiny_linalg/lapack/sygvd.hpp +++ b/ext/numo/tiny_linalg/lapack/sygvd.hpp @@ -54,10 +54,10 @@ class SyGvd { ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") }; VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values); - const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1; - const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V'; - const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U'; - const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR; + const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1; + const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V'; + const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U'; + const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -107,52 +107,6 @@ class SyGvd { return ret; } - - static lapack_int get_itype(VALUE val) { - const lapack_int itype = NUM2INT(val); - - if (itype != 1 && itype != 2 && itype != 3) { - rb_raise(rb_eArgError, "itype must be 1, 2 or 3"); - } - - return itype; - } - - static char get_jobz(VALUE val) { - const char jobz = NUM2CHR(val); - - if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') { - rb_raise(rb_eArgError, "jobz must be 'N' or 'V'"); - } - - return jobz; - } - - static char get_uplo(VALUE val) { - const char uplo = NUM2CHR(val); - - if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') { - rb_raise(rb_eArgError, "uplo must be 'U' or 'L'"); - } - - return uplo; - } - - static int get_matrix_layout(VALUE val) { - const char option = NUM2CHR(val); - - switch (option) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.sygvd does not support column major."); - break; - } - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/sygvx.hpp b/ext/numo/tiny_linalg/lapack/sygvx.hpp index b9ef7aa..cbfee51 100644 --- a/ext/numo/tiny_linalg/lapack/sygvx.hpp +++ b/ext/numo/tiny_linalg/lapack/sygvx.hpp @@ -69,15 +69,15 @@ class SyGvx { rb_intern("vl"), rb_intern("vu"), rb_intern("il"), rb_intern("iu"), rb_intern("order") }; VALUE kw_values[9] = { Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef, Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 9, kw_values); - const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1; - const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V'; - const char range = kw_values[2] != Qundef ? get_range(kw_values[2]) : 'A'; - const char uplo = kw_values[3] != Qundef ? get_uplo(kw_values[3]) : 'U'; + const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1; + const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V'; + const char range = kw_values[2] != Qundef ? Util().get_range(kw_values[2]) : 'A'; + const char uplo = kw_values[3] != Qundef ? Util().get_uplo(kw_values[3]) : 'U'; const dtype vl = kw_values[4] != Qundef ? NUM2DBL(kw_values[4]) : 0.0; const dtype vu = kw_values[5] != Qundef ? NUM2DBL(kw_values[5]) : 0.0; const lapack_int il = kw_values[6] != Qundef ? NUM2INT(kw_values[6]) : 0; const lapack_int iu = kw_values[7] != Qundef ? NUM2INT(kw_values[7]) : 0; - const int matrix_layout = kw_values[8] != Qundef ? get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR; + const int matrix_layout = kw_values[8] != Qundef ? Util().get_matrix_layout(kw_values[8]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -131,62 +131,6 @@ class SyGvx { return ret; } - - static lapack_int get_itype(VALUE val) { - const lapack_int itype = NUM2INT(val); - - if (itype != 1 && itype != 2 && itype != 3) { - rb_raise(rb_eArgError, "itype must be 1, 2 or 3"); - } - - return itype; - } - - static char get_jobz(VALUE val) { - const char jobz = NUM2CHR(val); - - if (jobz != 'N' && jobz != 'V') { - rb_raise(rb_eArgError, "jobz must be 'N' or 'V'"); - } - - return jobz; - } - - static char get_range(VALUE val) { - const char range = NUM2CHR(val); - - if (range != 'A' && range != 'V' && range != 'I') { - rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'"); - } - - return range; - } - - static char get_uplo(VALUE val) { - const char uplo = NUM2CHR(val); - - if (uplo != 'U' && uplo != 'L') { - rb_raise(rb_eArgError, "uplo must be 'U' or 'L'"); - } - - return uplo; - } - - static int get_matrix_layout(VALUE val) { - const char option = NUM2CHR(val); - - switch (option) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.sygvx does not support column major."); - break; - } - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/lapack/ungqr.hpp b/ext/numo/tiny_linalg/lapack/ungqr.hpp index 19e815e..7d1198e 100644 --- a/ext/numo/tiny_linalg/lapack/ungqr.hpp +++ b/ext/numo/tiny_linalg/lapack/ungqr.hpp @@ -49,7 +49,7 @@ class UngQr { ID kw_table[1] = { rb_intern("order") }; VALUE kw_values[1] = { Qundef }; rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values); - const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; + const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR; if (CLASS_OF(a_vnary) != nary_dtype) { a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary); @@ -90,26 +90,6 @@ class UngQr { return ret; } - - static int get_matrix_layout(VALUE val) { - const char* option_str = StringValueCStr(val); - - if (std::strlen(option_str) > 0) { - switch (option_str[0]) { - case 'r': - case 'R': - break; - case 'c': - case 'C': - rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major."); - break; - } - } - - RB_GC_GUARD(val); - - return LAPACK_ROW_MAJOR; - } }; } // namespace TinyLinalg diff --git a/ext/numo/tiny_linalg/tiny_linalg.cpp b/ext/numo/tiny_linalg/tiny_linalg.cpp index 31a7af4..0b95f88 100644 --- a/ext/numo/tiny_linalg/tiny_linalg.cpp +++ b/ext/numo/tiny_linalg/tiny_linalg.cpp @@ -1,10 +1,13 @@ #include "tiny_linalg.hpp" + +#include "converter.hpp" +#include "util.hpp" + #include "blas/dot.hpp" #include "blas/dot_sub.hpp" #include "blas/gemm.hpp" #include "blas/gemv.hpp" #include "blas/nrm2.hpp" -#include "converter.hpp" #include "lapack/geqrf.hpp" #include "lapack/gesdd.hpp" #include "lapack/gesv.hpp" diff --git a/ext/numo/tiny_linalg/util.hpp b/ext/numo/tiny_linalg/util.hpp new file mode 100644 index 0000000..3721034 --- /dev/null +++ b/ext/numo/tiny_linalg/util.hpp @@ -0,0 +1,100 @@ +namespace TinyLinalg { + +class Util { +public: + static lapack_int get_itype(VALUE val) { + const lapack_int itype = NUM2INT(val); + + if (itype != 1 && itype != 2 && itype != 3) { + rb_raise(rb_eArgError, "itype must be 1, 2 or 3"); + } + + return itype; + } + + static char get_jobz(VALUE val) { + const char jobz = NUM2CHR(val); + + if (jobz != 'N' && jobz != 'V') { + rb_raise(rb_eArgError, "jobz must be 'N' or 'V'"); + } + + return jobz; + } + + static char get_range(VALUE val) { + const char range = NUM2CHR(val); + + if (range != 'A' && range != 'V' && range != 'I') { + rb_raise(rb_eArgError, "range must be 'A', 'V' or 'I'"); + } + + return range; + } + + static char get_uplo(VALUE val) { + const char uplo = NUM2CHR(val); + + if (uplo != 'U' && uplo != 'L') { + rb_raise(rb_eArgError, "uplo must be 'U' or 'L'"); + } + + return uplo; + } + + static int get_matrix_layout(VALUE val) { + const char option = NUM2CHR(val); + + switch (option) { + case 'r': + case 'R': + break; + case 'c': + case 'C': + rb_warn("Numo::TinyLinalg does not support column major."); + break; + } + + return LAPACK_ROW_MAJOR; + } + + static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) { + const char option = NUM2CHR(val); + enum CBLAS_TRANSPOSE res = CblasNoTrans; + + switch (option) { + case 'n': + case 'N': + res = CblasNoTrans; + break; + case 't': + case 'T': + res = CblasTrans; + break; + case 'c': + case 'C': + res = CblasConjTrans; + break; + } + + return res; + } + + static enum CBLAS_ORDER get_cblas_order(VALUE val) { + const char option = NUM2CHR(val); + + switch (option) { + case 'r': + case 'R': + break; + case 'c': + case 'C': + rb_warn("Numo::TinyLinalg does not support column major."); + break; + } + + return CblasRowMajor; + } +}; + +} // namespace TinyLinalg