Skip to content

Commit

Permalink
Optimized multiplications involving diagonal matrix (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
wleoncio committed May 23, 2024
1 parent 425e772 commit 693e21a
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/admm_MADMMplasso.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ arma::field<arma::cube> admm_MADMMplasso_cpp(
arma::mat H(y.n_cols * C.n_rows, p + p * K);
arma::cube HH(p, 1 + K, D, arma::fill::zeros);


// for response groups =======================================================
const arma::ivec input = Rcpp::seq_len(D * C.n_rows);
arma::mat I = arma::zeros<arma::mat>(C.n_rows * D, D);
Expand Down Expand Up @@ -155,10 +154,16 @@ arma::field<arma::cube> admm_MADMMplasso_cpp(
arma::vectorise(rho * (Q.slice(jj) - P.slice(jj))) +\
arma::vectorise(rho * (EE.slice(jj) - HH.slice(jj)));
DD3 = arma::diagmat(1 / invmat.slice(jj));
arma::vec DD3_diag = arma::diagvec(DD3);

part_z = DD3 * W_hat_t;
part_y = DD3 * my_beta_jj;
part_y -= part_z * arma::solve(R_svd_inv + svd_w_tv * part_z, svd_w_tv * part_y, arma::solve_opts::fast);
for (arma::uword j = 0; j < W_hat_t.n_cols; ++j) {
part_z.col(j) = DD3_diag % W_hat_t.col(j);
}
part_y = DD3_diag % my_beta_jj;
arma::mat A = R_svd_inv + svd_w_tv * part_z;
arma::vec B = svd_w_tv * part_y;
arma::mat solve_A_B = arma::solve(A, B, arma::solve_opts::fast);
part_y -= part_z * solve_A_B;
beta_hat.col(jj) = part_y;
arma::mat beta_hat1 = arma::reshape(part_y, p, 1 + K);
arma::mat b_hat = alph * beta_hat1 + (1 - alph) * Q.slice(jj);
Expand Down

0 comments on commit 693e21a

Please sign in to comment.