Skip to content

Commit

Permalink
Update post selection estimation in FARMM solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
fenguoerbian committed Aug 26, 2024
1 parent a2d05b1 commit 6080f90
Showing 1 changed file with 49 additions and 5 deletions.
54 changes: 49 additions & 5 deletions R/mm_path_solver.R
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,8 @@ Logistic_FARMM_CV_path <- function(y_vec, x_mat, h, kn, p, rand_eff_df,
# each row for one test set
if(post_selection){
loglik_post_mat <- loglik_test_mat
}else{
stop("`post_selection` must be set to `TRUE` for `Logistic_FARMM_CV_path()`!")
}

pb <- progressr::progressor(along = 1 : nfold)
Expand All @@ -757,10 +759,13 @@ Logistic_FARMM_CV_path <- function(y_vec, x_mat, h, kn, p, rand_eff_df,
weight_vec_test <- weight_vec[test_id_vec]
logit_weight_vec_test <- logit_weight_vec[test_id_vec]

rand_eff_df_train <- rand_eff_df[-test_id_vec, , drop = FALSE]
rand_eff_df_test <- rand_eff_df[test_id_vec, , drop = FALSE]

# find solution path on the training set
print(paste("Find solution path on training set..."))
train_res <- Logistic_FARMM_Path(y_vec = y_vec_train, x_mat = x_mat_train,
h = h, kn = kn, p = p, rand_eff_df = rand_eff_df,
h = h, kn = kn, p = p, rand_eff_df = rand_eff_df_train,
p_type = p_type, p_param = p_param,
lambda_seq = lambda_seq, mu2 = mu2,
a = a, bj_vec = bj_vec, cj_vec = cj_vec, rj_vec = rj_vec,
Expand All @@ -783,7 +788,7 @@ Logistic_FARMM_CV_path <- function(y_vec, x_mat, h, kn, p, rand_eff_df,
# post_res <- train_res
for(lam_id in 1 : lambda_length){
post_est <- Logistic_FARMM_Path_Further_Improve(
x_mat = x_mat_train, y_vec = y_vec_train, rand_eff_df = rand_eff_df,
x_mat = x_mat_train, y_vec = y_vec_train, rand_eff_df = rand_eff_df_train,
h = h, k_n = kn, p = p,
delta_vec_init = train_res$delta_path[lam_id, ],
eta_stack_init = train_res$eta_stack_path[lam_id, ],
Expand All @@ -802,11 +807,29 @@ Logistic_FARMM_CV_path <- function(y_vec, x_mat, h, kn, p, rand_eff_df,

delta_vec <- post_est$delta_vec
peta_stack_vec <- post_est$eta_stack_vec
test_pi_vec <- as.vector((x_mat_test[, 1 : h, drop = FALSE] %*% delta_vec) + (x_mat_test[, -(1 : h), drop = FALSE] %*% eta_stack_vec) * logit_weight_vec_test)
# fixed effect part of pi vector
test_pi_vec1 <- as.vector((x_mat_test[, 1 : h, drop = FALSE] %*% delta_vec) + (x_mat_test[, -(1 : h), drop = FALSE] %*% eta_stack_vec) * logit_weight_vec_test)
# random effect part of pi vector
if(ncol(rand_eff_df_test) == 1){
zmat <- matrix(1, nrow = nrow(rand_eff_df_test), ncol = 1)
}else{
zmat <- rand_eff_df_test[, which(colnames(rand_eff_df_test) != "subj_vec_fct"), drop = FALSE]
}

rand_eff_mat <- data.frame(subj_vec_fct = rand_eff_df_test$subj_vec_fct) %>%
left_join(post_est$rand_eff_est,
by = "subj_vec_fct") %>%
select(-subj_vec_fct) %>%
as.matrix()
test_pi_vec2 <- rowSums(zmat * rand_eff_mat)

# test_pi_vec <- as.vector(x_mat_test %*% c(delta_vec, eta_stack_vec))
test_pi_vec <- test_pi_vec1 + test_pi_vec2
loglik_post_mat[cv_id, lam_id] <- sum((y_vec_test * test_pi_vec - log(1 + exp(test_pi_vec))) * weight_vec_test)
}

}else{
stop("`post_selection` must be set to `TRUE` for `Logistic_FARMM_CV_path()`!")
}
print(paste(nfold, "-fold CV, FINISHED at ", cv_id, "/", nfold, sep = ""))
pb(paste(nfold, "-fold CV, folder id = ", cv_id, " finished at pid = ", Sys.getpid(), "!", sep = ""))
Expand Down Expand Up @@ -846,6 +869,8 @@ Logistic_FARMM_CV_path <- function(y_vec, x_mat, h, kn, p, rand_eff_df,
res$cv_post_id <- lam_post_id
res$loglik_post_mat <- loglik_post_mat
res$post_est <- post_est
}else{
stop("`post_selection` must be set to `TRUE` for `Logistic_FARMM_CV_path()`!")
}
return(res)
}
Expand Down Expand Up @@ -1265,11 +1290,28 @@ Logistic_FARMM_CV_path_par <- function(y_vec, x_mat, h, kn, p, rand_eff_df,
lam = 0.001, tol = 10^{-5}, max_iter = 1000, fastglm = TRUE)
delta_vec <- post_est$delta_vec
peta_stack_vec <- post_est$eta_stack_vec
# fixed effect part of pi vector
test_pi_vec1 <- as.vector((x_mat_test[, 1 : h, drop = FALSE] %*% delta_vec) + (x_mat_test[, -(1 : h), drop = FALSE] %*% eta_stack_vec) * logit_weight_vec_test)
# random effect part of pi vector
if(ncol(rand_eff_df_test) == 1){
zmat <- matrix(1, nrow = nrow(rand_eff_df_test), ncol = 1)
}else{
zmat <- rand_eff_df_test[, which(colnames(rand_eff_df_test) != "subj_vec_fct"), drop = FALSE]
}

rand_eff_mat <- data.frame(subj_vec_fct = rand_eff_df_test$subj_vec_fct) %>%
left_join(post_est$rand_eff_est,
by = "subj_vec_fct") %>%
select(-subj_vec_fct) %>%
as.matrix()
test_pi_vec2 <- rowSums(zmat * rand_eff_mat)

# test_pi_vec <- as.vector(x_mat_test %*% c(delta_vec, eta_stack_vec))
# loglik_test_mat[2, lam_id] <- sum(y_vec_test * test_pi_vec - log(1 + exp(test_pi_vec)))
test_pi_vec <- as.vector((x_mat_test[, 1 : h, drop = FALSE] %*% delta_vec) + (x_mat_test[, -(1 : h), drop = FALSE] %*% eta_stack_vec) * logit_weight_vec_test)
test_pi_vec <- test_pi_vec1 + test_pi_vec2
loglik_test_mat[2, lam_id] <- sum((y_vec_test * test_pi_vec - log(1 + exp(test_pi_vec))) * weight_vec_test)
}
}else{
stop("`post_selection` must be set to `TRUE` for `Logistic_FARMM_CV_path()`!")
}

# print(paste(nfold, "-fold CV, FINISHED at ", cv_id, "/", nfold, sep = ""))
Expand Down Expand Up @@ -1330,6 +1372,8 @@ Logistic_FARMM_CV_path_par <- function(y_vec, x_mat, h, kn, p, rand_eff_df,
res$cv_post_id <- lam_post_id
res$loglik_post_mat <- loglik_post_mat
res$post_est <- post_est
}else{
stop("`post_selection` must be set to `TRUE` for `Logistic_FARMM_CV_path_par()`!")
}
return(res)
}

0 comments on commit 6080f90

Please sign in to comment.