diff --git a/src/admm_MADMMplasso.cpp b/src/admm_MADMMplasso.cpp index 5ce3c7a..ee263f8 100644 --- a/src/admm_MADMMplasso.cpp +++ b/src/admm_MADMMplasso.cpp @@ -78,7 +78,6 @@ arma::field admm_MADMMplasso_cpp( arma::mat H(y.n_cols * C.n_rows, p + p * K); arma::cube HH(p, 1 + K, D, arma::fill::zeros); - // for response groups ======================================================= const arma::ivec input = Rcpp::seq_len(D * C.n_rows); arma::mat I = arma::zeros(C.n_rows * D, D); @@ -155,10 +154,16 @@ arma::field admm_MADMMplasso_cpp( arma::vectorise(rho * (Q.slice(jj) - P.slice(jj))) +\ arma::vectorise(rho * (EE.slice(jj) - HH.slice(jj))); DD3 = arma::diagmat(1 / invmat.slice(jj)); + arma::vec DD3_diag = arma::diagvec(DD3); - part_z = DD3 * W_hat_t; - part_y = DD3 * my_beta_jj; - part_y -= part_z * arma::solve(R_svd_inv + svd_w_tv * part_z, svd_w_tv * part_y, arma::solve_opts::fast); + for (arma::uword j = 0; j < W_hat_t.n_cols; ++j) { + part_z.col(j) = DD3_diag % W_hat_t.col(j); + } + part_y = DD3_diag % my_beta_jj; + arma::mat A = R_svd_inv + svd_w_tv * part_z; + arma::vec B = svd_w_tv * part_y; + arma::mat solve_A_B = arma::solve(A, B, arma::solve_opts::fast); + part_y -= part_z * solve_A_B; beta_hat.col(jj) = part_y; arma::mat beta_hat1 = arma::reshape(part_y, p, 1 + K); arma::mat b_hat = alph * beta_hat1 + (1 - alph) * Q.slice(jj);