Skip to content

Commit

Permalink
refactor: cut out common processing into utility class
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Aug 6, 2023
1 parent d265723 commit 439acc5
Show file tree
Hide file tree
Showing 16 changed files with 141 additions and 546 deletions.
52 changes: 3 additions & 49 deletions ext/numo/tiny_linalg/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class Gemm {

dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? get_cblas_trans(kw_values[4]) : CblasNoTrans;
enum CBLAS_ORDER order = kw_values[2] != Qundef ? Util().get_cblas_order(kw_values[2]) : CblasRowMajor;
enum CBLAS_TRANSPOSE transa = kw_values[3] != Qundef ? Util().get_cblas_trans(kw_values[3]) : CblasNoTrans;
enum CBLAS_TRANSPOSE transb = kw_values[4] != Qundef ? Util().get_cblas_trans(kw_values[4]) : CblasNoTrans;

narray_t* a_nary = NULL;
GetNArray(a, a_nary);
Expand Down Expand Up @@ -172,52 +172,6 @@ class Gemm {

return ret;
};

static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
const char* option_str = StringValueCStr(val);
enum CBLAS_TRANSPOSE res = CblasNoTrans;

if (std::strlen(option_str) > 0) {
switch (option_str[0]) {
case 'n':
case 'N':
res = CblasNoTrans;
break;
case 't':
case 'T':
res = CblasTrans;
break;
case 'c':
case 'C':
res = CblasConjTrans;
break;
}
}

RB_GC_GUARD(val);

return res;
}

static enum CBLAS_ORDER get_cblas_order(VALUE val) {
const char* option_str = StringValueCStr(val);

if (std::strlen(option_str) > 0) {
switch (option_str[0]) {
case 'r':
case 'R':
break;
case 'c':
case 'C':
rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
break;
}
}

RB_GC_GUARD(val);

return CblasRowMajor;
}
};

} // namespace TinyLinalg
50 changes: 2 additions & 48 deletions ext/numo/tiny_linalg/blas/gemv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class Gemv {

dtype alpha = kw_values[0] != Qundef ? Converter().to_dtype(kw_values[0]) : Converter().one();
dtype beta = kw_values[1] != Qundef ? Converter().to_dtype(kw_values[1]) : Converter().zero();
enum CBLAS_ORDER order = kw_values[2] != Qundef ? get_cblas_order(kw_values[2]) : CblasRowMajor;
enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? get_cblas_trans(kw_values[3]) : CblasNoTrans;
enum CBLAS_ORDER order = kw_values[2] != Qundef ? Util().get_cblas_order(kw_values[2]) : CblasRowMajor;
enum CBLAS_TRANSPOSE trans = kw_values[3] != Qundef ? Util().get_cblas_trans(kw_values[3]) : CblasNoTrans;

narray_t* a_nary = NULL;
GetNArray(a, a_nary);
Expand Down Expand Up @@ -160,52 +160,6 @@ class Gemv {

return ret;
}

static enum CBLAS_TRANSPOSE get_cblas_trans(VALUE val) {
const char* option_str = StringValueCStr(val);
enum CBLAS_TRANSPOSE res = CblasNoTrans;

if (std::strlen(option_str) > 0) {
switch (option_str[0]) {
case 'n':
case 'N':
res = CblasNoTrans;
break;
case 't':
case 'T':
res = CblasTrans;
break;
case 'c':
case 'C':
res = CblasConjTrans;
break;
}
}

RB_GC_GUARD(val);

return res;
}

static enum CBLAS_ORDER get_cblas_order(VALUE val) {
const char* option_str = StringValueCStr(val);

if (std::strlen(option_str) > 0) {
switch (option_str[0]) {
case 'r':
case 'R':
break;
case 'c':
case 'C':
rb_warn("Numo::TinyLinalg::BLAS.gemm does not support column major.");
break;
}
}

RB_GC_GUARD(val);

return CblasRowMajor;
}
};

} // namespace TinyLinalg
22 changes: 1 addition & 21 deletions ext/numo/tiny_linalg/lapack/geqrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class GeQrf {
ID kw_table[1] = { rb_intern("order") };
VALUE kw_values[1] = { Qundef };
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;

if (CLASS_OF(a_vnary) != nary_dtype) {
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
Expand Down Expand Up @@ -93,26 +93,6 @@ class GeQrf {

return ret;
}

static int get_matrix_layout(VALUE val) {
const char* option_str = StringValueCStr(val);

if (std::strlen(option_str) > 0) {
switch (option_str[0]) {
case 'r':
case 'R':
break;
case 'c':
case 'C':
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
break;
}
}

RB_GC_GUARD(val);

return LAPACK_ROW_MAJOR;
}
};

} // namespace TinyLinalg
22 changes: 1 addition & 21 deletions ext/numo/tiny_linalg/lapack/gesv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GeSv {

rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);

const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;

