Skip to content

Commit

Permalink
Merge pull request #44 from ocbe-uio/issue-17
Browse files Browse the repository at this point in the history
Resolving issues #17, #34, #42
  • Loading branch information
Theo-qua authored Jun 26, 2024
2 parents dca38ae + c455763 commit 23530d3
Show file tree
Hide file tree
Showing 26 changed files with 620 additions and 436 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
CITATION.cff
aux/
^\.github$
^.*\.out
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ src/*.o
src/*.so
src/*.dll
aux/
*.out
118 changes: 85 additions & 33 deletions R/MADMMplasso.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@

#' @example inst/examples/MADMMplasso_example.R
#' @export
MADMMplasso <- function(X, Z, y, alpha, my_lambda = NULL, lambda_min = 0.001, max_it = 50000, e.abs = 1E-3, e.rel = 1E-3, maxgrid, nlambda, rho = 5, my_print = FALSE, alph = 1.8, tree, parallel = TRUE, pal = FALSE, gg = NULL, tol = 1E-4, cl = 4, legacy = FALSE) {
MADMMplasso <- function(X, Z, y, alpha, my_lambda = NULL, lambda_min = 0.001, max_it = 50000, e.abs = 1E-3, e.rel = 1E-3, maxgrid, nlambda, rho = 5, my_print = FALSE, alph = 1.8, tree, parallel = TRUE, pal = !parallel, gg = NULL, tol = 1E-4, cl = 4, legacy = FALSE) {
if (parallel && pal) {
stop("parallel and pal cannot be TRUE at the same time")
}
N <- nrow(X)

p <- ncol(X)
Expand Down Expand Up @@ -181,60 +184,109 @@ MADMMplasso <- function(X, Z, y, alpha, my_lambda = NULL, lambda_min = 0.001, ma

r_current <- y
b <- reg(r_current, Z)
beta0 <- b$beta0
theta0 <- b$theta0
beta0 <- b[1, ]
theta0 <- b[-1, ]

new_y <- y - (matrix(1, N) %*% beta0 + Z %*% ((theta0)))

XtY <- crossprod((my_W_hat), (new_y))


cl1 <- cl

# Adjusting objects for C++
if (!legacy) {
C <- TT$Tree
CW <- TT$Tw
svd_w_tu <- t(svd.w$u)
svd_w_tv <- t(svd.w$v)
svd_w_d <- svd.w$d
BETA <- array(0, c(p, D, nlambda))
BETA_hat <- array(0, c(p + p * K, D, nlambda))
}

# Pre-calculating my_values through my_values_matrix
if (parallel) {
cl <- makeCluster(cl1, type = "FORK")

doParallel::registerDoParallel(cl = cl)
foreach::getDoParRegistered()

my_values_matrix <- foreach(i = 1:nlambda, .packages = "MADMMplasso", .combine = rbind) %dopar% {
admm_MADMMplasso(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY,
y, N, e.abs, e.rel, alpha, lam[i, ], alph, svd.w, tree, my_print,
invmat, gg[i, ], legacy
)
if (legacy) {
my_values_matrix <- foreach(i = 1:nlambda, .packages = "MADMMplasso", .combine = rbind) %dopar% {
admm_MADMMplasso(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY,
y, N, e.abs, e.rel, alpha, lam[i, ], alph, svd.w, tree, my_print,
invmat, gg[i, ]
)
}
} else {
my_values_matrix <- foreach(i = 1:nlambda, .packages = "MADMMplasso", .combine = rbind) %dopar% {
admm_MADMMplasso_cpp(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY,
y, N, e.abs, e.rel, alpha, lam[i, ], alph, svd_w_tu, svd_w_tv, svd_w_d,
C, CW, gg[i, ], my_print
)
}
}
parallel::stopCluster(cl)

# Converting to list so hh_nlambda_loop_cpp can handle it
for (hh in seq_len(nrow(my_values_matrix))) {
my_values[[hh]] <- my_values_matrix[hh, ]
if (nlambda == 1) {
my_values <- list(my_values_matrix)
} else {
my_values <- list()
for (hh in seq_len(nlambda)) {
my_values[[hh]] <- my_values_matrix[hh, ]
}
}
} else if (!parallel && !pal) {
my_values <- lapply(
seq_len(nlambda),
function(g) {
admm_MADMMplasso(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat,
XtY, y, N, e.abs, e.rel, alpha, lam[g, ], alph, svd.w, tree, my_print,
invmat, gg[g, ], legacy
)
}
)
if (legacy) {
my_values <- lapply(
seq_len(nlambda),
function(g) {
admm_MADMMplasso(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat,
XtY, y, N, e.abs, e.rel, alpha, lam[g, ], alph, svd.w, tree, my_print,
invmat, gg[g, ]
)
}
)
} else {
my_values <- lapply(
seq_len(nlambda),
function(i) {
admm_MADMMplasso_cpp(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY,
y, N, e.abs, e.rel, alpha, lam[i, ], alph, svd_w_tu, svd_w_tv, svd_w_d,
C, CW, gg[i, ], my_print
)
}
)
}
} else {
# This is triggered when parallel is FALSE and pal is 1
# This is triggered when parallel is FALSE and pal is TRUE
my_values <- list()
}

loop_output <- hh_nlambda_loop(
lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it,
my_W_hat, XtY, y, N, e.abs, e.rel, alpha, alph, svd.w, tree, my_print,
invmat, gg, tol, parallel, pal, BETA0, THETA0, BETA,
BETA_hat, Y_HAT, THETA, D, my_values, legacy
)

remove(invmat)
remove(my_W_hat)
# Big calculations
if (legacy) {
loop_output <- hh_nlambda_loop(
lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it,
my_W_hat, XtY, y, N, e.abs, e.rel, alpha, alph, svd.w, tree, my_print,
invmat, gg, tol, parallel, pal, BETA0, THETA0, BETA,
BETA_hat, Y_HAT, THETA, D, my_values
)
} else {
loop_output <- hh_nlambda_loop_cpp(
lam, as.integer(nlambda), beta0, theta0, beta, beta_hat, theta, rho1, X, Z, as.integer(max_it),
my_W_hat, XtY, y, as.integer(N), e.abs, e.rel, alpha, alph, my_print,
gg, tol, parallel, pal, simplify2array(BETA0), simplify2array(THETA0),
BETA, BETA_hat, simplify2array(Y_HAT),
as.integer(D), C, CW, svd_w_tu, svd_w_tv, svd_w_d, my_values
)
loop_output <- post_process_cpp(loop_output)
}

# Final adjustments in output
loop_output$obj[1] <- loop_output$obj[2]

pred <- data.frame(
Expand Down
30 changes: 18 additions & 12 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
#' @param alpha mixing parameter, usually obtained from the MADMMplasso call. When the goal is to include more interactions, alpha should be very small and vice versa.
#' @param lambda a vector lambda_3 values for the admm call with length ncol(y). This is usually calculated in the MADMMplasso call. In our current setting, we use the same the lambda_3 value for all responses.
#' @param alph an overrelaxation parameter in \[1, 1.8\], usually obtained from the MADMMplasso call.
#' @param svd_w singular value decomposition of W
#' @param tree The results from the hierarchical clustering of the response matrix.
#' @param svd_w_tu the transpose of the U matrix from the SVD of W_hat
#' @param svd_w_tv the transpose of the V matrix from the SVD of W_hat
#' @param svd_w_d the D matrix from the SVD of W_hat
#' @param C the trained tree
#' @param CW weights for the trained tree
#' The easy way to obtain this is by using the function (tree_parms) which gives a default clustering.
#' However, user decide on a specific structure and then input a tree that follows such structure.
#' @param my_print Should information form each ADMM iteration be printed along the way? Default TRUE. This prints the dual and primal residuals
#' @param invmat A list of length ncol(y), each containing the C_d part of equation 32 in the paper
#' @param gg penalty terms for the tree structure for lambda_1 and lambda_2 for the admm call.
#' @return predicted values for the ADMM part
#' @description TODO: add description
#' @description This function fits a multi-response pliable lasso model over a path of regularization values.
#' @export
admm_MADMMplasso_cpp <- function(beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, W_hat, XtY, y, N, e_abs, e_rel, alpha, lambda, alph, svd_w, tree, invmat, gg, my_print = TRUE) {
.Call(`_MADMMplasso_admm_MADMMplasso_cpp`, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, W_hat, XtY, y, N, e_abs, e_rel, alpha, lambda, alph, svd_w, tree, invmat, gg, my_print)
admm_MADMMplasso_cpp <- function(beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, W_hat, XtY, y, N, e_abs, e_rel, alpha, lambda, alph, svd_w_tu, svd_w_tv, svd_w_d, C, CW, gg, my_print = TRUE) {
.Call(`_MADMMplasso_admm_MADMMplasso_cpp`, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, W_hat, XtY, y, N, e_abs, e_rel, alpha, lambda, alph, svd_w_tu, svd_w_tv, svd_w_d, C, CW, gg, my_print)
}

count_nonzero_a_cpp <- function(x) {
Expand All @@ -49,16 +51,20 @@ count_nonzero_a_cube <- function(x) {
.Call(`_MADMMplasso_count_nonzero_a_cube`, x)
}

hh_nlambda_loop_cpp <- function(lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY, y, N, e_abs, e_rel, alpha, alph, svd_w, tree, my_print, invmat, gg, tol, parallel, pal, BETA0, THETA0, BETA, BETA_hat, Y_HAT, THETA, D, my_values) {
.Call(`_MADMMplasso_hh_nlambda_loop_cpp`, lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY, y, N, e_abs, e_rel, alpha, alph, svd_w, tree, my_print, invmat, gg, tol, parallel, pal, BETA0, THETA0, BETA, BETA_hat, Y_HAT, THETA, D, my_values)
count_nonzero_a_mat <- function(x) {
.Call(`_MADMMplasso_count_nonzero_a_mat`, x)
}

model_intercept <- function(beta0, theta0, beta, theta, X, Z) {
.Call(`_MADMMplasso_model_intercept`, beta0, theta0, beta, theta, X, Z)
hh_nlambda_loop_cpp <- function(lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY, y, N, e_abs, e_rel, alpha, alph, my_print, gg, tol, parallel, pal, BETA0, THETA0, BETA, BETA_hat, Y_HAT, D, C, CW, svd_w_tu, svd_w_tv, svd_w_d, my_values) {
.Call(`_MADMMplasso_hh_nlambda_loop_cpp`, lam, nlambda, beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, my_W_hat, XtY, y, N, e_abs, e_rel, alpha, alph, my_print, gg, tol, parallel, pal, BETA0, THETA0, BETA, BETA_hat, Y_HAT, D, C, CW, svd_w_tu, svd_w_tv, svd_w_d, my_values)
}

model_p <- function(beta0, theta0, beta, theta, X, Z) {
.Call(`_MADMMplasso_model_p`, beta0, theta0, beta, theta, X, Z)
model_intercept <- function(beta, X) {
.Call(`_MADMMplasso_model_intercept`, beta, X)
}

model_p <- function(beta0, theta0, beta, X, Z) {
.Call(`_MADMMplasso_model_p`, beta0, theta0, beta, X, Z)
}

modulo <- function(x, n) {
Expand Down
22 changes: 5 additions & 17 deletions R/admm_MADMMplasso.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,7 @@


#' @export
admm_MADMMplasso <- function(beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, W_hat, XtY, y, N, e.abs, e.rel, alpha, lambda, alph, svd.w, tree, my_print, invmat, gg = 0.2, legacy = FALSE) {
if (!legacy) {
out <- admm_MADMMplasso_cpp(
beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, W_hat, XtY, y,
N, e.abs, e.rel, alpha, lambda, alph, svd.w, tree, invmat, gg, my_print
)
return(out)
}
warning(
"Using legacy R code for MADMMplasso. ",
"This functionality will be removed in a future release. ",
"Please consider using legacy = FALSE instead."
)
admm_MADMMplasso <- function(beta0, theta0, beta, beta_hat, theta, rho1, X, Z, max_it, W_hat, XtY, y, N, e.abs, e.rel, alpha, lambda, alph, svd.w, tree, my_print, invmat, gg = 0.2) {
TT <- tree

C <- TT$Tree
Expand Down Expand Up @@ -115,10 +103,10 @@ admm_MADMMplasso <- function(beta0, theta0, beta, beta_hat, theta, rho1, X, Z, m
rho <- rho1
Big_beta11 <- V
for (i in 2:max_it) {
r_current <- (y - model_intercept(beta0, theta0, beta = beta_hat, theta, X = W_hat, Z))
r_current <- (y - model_intercept(beta_hat, W_hat))
b <- reg(r_current, Z) # Analytic solution how no sample lower bound (Z.T @ Z + cI)^-1 @ (Z.T @ r)
beta0 <- b$beta0
theta0 <- b$theta0
beta0 <- b[1, ]
theta0 <- b[-1, ]

new_y <- y - (matrix(1, N) %*% beta0 + Z %*% ((theta0)))

Expand Down Expand Up @@ -368,7 +356,7 @@ admm_MADMMplasso <- function(beta0, theta0, beta, beta_hat, theta, rho1, X, Z, m
theta[, , jj] <- (beta_hat1[, -1])
beta_hat[, jj] <- c(c(beta_hat1[, 1], as.vector(theta[, , jj])))
}
y_hat <- model_p(beta0, theta0, beta = beta_hat, theta, X = W_hat, Z)
y_hat <- model_p(beta0, theta0, beta = beta_hat, X = W_hat, Z)

out <- list(beta0 = beta0, theta0 = theta0, beta = beta, theta = theta, converge = converge, obj = obj, beta_hat = beta_hat, y_hat = y_hat)

Expand Down
Loading

0 comments on commit 23530d3

Please sign in to comment.