Skip to content

Commit

Permalink
refactor: unify notation rules
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Aug 6, 2023
1 parent aca3fb5 commit 4b3ef05
Show file tree
Hide file tree
Showing 15 changed files with 122 additions and 122 deletions.
8 changes: 4 additions & 4 deletions ext/numo/tiny_linalg/lapack/geqrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct CGeQrf {
}
};

template <int nary_dtype_id, typename DType, typename FncType>
template <int nary_dtype_id, typename dtype, class LapackFn>
class GeQrf {
public:
static void define_module_function(VALUE mLapack, const char* fnc_name) {
Expand All @@ -41,14 +41,14 @@ class GeQrf {
};

static void iter_geqrf(na_loop_t* const lp) {
DType* a = (DType*)NDL_PTR(lp, 0);
DType* tau = (DType*)NDL_PTR(lp, 1);
dtype* a = (dtype*)NDL_PTR(lp, 0);
dtype* tau = (dtype*)NDL_PTR(lp, 1);
int* info = (int*)NDL_PTR(lp, 2);
geqrf_opt* opt = (geqrf_opt*)(lp->opt_ptr);
const lapack_int m = NDL_SHAPE(lp, 0)[0];
const lapack_int n = NDL_SHAPE(lp, 0)[1];
const lapack_int lda = n;
const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, tau);
const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, a, lda, tau);
*info = static_cast<int>(i);
}

