Skip to content

Commit

Permalink
A bit more optimization for #17
Browse files Browse the repository at this point in the history
  • Loading branch information
wleoncio committed May 23, 2024
1 parent 693e21a commit cd39aea
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/admm_MADMMplasso.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ arma::field<arma::cube> admm_MADMMplasso_cpp(
arma::mat part_z(W_hat_t.n_rows, W_hat_t.n_cols);
arma::vec part_y(W_hat_t.n_rows);
arma::vec my_beta_jj(W_hat_t.n_rows);
arma::mat beta_hat1(p, 1 + K);
arma::mat b_hat(p, 1 + K);
for (int i = 1; i < max_it + 1; i++) {
r_current = y - model_intercept(beta_hat, W_hat);
b = reg(r_current, Z);
Expand All @@ -153,20 +155,17 @@ arma::field<arma::cube> admm_MADMMplasso_cpp(
arma::vectorise(new_group) + res_val.row(jj).t() +\
arma::vectorise(rho * (Q.slice(jj) - P.slice(jj))) +\
arma::vectorise(rho * (EE.slice(jj) - HH.slice(jj)));

DD3 = arma::diagmat(1 / invmat.slice(jj));
arma::vec DD3_diag = arma::diagvec(DD3);

for (arma::uword j = 0; j < W_hat_t.n_cols; ++j) {
part_z.col(j) = DD3_diag % W_hat_t.col(j);
}
part_y = DD3_diag % my_beta_jj;
arma::mat A = R_svd_inv + svd_w_tv * part_z;
arma::vec B = svd_w_tv * part_y;
arma::mat solve_A_B = arma::solve(A, B, arma::solve_opts::fast);
part_y -= part_z * solve_A_B;
part_y -= part_z * arma::solve(R_svd_inv + svd_w_tv * part_z, svd_w_tv * part_y, arma::solve_opts::fast);
beta_hat.col(jj) = part_y;
arma::mat beta_hat1 = arma::reshape(part_y, p, 1 + K);
arma::mat b_hat = alph * beta_hat1 + (1 - alph) * Q.slice(jj);
beta_hat1 = arma::reshape(part_y, p, 1 + K);
b_hat = alph * beta_hat1 + (1 - alph) * Q.slice(jj);
Q.slice(jj).col(0) = b_hat.col(0) + P.slice(jj).col(0);
arma::mat new_mat = b_hat.tail_cols(b_hat.n_cols - 1) + P.slice(jj).tail_cols(P.slice(jj).n_cols - 1);
Q.slice(jj).tail_cols(Q.n_cols - 1) = arma::sign(new_mat) % arma::max(arma::abs(new_mat) - ((alpha * lambda(jj)) / rho), arma::zeros(arma::size(new_mat)));
Expand Down

0 comments on commit cd39aea

Please sign in to comment.