Skip to content

Commit

Permalink
Replacing Rcpp::List with arma::field (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
wleoncio committed May 23, 2024
1 parent 6907ad8 commit 0d1d7be
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/MADMMplasso.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include <RcppArmadillo.h>
Rcpp::List admm_MADMMplasso_cpp(
arma::field<arma::cube> admm_MADMMplasso_cpp(
const arma::vec beta0,
const arma::mat theta0,
arma::mat beta,
Expand Down
2 changes: 1 addition & 1 deletion src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// admm_MADMMplasso_cpp
Rcpp::List admm_MADMMplasso_cpp(arma::vec beta0, arma::mat theta0, arma::mat beta, arma::mat beta_hat, arma::cube theta, const double rho1, const arma::mat X, const arma::mat Z, const int max_it, const arma::mat W_hat, arma::mat XtY, const arma::mat y, const int N, const double e_abs, const double e_rel, const double alpha, const arma::vec lambda, const double alph, const arma::mat svd_w_tu, const arma::mat svd_w_tv, const arma::vec svd_w_d, const arma::sp_mat C, const arma::vec CW, const arma::rowvec gg, const bool my_print);
arma::field<arma::cube> admm_MADMMplasso_cpp(arma::vec beta0, arma::mat theta0, arma::mat beta, arma::mat beta_hat, arma::cube theta, const double rho1, const arma::mat X, const arma::mat Z, const int max_it, const arma::mat W_hat, arma::mat XtY, const arma::mat y, const int N, const double e_abs, const double e_rel, const double alpha, const arma::vec lambda, const double alph, const arma::mat svd_w_tu, const arma::mat svd_w_tv, const arma::vec svd_w_d, const arma::sp_mat C, const arma::vec CW, const arma::rowvec gg, const bool my_print);
RcppExport SEXP _MADMMplasso_admm_MADMMplasso_cpp(SEXP beta0SEXP, SEXP theta0SEXP, SEXP betaSEXP, SEXP beta_hatSEXP, SEXP thetaSEXP, SEXP rho1SEXP, SEXP XSEXP, SEXP ZSEXP, SEXP max_itSEXP, SEXP W_hatSEXP, SEXP XtYSEXP, SEXP ySEXP, SEXP NSEXP, SEXP e_absSEXP, SEXP e_relSEXP, SEXP alphaSEXP, SEXP lambdaSEXP, SEXP alphSEXP, SEXP svd_w_tuSEXP, SEXP svd_w_tvSEXP, SEXP svd_w_dSEXP, SEXP CSEXP, SEXP CWSEXP, SEXP ggSEXP, SEXP my_printSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Expand Down
35 changes: 24 additions & 11 deletions src/admm_MADMMplasso.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
//' @description TODO: add description
//' @export
// [[Rcpp::export]]
Rcpp::List admm_MADMMplasso_cpp(
arma::field<arma::cube> admm_MADMMplasso_cpp(
arma::vec beta0,
arma::mat theta0,
arma::mat beta,
Expand Down Expand Up @@ -371,15 +371,28 @@ Rcpp::List admm_MADMMplasso_cpp(
}
arma::mat y_hat = model_p(beta0, theta0, beta_hat, W_hat, Z);

Rcpp::List out = Rcpp::List::create(
Rcpp::Named("beta0") = beta0,
Rcpp::Named("theta0") = theta0,
Rcpp::Named("beta") = beta,
Rcpp::Named("theta") = theta,
Rcpp::Named("converge") = converge,
Rcpp::Named("obj") = NULL,
Rcpp::Named("beta_hat") = beta_hat,
Rcpp::Named("y_hat") = y_hat
);
// Return important values
arma::field<arma::cube> out(7);
// TODO: print all dimensions to make sure they are correct
out(0) = arma::cube(beta0.n_elem, 1, 1);
out(0).slice(0) = beta0;

out(1) = arma::cube(theta0.n_rows, theta0.n_cols, 1);
out(1).slice(0) = theta0;

out(2) = arma::cube(beta.n_rows, beta.n_cols, 1);
out(2).slice(0) = beta;

out(3) = theta;

out(4) = arma::cube(1, 1, 1);
out(4).slice(0) = converge;

out(5) = arma::cube(beta_hat.n_rows, beta_hat.n_cols, 1);
out(5).slice(0) = beta_hat;

out(6) = arma::cube(y_hat.n_rows, y_hat.n_cols, 1);
out(6).slice(0) = y_hat;

return out;
}
26 changes: 15 additions & 11 deletions src/hh_nlambda_loop_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,33 @@ Rcpp::List hh_nlambda_loop_cpp(
arma::vec lam_list;
arma::mat y_hat = y;
unsigned int hh = 0;
Rcpp::List my_values_hh;
Rcpp::List THETA(nlambda);
arma::field<arma::cube> THETA(nlambda);
while (hh <= nlambda - 1) {
arma::vec lambda = lam.row(hh).t();

if (parallel) { // TODO: recheck all conditions (all parallel-pal combinations)
// my_values is already a list of length hh
my_values_hh = my_values[hh];
beta0 = my_values[hh]["beta0"];
theta0 = my_values[hh]["theta0"];
beta = my_values[hh]["beta"];
theta = my_values[hh]["theta"];
beta_hat = my_values[hh]["beta_hat"];
y_hat = my_values[hh]["y_hat"];
} else if (pal) {
// In this case, my_values is an empty list to be created now
my_values_hh = admm_MADMMplasso_cpp(
arma::field<arma::cube> my_values_hh = admm_MADMMplasso_cpp(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY,
y, N, e_abs, e_rel, alpha, lambda, alph, svd_w_tu, svd_w_tv, svd_w_d, C, CW,
gg.row(hh), my_print
);
beta0 = my_values_hh(0).slice(0);
theta0 = my_values_hh(1).slice(0);
beta = my_values_hh(2).slice(0);
theta = my_values_hh(3);
beta_hat = my_values_hh(5).slice(0);
y_hat = my_values_hh(6).slice(0);
}

beta = Rcpp::as<arma::mat>(my_values_hh["beta"]);
theta = Rcpp::as<arma::cube>(my_values_hh["theta"]);
beta0 = Rcpp::as<arma::vec>(my_values_hh["beta0"]);
theta0 = Rcpp::as<arma::mat>(my_values_hh["theta0"]);
beta_hat = Rcpp::as<arma::mat>(my_values_hh["beta_hat"]);
y_hat = Rcpp::as<arma::mat>(my_values_hh["y_hat"]);

// should be sparse, but Arma doesn't have sp_cube; beta1 and beta_hat1
// are going into a cube, so they need to be dense as well
Expand All @@ -95,7 +99,7 @@ Rcpp::List hh_nlambda_loop_cpp(
BETA.slice(hh) = beta1;
BETA_hat.slice(hh) = beta_hat1;
Y_HAT.slice(hh) = y_hat;
THETA[hh] = theta1;
THETA(hh) = theta1;

if (my_print) {
if (hh == 0) {
Expand Down
16 changes: 7 additions & 9 deletions tests/testthat/test-admm_MADMMplasso_cpp.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,15 @@ my_values_cpp <- admm_MADMMplasso_cpp(
)

test_that("C++ function output structure", {
expect_identical(length(my_values_cpp), length(my_values))
expect_identical(names(my_values_cpp), names(my_values))
expect_identical(length(my_values_cpp), length(my_values) - 1L)
})

test_that("Values are the same", {
tl <- 1e-1
expect_equal(my_values$beta0, my_values_cpp$beta0[, 1], tolerance = tl)
expect_equal(my_values$theta0, my_values_cpp$theta0, tolerance = tl)
expect_equal(my_values$beta, my_values_cpp$beta, tolerance = tl)
expect_equal(my_values$theta, my_values_cpp$theta, tolerance = tl)
expect_identical(my_values$converge, my_values_cpp$converge)
expect_equal(my_values$beta_hat, my_values_cpp$beta_hat, tolerance = tl)
expect_equal(my_values$y_hat, my_values_cpp$y_hat, tolerance = tl)
expect_equal(my_values$beta0, my_values_cpp[[1]][, 1, 1], tolerance = tl)
expect_equal(my_values$theta0, my_values_cpp[[2]][, , 1], tolerance = tl)
expect_equal(my_values$beta, my_values_cpp[[3]][, , 1], tolerance = tl)
expect_equal(my_values$theta, my_values_cpp[[4]], tolerance = tl)
expect_equal(my_values$beta_hat, my_values_cpp[[6]][, , 1], tolerance = tl)
expect_equal(my_values$y_hat, my_values_cpp[[7]][, , 1], tolerance = tl)
})

0 comments on commit 0d1d7be

Please sign in to comment.