Expand Down
22 changes: 11 additions & 11 deletions ext/numo/tiny_linalg/lapack/gesdd.hpp
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
namespace TinyLinalg {

struct DGESDD {
struct DGeSdd {
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt) {
return LAPACKE_dgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
};
};

struct SGESDD {
struct SGeSdd {
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt) {
return LAPACKE_sgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
};
};

struct ZGESDD {
struct ZGeSdd {
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt) {
return LAPACKE_zgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
};
};

struct CGESDD {
struct CGeSdd {
lapack_int call(int matrix_order, char jobz, lapack_int m, lapack_int n,
lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt) {
return LAPACKE_cgesdd(matrix_order, jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
};
};

template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
class GESDD {
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
class GeSdd {
public:
static void define_module_function(VALUE mLapack, const char* mf_name) {
rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesdd), -1);
Expand All @@ -42,10 +42,10 @@ class GESDD {
};

static void iter_gesdd(na_loop_t* const lp) {
DType* a = (DType*)NDL_PTR(lp, 0);
RType* s = (RType*)NDL_PTR(lp, 1);
DType* u = (DType*)NDL_PTR(lp, 2);
DType* vt = (DType*)NDL_PTR(lp, 3);
dtype* a = (dtype*)NDL_PTR(lp, 0);
rtype* s = (rtype*)NDL_PTR(lp, 1);
dtype* u = (dtype*)NDL_PTR(lp, 2);
dtype* vt = (dtype*)NDL_PTR(lp, 3);
int* info = (int*)NDL_PTR(lp, 4);
gesdd_opt* opt = (gesdd_opt*)(lp->opt_ptr);

Expand All @@ -56,7 +56,7 @@ class GESDD {
const lapack_int ldu = opt->jobz == 'S' ? min_mn : m;
const lapack_int ldvt = opt->jobz == 'S' ? min_mn : n;

lapack_int i = FncType().call(opt->matrix_order, opt->jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
lapack_int i = LapackFn().call(opt->matrix_order, opt->jobz, m, n, a, lda, s, u, ldu, vt, ldvt);
*info = static_cast<int>(i);
};

Expand Down
18 changes: 9 additions & 9 deletions ext/numo/tiny_linalg/lapack/gesv.hpp
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
namespace TinyLinalg {

struct DGESV {
struct DGeSv {
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
double* a, lapack_int lda, lapack_int* ipiv,
double* b, lapack_int ldb) {
return LAPACKE_dgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
}
};

struct SGESV {
struct SGeSv {
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
float* a, lapack_int lda, lapack_int* ipiv,
float* b, lapack_int ldb) {
return LAPACKE_sgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
}
};

struct ZGESV {
struct ZGeSv {
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
lapack_complex_double* a, lapack_int lda, lapack_int* ipiv,
lapack_complex_double* b, lapack_int ldb) {
return LAPACKE_zgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
}
};

struct CGESV {
struct CGeSv {
lapack_int call(int matrix_layout, lapack_int n, lapack_int nrhs,
lapack_complex_float* a, lapack_int lda, lapack_int* ipiv,
lapack_complex_float* b, lapack_int ldb) {
return LAPACKE_cgesv(matrix_layout, n, nrhs, a, lda, ipiv, b, ldb);
}
};

template <int nary_dtype_id, typename DType, typename FncType>
class GESV {
template <int nary_dtype_id, typename dtype, class LapackFn>
class GeSv {
public:
static void define_module_function(VALUE mLapack, const char* fnc_name) {
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_gesv), -1);
Expand All @@ -45,16 +45,16 @@ class GESV {
};

static void iter_gesv(na_loop_t* const lp) {
DType* a = (DType*)NDL_PTR(lp, 0);
DType* b = (DType*)NDL_PTR(lp, 1);
dtype* a = (dtype*)NDL_PTR(lp, 0);
dtype* b = (dtype*)NDL_PTR(lp, 1);
int* ipiv = (int*)NDL_PTR(lp, 2);
int* info = (int*)NDL_PTR(lp, 3);
gesv_opt* opt = (gesv_opt*)(lp->opt_ptr);
const lapack_int n = NDL_SHAPE(lp, 0)[0];
const lapack_int nhrs = lp->args[1].ndim == 1 ? 1 : NDL_SHAPE(lp, 1)[1];
const lapack_int lda = n;
const lapack_int ldb = nhrs;
const lapack_int i = FncType().call(opt->matrix_layout, n, nhrs, a, lda, ipiv, b, ldb);
const lapack_int i = LapackFn().call(opt->matrix_layout, n, nhrs, a, lda, ipiv, b, ldb);
*info = static_cast<int>(i);
}

Expand Down
24 changes: 12 additions & 12 deletions ext/numo/tiny_linalg/lapack/gesvd.hpp
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
namespace TinyLinalg {

struct DGESVD {
struct DGeSvd {
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
double* a, lapack_int lda, double* s, double* u, lapack_int ldu, double* vt, lapack_int ldvt,
double* superb) {
return LAPACKE_dgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
};
};

struct SGESVD {
struct SGeSvd {
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
float* a, lapack_int lda, float* s, float* u, lapack_int ldu, float* vt, lapack_int ldvt,
float* superb) {
return LAPACKE_sgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
};
};

struct ZGESVD {
struct ZGeSvd {
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
lapack_complex_double* a, lapack_int lda, double* s, lapack_complex_double* u, lapack_int ldu, lapack_complex_double* vt, lapack_int ldvt,
double* superb) {
return LAPACKE_zgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
};
};

struct CGESVD {
struct CGeSvd {
lapack_int call(int matrix_order, char jobu, char jobvt, lapack_int m, lapack_int n,
lapack_complex_float* a, lapack_int lda, float* s, lapack_complex_float* u, lapack_int ldu, lapack_complex_float* vt, lapack_int ldvt,
float* superb) {
return LAPACKE_cgesvd(matrix_order, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
};
};

template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
class GESVD {
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
class GeSvd {
public:
static void define_module_function(VALUE mLapack, const char* mf_name) {
rb_define_module_function(mLapack, mf_name, RUBY_METHOD_FUNC(tiny_linalg_gesvd), -1);
Expand All @@ -47,10 +47,10 @@ class GESVD {
};

static void iter_gesvd(na_loop_t* const lp) {
DType* a = (DType*)NDL_PTR(lp, 0);
RType* s = (RType*)NDL_PTR(lp, 1);
DType* u = (DType*)NDL_PTR(lp, 2);
DType* vt = (DType*)NDL_PTR(lp, 3);
dtype* a = (dtype*)NDL_PTR(lp, 0);
rtype* s = (rtype*)NDL_PTR(lp, 1);
dtype* u = (dtype*)NDL_PTR(lp, 2);
dtype* vt = (dtype*)NDL_PTR(lp, 3);
int* info = (int*)NDL_PTR(lp, 4);
gesvd_opt* opt = (gesvd_opt*)(lp->opt_ptr);

Expand All @@ -61,9 +61,9 @@ class GESVD {
const lapack_int ldu = opt->jobu == 'A' ? m : min_mn;
const lapack_int ldvt = n;

RType* superb = (RType*)ruby_xmalloc(min_mn * sizeof(RType));
rtype* superb = (rtype*)ruby_xmalloc(min_mn * sizeof(rtype));

lapack_int i = FncType().call(opt->matrix_order, opt->jobu, opt->jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
lapack_int i = LapackFn().call(opt->matrix_order, opt->jobu, opt->jobvt, m, n, a, lda, s, u, ldu, vt, ldvt, superb);
*info = static_cast<int>(i);

ruby_xfree(superb);
Expand Down
16 changes: 8 additions & 8 deletions ext/numo/tiny_linalg/lapack/getrf.hpp
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
namespace TinyLinalg {

struct DGETRF {
struct DGeTrf {
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
double* a, lapack_int lda, lapack_int* ipiv) {
return LAPACKE_dgetrf(matrix_layout, m, n, a, lda, ipiv);
}
};

struct SGETRF {
struct SGeTrf {
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
float* a, lapack_int lda, lapack_int* ipiv) {
return LAPACKE_sgetrf(matrix_layout, m, n, a, lda, ipiv);
}
};

struct ZGETRF {
struct ZGeTrf {
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
lapack_complex_double* a, lapack_int lda, lapack_int* ipiv) {
return LAPACKE_zgetrf(matrix_layout, m, n, a, lda, ipiv);
}
};

struct CGETRF {
struct CGeTrf {
lapack_int call(int matrix_layout, lapack_int m, lapack_int n,
lapack_complex_float* a, lapack_int lda, lapack_int* ipiv) {
return LAPACKE_cgetrf(matrix_layout, m, n, a, lda, ipiv);
}
};

template <int nary_dtype_id, typename DType, typename FncType>
class GETRF {
template <int nary_dtype_id, typename dtype, class LapackFn>
class GeTrf {
public:
static void define_module_function(VALUE mLapack, const char* fnc_name) {
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getrf), -1);
Expand All @@ -41,14 +41,14 @@ class GETRF {
};

static void iter_getrf(na_loop_t* const lp) {
DType* a = (DType*)NDL_PTR(lp, 0);
dtype* a = (dtype*)NDL_PTR(lp, 0);
int* ipiv = (int*)NDL_PTR(lp, 1);
int* info = (int*)NDL_PTR(lp, 2);
getrf_opt* opt = (getrf_opt*)(lp->opt_ptr);
const lapack_int m = NDL_SHAPE(lp, 0)[0];
const lapack_int n = NDL_SHAPE(lp, 0)[1];
const lapack_int lda = n;
const lapack_int i = FncType().call(opt->matrix_layout, m, n, a, lda, ipiv);
const lapack_int i = LapackFn().call(opt->matrix_layout, m, n, a, lda, ipiv);
*info = static_cast<int>(i);
}

Expand Down
16 changes: 8 additions & 8 deletions ext/numo/tiny_linalg/lapack/getri.hpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
namespace TinyLinalg {

struct DGETRI {
struct DGeTri {
lapack_int call(int matrix_layout, lapack_int n, double* a, lapack_int lda, const lapack_int* ipiv) {
return LAPACKE_dgetri(matrix_layout, n, a, lda, ipiv);
}
};

struct SGETRI {
struct SGeTri {
lapack_int call(int matrix_layout, lapack_int n, float* a, lapack_int lda, const lapack_int* ipiv) {
return LAPACKE_sgetri(matrix_layout, n, a, lda, ipiv);
}
};

struct ZGETRI {
struct ZGeTri {
lapack_int call(int matrix_layout, lapack_int n, lapack_complex_double* a, lapack_int lda, const lapack_int* ipiv) {
return LAPACKE_zgetri(matrix_layout, n, a, lda, ipiv);
}
};

struct CGETRI {
struct CGeTri {
lapack_int call(int matrix_layout, lapack_int n, lapack_complex_float* a, lapack_int lda, const lapack_int* ipiv) {
return LAPACKE_cgetri(matrix_layout, n, a, lda, ipiv);
}
};

template <int nary_dtype_id, typename DType, typename FncType>
class GETRI {
template <int nary_dtype_id, typename dtype, class LapackFn>
class GeTri {
public:
static void define_module_function(VALUE mLapack, const char* fnc_name) {
rb_define_module_function(mLapack, fnc_name, RUBY_METHOD_FUNC(tiny_linalg_getri), -1);
Expand All @@ -37,13 +37,13 @@ class GETRI {
};

static void iter_getri(na_loop_t* const lp) {
DType* a = (DType*)NDL_PTR(lp, 0);
dtype* a = (dtype*)NDL_PTR(lp, 0);
lapack_int* ipiv = (lapack_int*)NDL_PTR(lp, 1);
int* info = (int*)NDL_PTR(lp, 2);
getri_opt* opt = (getri_opt*)(lp->opt_ptr);
const lapack_int n = NDL_SHAPE(lp, 0)[0];
const lapack_int lda = n;
const lapack_int i = FncType().call(opt->matrix_layout, n, a, lda, ipiv);
const lapack_int i = LapackFn().call(opt->matrix_layout, n, a, lda, ipiv);
*info = static_cast<int>(i);
}

Expand Down
10 changes: 5 additions & 5 deletions ext/numo/tiny_linalg/lapack/hegv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct CHeGv {
}
};

template <int nary_dtype_id, int nary_rtype_id, typename DType, typename RType, typename FncType>
template <int nary_dtype_id, int nary_rtype_id, typename dtype, typename rtype, class LapackFn>
class HeGv {
public:
static void define_module_function(VALUE mLapack, const char* fnc_name) {
Expand All @@ -40,15 +40,15 @@ class HeGv {
};

static void iter_hegv(na_loop_t* const lp) {
DType* a = (DType*)NDL_PTR(lp, 0);
DType* b = (DType*)NDL_PTR(lp, 1);
RType* w = (RType*)NDL_PTR(lp, 2);
dtype* a = (dtype*)NDL_PTR(lp, 0);
dtype* b = (dtype*)NDL_PTR(lp, 1);
rtype* w = (rtype*)NDL_PTR(lp, 2);
int* info = (int*)NDL_PTR(lp, 3);
hegv_opt* opt = (hegv_opt*)(lp->opt_ptr);
const lapack_int n = NDL_SHAPE(lp, 0)[1];
const lapack_int lda = NDL_SHAPE(lp, 0)[0];
const lapack_int ldb = NDL_SHAPE(lp, 1)[0];
const lapack_int i = FncType().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
const lapack_int i = LapackFn().call(opt->matrix_layout, opt->itype, opt->jobz, opt->uplo, n, a, lda, b, ldb, w);
*info = static_cast<int>(i);
}

Expand Down
Loading

0 comments on commit 4b3ef05

Please sign in to comment.