Skip to content

Commit

Permalink
Fix a bug in FARMM CV parallel solver
Browse files Browse the repository at this point in the history
  • Loading branch information
fenguoerbian committed Sep 2, 2024
1 parent 8da3368 commit a3f7a73
Showing 1 changed file with 42 additions and 4 deletions.
46 changes: 42 additions & 4 deletions R/mm_path_solver.R
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,25 @@ Logistic_FARMM_CV_path <- function(y_vec, x_mat, h, kn, p, rand_eff_df,
for(lam_id in 1 : lambda_length){
delta_vec <- train_res$delta_path[lam_id, ]
eta_stack_vec <- train_res$eta_stack_path[lam_id, ]
test_pi_vec <- as.vector((x_mat_test[, 1 : h, drop = FALSE] %*% delta_vec) + (x_mat_test[, -(1 : h), drop = FALSE] %*% eta_stack_vec) * logit_weight_vec_test)
# test_pi_vec <- as.vector((x_mat_test[, 1 : h, drop = FALSE] %*% delta_vec) + (x_mat_test[, -(1 : h), drop = FALSE] %*% eta_stack_vec) * logit_weight_vec_test)
# fixed effect part of pi vector
test_pi_vec1 <- as.vector((x_mat_test[, 1 : h, drop = FALSE] %*% delta_vec) + (x_mat_test[, -(1 : h), drop = FALSE] %*% eta_stack_vec) * logit_weight_vec_test)
# random effect part of pi vector
if(ncol(rand_eff_df_test) == 1){
zmat <- matrix(1, nrow = nrow(rand_eff_df_test), ncol = 1)
}else{
zmat <- rand_eff_df_test[, which(colnames(rand_eff_df_test) != "subj_vec_fct"), drop = FALSE]
zmat <- cbind(1, zmat)
}

rand_eff_mat <- data.frame(subj_vec_fct = rand_eff_df_test$subj_vec_fct) %>%
left_join(post_est$rand_eff_est,
by = "subj_vec_fct") %>%
select(-subj_vec_fct) %>%
as.matrix()
test_pi_vec2 <- rowSums(zmat * rand_eff_mat)

test_pi_vec <- test_pi_vec1 + test_pi_vec2
# test_pi_vec <- as.vector(x_mat_test %*% c(delta_vec, eta_stack_vec))
loglik_test_mat[cv_id, lam_id] <- sum((y_vec_test * test_pi_vec - log(1 + exp(test_pi_vec))) * weight_vec_test)
}
Expand Down Expand Up @@ -1269,8 +1287,28 @@ Logistic_FARMM_CV_path_par <- function(y_vec, x_mat, h, kn, p, rand_eff_df,
eta_stack_vec <- train_res$eta_stack_path[lam_id, ]
# test_pi_vec <- as.vector(x_mat_test %*% c(delta_vec, eta_stack_vec))
# loglik_test_mat[1, lam_id] <- sum(y_vec_test * test_pi_vec - log(1 + exp(test_pi_vec)))
test_pi_vec <- as.vector((x_mat_test[, 1 : h, drop = FALSE] %*% delta_vec) + (x_mat_test[, -(1 : h), drop = FALSE] %*% eta_stack_vec) * logit_weight_vec_test)
loglik_test_mat[cv_id, lam_id] <- sum((y_vec_test * test_pi_vec - log(1 + exp(test_pi_vec))) * weight_vec_test)
# test_pi_vec <- as.vector((x_mat_test[, 1 : h, drop = FALSE] %*% delta_vec) + (x_mat_test[, -(1 : h), drop = FALSE] %*% eta_stack_vec) * logit_weight_vec_test)

# fixed effect part of pi vector
test_pi_vec1 <- as.vector((x_mat_test[, 1 : h, drop = FALSE] %*% delta_vec) + (x_mat_test[, -(1 : h), drop = FALSE] %*% eta_stack_vec) * logit_weight_vec_test)
# random effect part of pi vector
if(ncol(rand_eff_df_test) == 1){
zmat <- matrix(1, nrow = nrow(rand_eff_df_test), ncol = 1)
}else{
zmat <- rand_eff_df_test[, which(colnames(rand_eff_df_test) != "subj_vec_fct"), drop = FALSE]
zmat <- cbind(1, zmat)
}

rand_eff_mat <- data.frame(subj_vec_fct = rand_eff_df_test$subj_vec_fct) %>%
left_join(post_est$rand_eff_est,
by = "subj_vec_fct") %>%
select(-subj_vec_fct) %>%
as.matrix()
test_pi_vec2 <- rowSums(zmat * rand_eff_mat)

# test_pi_vec <- as.vector(x_mat_test %*% c(delta_vec, eta_stack_vec))
test_pi_vec <- test_pi_vec1 + test_pi_vec2
loglik_test_mat[1, lam_id] <- sum((y_vec_test * test_pi_vec - log(1 + exp(test_pi_vec))) * weight_vec_test)
}

# test on testing set based on post-selection estimation
Expand Down Expand Up @@ -1310,7 +1348,7 @@ Logistic_FARMM_CV_path_par <- function(y_vec, x_mat, h, kn, p, rand_eff_df,

# test_pi_vec <- as.vector(x_mat_test %*% c(delta_vec, eta_stack_vec))
test_pi_vec <- test_pi_vec1 + test_pi_vec2
loglik_post_mat[cv_id, lam_id] <- sum((y_vec_test * test_pi_vec - log(1 + exp(test_pi_vec))) * weight_vec_test)
loglik_post_mat[2, lam_id] <- sum((y_vec_test * test_pi_vec - log(1 + exp(test_pi_vec))) * weight_vec_test)
}
}else{
stop("`post_selection` must be set to `TRUE` for `Logistic_FARMM_CV_path()`!")
Expand Down

0 comments on commit a3f7a73

Please sign in to comment.