Skip to content

Commit

Permalink
simplify nnls solver
Browse files Browse the repository at this point in the history
  • Loading branch information
dselivanov committed Nov 18, 2020
1 parent 940dd1c commit 56e252e
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions inst/include/nnls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,25 @@
#define MAX_ITER 500

template <class T>
void scd_ls_update(arma::subview_col<T> Hj,
const arma::Mat<T> &WtW,
arma::Col<T> scd_ls_update(const arma::Mat<T> &WtW,
arma::Col<T> &mu,
uint max_iter,
double rel_tol) {
double rel_tol,
const arma::Col<T> &initial) {
// Problem: Aj = W * Hj
// Method: sequential coordinate-wise descent when loss function = square error
// WtW = W^T W
// WtAj = W^T Aj

arma::Col<T> res = initial;
auto WtW_diag = WtW.diag();
for (auto t = 0; t < max_iter; t++) {
T rel_err = 0;
for (auto k = 0; k < WtW.n_cols; k++) {
T current = Hj(k);
T current = res(k);
auto update = current - mu(k) / WtW_diag.at(k);
if(update < 0) update = 0;
Hj(k) = update;
res(k) = update;
if (update != current) {
mu += (update - current) * WtW.col(k);
auto current_err = std::abs(current - update) / (std::abs(current) + TINY_NUM);
Expand All @@ -31,6 +32,7 @@ void scd_ls_update(arma::subview_col<T> Hj,
}
if (rel_err <= rel_tol) break;
}
return res;
}

template <class T>
Expand All @@ -41,15 +43,15 @@ arma::Mat<T> c_nnls(const arma::Mat<T> &x,
arma::Mat<T> H(x.n_cols, y.n_cols, arma::fill::randu);
arma::Mat<T> Wt = x.t();

arma::Mat<T> WtW = Wt * Wt.t();
arma::Mat<T> WtW = Wt * x;
arma::Col<T> mu, sumW;

// for stability: avoid divided by 0
WtW.diag() += TINY_NUM;

for (unsigned int j = 0; j < y.n_cols; j++) {
for (auto j = 0; j < y.n_cols; j++) {
mu = WtW * H.col(j) - Wt * y.col(j);
scd_ls_update<T>(H.col(j), WtW, mu, max_iter, rel_tol);
H.col(j) = scd_ls_update<T>(WtW, mu, max_iter, rel_tol, H.col(j));
}
return (H);
}

0 comments on commit 56e252e

Please sign in to comment.