if (CLASS_OF(a_vnary) != nary_dtype) {
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
Expand Down Expand Up @@ -123,26 +123,6 @@ class GeSv {

return ret;
}

static int get_matrix_layout(VALUE val) {
const char* option_str = StringValueCStr(val);

if (std::strlen(option_str) > 0) {
switch (option_str[0]) {
case 'r':
case 'R':
break;
case 'c':
case 'C':
rb_warn("Numo::TinyLinalg::Lapack.gesv does not support column major.");
break;
}
}

RB_GC_GUARD(val);

return LAPACK_ROW_MAJOR;
}
};

} // namespace TinyLinalg
22 changes: 1 addition & 21 deletions ext/numo/tiny_linalg/lapack/getrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class GeTrf {
ID kw_table[1] = { rb_intern("order") };
VALUE kw_values[1] = { Qundef };
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;

if (CLASS_OF(a_vnary) != nary_dtype) {
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
Expand Down Expand Up @@ -93,26 +93,6 @@ class GeTrf {

return ret;
}

static int get_matrix_layout(VALUE val) {
const char* option_str = StringValueCStr(val);

if (std::strlen(option_str) > 0) {
switch (option_str[0]) {
case 'r':
case 'R':
break;
case 'c':
case 'C':
rb_warn("Numo::TinyLinalg::Lapack.getrf does not support column major.");
break;
}
}

RB_GC_GUARD(val);

return LAPACK_ROW_MAJOR;
}
};

} // namespace TinyLinalg
22 changes: 1 addition & 21 deletions ext/numo/tiny_linalg/lapack/getri.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class GeTri {
ID kw_table[1] = { rb_intern("order") };
VALUE kw_values[1] = { Qundef };
rb_get_kwargs(kw_args, kw_table, 0, 1, kw_values);
const int matrix_layout = kw_values[0] != Qundef ? get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;
const int matrix_layout = kw_values[0] != Qundef ? Util().get_matrix_layout(kw_values[0]) : LAPACK_ROW_MAJOR;

if (CLASS_OF(a_vnary) != nary_dtype) {
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
Expand Down Expand Up @@ -102,26 +102,6 @@ class GeTri {

return ret;
}

static int get_matrix_layout(VALUE val) {
const char* option_str = StringValueCStr(val);

if (std::strlen(option_str) > 0) {
switch (option_str[0]) {
case 'r':
case 'R':
break;
case 'c':
case 'C':
rb_warn("Numo::TinyLinalg::Lapack.getri does not support column major.");
break;
}
}

RB_GC_GUARD(val);

return LAPACK_ROW_MAJOR;
}
};

} // namespace TinyLinalg
54 changes: 4 additions & 50 deletions ext/numo/tiny_linalg/lapack/hegv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ class HeGv {
ID kw_table[4] = { rb_intern("itype"), rb_intern("jobz"), rb_intern("uplo"), rb_intern("order") };
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
rb_get_kwargs(kw_args, kw_table, 0, 4, kw_values);
const lapack_int itype = kw_values[0] != Qundef ? get_itype(kw_values[0]) : 1;
const char jobz = kw_values[1] != Qundef ? get_jobz(kw_values[1]) : 'V';
const char uplo = kw_values[2] != Qundef ? get_uplo(kw_values[2]) : 'U';
const int matrix_layout = kw_values[3] != Qundef ? get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;
const lapack_int itype = kw_values[0] != Qundef ? Util().get_itype(kw_values[0]) : 1;
const char jobz = kw_values[1] != Qundef ? Util().get_jobz(kw_values[1]) : 'V';
const char uplo = kw_values[2] != Qundef ? Util().get_uplo(kw_values[2]) : 'U';
const int matrix_layout = kw_values[3] != Qundef ? Util().get_matrix_layout(kw_values[3]) : LAPACK_ROW_MAJOR;

if (CLASS_OF(a_vnary) != nary_dtype) {
a_vnary = rb_funcall(nary_dtype, rb_intern("cast"), 1, a_vnary);
Expand Down Expand Up @@ -116,52 +116,6 @@ class HeGv {

return ret;
}

static lapack_int get_itype(VALUE val) {
const lapack_int itype = NUM2INT(val);

if (itype != 1 && itype != 2 && itype != 3) {
rb_raise(rb_eArgError, "itype must be 1, 2 or 3");
}

return itype;
}

static char get_jobz(VALUE val) {
const char jobz = NUM2CHR(val);

if (jobz != 'n' && jobz != 'N' && jobz != 'v' && jobz != 'V') {
rb_raise(rb_eArgError, "jobz must be 'N' or 'V'");
}

return jobz;
}

static char get_uplo(VALUE val) {
const char uplo = NUM2CHR(val);

if (uplo != 'u' && uplo != 'U' && uplo != 'l' && uplo != 'L') {
rb_raise(rb_eArgError, "uplo must be 'U' or 'L'");
}

return uplo;
}

static int get_matrix_layout(VALUE val) {
const char option = NUM2CHR(val);

switch (option) {
case 'r':
case 'R':
break;
case 'c':
case 'C':
rb_warn("Numo::TinyLinalg::Lapack.sygv does not support column major.");
break;
}

return LAPACK_ROW_MAJOR;
}
};

} // namespace TinyLinalg
Loading

0 comments on commit 439acc5

Please sign in to comment.