Skip to content

Commit

Permalink
Adjusted argument order for C++ calls (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
wleoncio committed Jun 24, 2024
1 parent 7da0620 commit 27a6772
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions R/MADMMplasso.R
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,19 @@ MADMMplasso <- function(X, Z, y, alpha, my_lambda = NULL, lambda_min = 0.001, ma


cl1 <- cl

# Adjusting objects for C++
if (!legacy) {
C <- TT$Tree
CW <- TT$Tw
svd_w_tu <- t(svd.w$u)
svd_w_tv <- t(svd.w$v)
svd_w_d <- svd.w$d
BETA <- array(0, c(p, D, nlambda))
BETA_hat <- array(0, c(p + p * K, D, nlambda))
}
if (parallel) {
cl <- makeCluster(cl1, type = "FORK")

doParallel::registerDoParallel(cl = cl)
foreach::getDoParRegistered()
if (legacy) {
Expand All @@ -210,15 +220,15 @@ MADMMplasso <- function(X, Z, y, alpha, my_lambda = NULL, lambda_min = 0.001, ma
my_values_matrix <- foreach(i = 1:nlambda, .packages = "MADMMplasso", .combine = rbind) %dopar% {
admm_MADMMplasso_cpp(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY,
y, N, e.abs, e.rel, alpha, lam[i, ], alph, svd.w, tree, my_print,
gg[i, ]
y, N, e.abs, e.rel, alpha, lam[i, ], alph, svd_w_tu, svd_w_tv, svd_w_d,
C, CW, gg[i, ], my_print
)
}
}
parallel::stopCluster(cl)

# Converting to list so hh_nlambda_loop_cpp can handle it
for (hh in seq_len(nrow(my_values_matrix))) {
for (hh in seq_len(nlambda)) {
my_values[[hh]] <- my_values_matrix[hh, ]
}
} else if (!parallel && !pal) {
Expand All @@ -236,11 +246,11 @@ MADMMplasso <- function(X, Z, y, alpha, my_lambda = NULL, lambda_min = 0.001, ma
} else {
my_values <- lapply(
seq_len(nlambda),
function(g) {
function(i) {
admm_MADMMplasso_cpp(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat,
XtY, y, N, e.abs, e.rel, alpha, lam[g, ], alph, svd.w, tree, my_print,
gg[g, ]
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY,
y, N, e.abs, e.rel, alpha, lam[i, ], alph, svd_w_tu, svd_w_tv, svd_w_d,
C, CW, gg[i, ], my_print
)
}
)
Expand All @@ -258,13 +268,6 @@ MADMMplasso <- function(X, Z, y, alpha, my_lambda = NULL, lambda_min = 0.001, ma
BETA_hat, Y_HAT, THETA, D, my_values
)
} else {
C <- TT$Tree
CW <- TT$Tw
svd_w_tu <- t(svd.w$u)
svd_w_tv <- t(svd.w$v)
svd_w_d <- svd.w$d
BETA <- array(0, c(p, D, nlambda))
BETA_hat <- array(0, c(p + p * K, D, nlambda))
loop_output <- hh_nlambda_loop_cpp(
lam, as.integer(nlambda), beta0, theta0, beta, beta_hat, theta, rho1, X, Z, as.integer(max_it),
my_W_hat, XtY, y, as.integer(N), e.abs, e.rel, alpha, alph, my_print,
Expand Down

0 comments on commit 27a6772

Please sign in to comment.