Skip to content

Latest commit

 

History

History
710 lines (526 loc) · 21.4 KB

train_and_predict.md

File metadata and controls

710 lines (526 loc) · 21.4 KB
N_CORES = 8

library(doMC)
registerDoMC(cores = N_CORES)

library(glmnet)
library(data.table)
library(ggplot2)
library(Matrix)
library(parallel)
library(boot)
library(ranger)
library(xgboost)
library(knitr)

Make image of objects to use on cluster nodes For each outcome, create datasets that duplicate states with the other outcome var so we’re predicting one with the other outcome and one without

states_combined_dt = fread('data/combined_pred_data.csv')

# Test sample (REMOVE TO USE FULL DATASET)
states_combined_dt = states_combined_dt[sitecode %in% c("NY","CA","VT","CT")]

state_years_w_responses = unique(states_combined_dt[year>= 2013 & (!is.na(q66) | !is.na(q67)), .(sitecode, year)])
print(nrow(state_years_w_responses))

## [1] 8

subset_dt = merge(states_combined_dt, state_years_w_responses, by = c('sitecode','year'))

id_vars = c('sitecode', 'census_region', 'census_division', 'year', 'weight')
varimp_inds = fread('data/varimp_v1.csv')
modeling_vars = varimp_inds[, var]
ss_vars = intersect(c(id_vars, modeling_vars),colnames(subset_dt))
subset_dt = subset_dt[, ..ss_vars]
preds = intersect(varimp_inds[pred == "y", var], colnames(subset_dt))

# States with answers to each question
state_years_w_q66 = unique(subset_dt[!is.na(q66), .(sitecode, year)])
state_years_w_q66[, have_q66 := 1]
state_years_w_q67 = unique(subset_dt[!is.na(q67), .(sitecode, year)])
state_years_w_q67[, have_q67 := 1]
print(str(state_years_w_q66))

## Classes 'data.table' and 'data.frame':   8 obs. of  3 variables:
##  $ sitecode: chr  "CA" "CA" "CT" "CT" ...
##  $ year    : int  2015 2017 2015 2017 2015 2017 2015 2017
##  $ have_q66: num  1 1 1 1 1 1 1 1
##  - attr(*, ".internal.selfref")=<externalptr> 
##  - attr(*, "sorted")= chr [1:2] "sitecode" "year"
## NULL

print(str(state_years_w_q67))

## Classes 'data.table' and 'data.frame':   8 obs. of  3 variables:
##  $ sitecode: chr  "CA" "CA" "CT" "CT" ...
##  $ year    : int  2015 2017 2015 2017 2015 2017 2015 2017
##  $ have_q67: num  1 1 1 1 1 1 1 1
##  - attr(*, ".internal.selfref")=<externalptr> 
##  - attr(*, "sorted")= chr [1:2] "sitecode" "year"
## NULL

getmode <- function(v) {
   uniqv <- unique(v)
   uniqv[which.max(tabulate(match(v, uniqv)))]
}

# Make training and prediction data for q66 (contact) answer
q66_dt = merge(subset_dt, state_years_w_q66, by = c('sitecode','year'))

# Add indicator for whether q67 answer is present
q66_dt = merge(q66_dt, state_years_w_q67, by = c('sitecode','year'), all.x=TRUE)
q66_dt[is.na(have_q67), have_q67 := 0]

q66_wo_q67 = q66_dt[have_q67 == 0]
q66_wo_q67[, predset := "Not Using Identity Responses"]

# Duplicate states that have q67 answer as well
q66_w_q67 = q66_dt[have_q67 == 1]
q66_w_q67_null_q67 = copy(q66_w_q67)
q66_w_q67_null_q67[, q67 := NA]

q66_w_q67[, predset := "Using Identity Responses"]
q66_w_q67_null_q67[, predset := "Not Using Identity Responses"]

q66_dup_dt = rbind(q66_wo_q67, q66_w_q67, q66_w_q67_null_q67)

# Encode outcome (same sex contacts)
q66_dup_dt[, Y := 0]
q66_dup_dt[q66 == 4, Y := 1]
q66_dup_dt[sex == 1 & q66 == 2, Y := 1]
q66_dup_dt[sex == 2 & q66 == 3, Y := 1]

print(q66_dup_dt[, table(Y, useNA="always")])

## Y
##      0      1   <NA> 
## 135482   9428      0

print(q66_dup_dt[, unique(sitecode)])

## [1] "CA" "CT" "NY" "VT"

# Deal with missing values
# Fill nulls with mode and add a separate is_missing indicator
preds_for_q66 = c(preds, 'q67')

# Remove preds without variation
preds_for_q66 = Filter(function(p) { length(unique(q66_dup_dt[, get(p)])) > 1 }, preds_for_q66)

for(p in preds_for_q66) {

  q66_dup_dt[, (paste0(p,"_is_missing")) := ifelse(is.na(get(p)),1,0)]
  q66_dup_dt[get(paste0(p,"_is_missing")) == 1, 
            (p) := getmode(q66_dup_dt[!is.na(get(p)),get(p)])]
  
  # make all predictors factors for LASSO
  q66_dup_dt[, (p) := as.factor(get(p))]
}

# make sure continuous predictors are still numeric
q66_dup_dt[, bmipct := as.numeric(as.character(bmipct))]
q66_dup_dt[, bmi := as.numeric(as.character(bmi))]
q66_dup_dt[, stheight := as.numeric(as.character(stheight))]
q66_dup_dt[, stweight := as.numeric(as.character(stweight))]

# Make year a factor for FEs
q66_dup_dt[, year := as.factor(year)]

# prepare model matrix
missing_preds = grep("_is_missing", colnames(q66_dup_dt), value=T)
q66_simp_formula_str = paste(" ~ -1 + census_region + census_division + year +",
                         paste(preds_for_q66, collapse = " + "), "+", 
                         paste(missing_preds, collapse = " + "))

X_q66 = sparse.model.matrix(as.formula(q66_simp_formula_str), data=q66_dup_dt)
Y_q66 = q66_dup_dt[, Y]
w_q66 = q66_dup_dt[, weight]

dt_q66_fn = q66_dup_dt[, .(sitecode, predset, year)]

print(dim(X_q66))

## [1] 144910    522

# Make training and prediction data for q67 (identity) answer
q67_dt = merge(subset_dt, state_years_w_q67, by = c('sitecode','year'))

# Add indicator for whether q66 answer is present
q67_dt = merge(q67_dt, state_years_w_q66, by = c('sitecode','year'), all.x=TRUE)
q67_dt[is.na(have_q66), have_q66 := 0]
q67_dt[, table(have_q66)]

## have_q66
##     1 
## 72455

q67_wo_q66 = q67_dt[have_q66 == 0]
q67_wo_q66[, predset := "Not Using Contact Responses"]

# Duplicate states that have q66 answer as well
q67_w_q66 = q67_dt[have_q66 == 1]
q67_w_q66_null_q67 = copy(q67_w_q66)
q67_w_q66_null_q67[, q66 := NA]

q67_w_q66[, predset := "Using Contact Responses"]
q67_w_q66_null_q67[, predset := "Not Using Contact Responses"]

q67_dup_dt = rbind(q67_wo_q66, q67_w_q66, q67_w_q66_null_q67)

# Encode outcome (LGB identity)
q67_dup_dt[, Y := 0]
q67_dup_dt[q67 == 2, Y := 1]
q67_dup_dt[q67 == 3, Y := 1]

print(q67_dup_dt[, table(Y, useNA="always")])

## Y
##      0      1   <NA> 
## 131142  13768      0

print(q67_dup_dt[, unique(sitecode)])

## [1] "CA" "CT" "NY" "VT"

# Remove preds without variation
preds_for_q67 = c(preds, 'q66')
preds_for_q67 = Filter(function(p) { length(unique(q67_dup_dt[, get(p)])) > 1 }, preds_for_q67)

# Deal with missing values
# Fill nulls with mode and add a separate is_missing indicator
for(p in preds_for_q67) {
  q67_dup_dt[, (paste0(p,"_is_missing")) := ifelse(is.na(get(p)),1,0)]
  q67_dup_dt[get(paste0(p,"_is_missing")) == 1, 
            (p) := getmode(q67_dup_dt[!is.na(get(p)),get(p)])]
  q67_dup_dt[, (p) := as.factor(get(p))]
}

# make sure continuous predictors are still numeric
q67_dup_dt[, bmipct := as.numeric(as.character(bmipct))]
q67_dup_dt[, bmi := as.numeric(as.character(bmi))]
q67_dup_dt[, stheight := as.numeric(as.character(stheight))]
q67_dup_dt[, stweight := as.numeric(as.character(stweight))]

# Make year a factor for FEs
q67_dup_dt[, year := as.factor(year)]

# prepare model matrix
missing_preds = grep("_is_missing", colnames(q67_dup_dt), value=T)
q67_simp_formula_str = paste(" ~ -1 + census_region + census_division + year +",
                         paste(preds_for_q67, collapse = " + "), "+", 
                         paste(missing_preds, collapse = " + "))

X_q67 = sparse.model.matrix(as.formula(q67_simp_formula_str), data=q67_dup_dt)
Y_q67 = q67_dup_dt[, Y]
w_q67 = q67_dup_dt[, weight]

dt_q67_fn = q67_dup_dt[, .(sitecode, predset, year)]
print(dim(X_q67))

## [1] 144910    522

# Make training and prediction data for q66 (contact) answer (male)

# Subset to males
subset_dt_males = subset_dt[sex==2]

# Make prediction data for q66 (contact) answer
q66_male_dt = merge(subset_dt_males, state_years_w_q66, by = c('sitecode','year'))

# Add indicator for whether q67 answer is present
q66_male_dt = merge(q66_male_dt, state_years_w_q67, by = c('sitecode','year'), all.x=TRUE)
q66_male_dt[is.na(have_q67), have_q67 := 0]

q66_male_wo_q67 = q66_male_dt[have_q67 == 0]
q66_male_wo_q67[, predset := "Not Using Identity Responses"]

# Duplicate states that have q67 answer as well
q66_male_w_q67 = q66_male_dt[have_q67 == 1]
q66_male_w_q67_null_q67 = copy(q66_male_w_q67)
q66_male_w_q67_null_q67[, q67 := NA]

q66_male_w_q67[, predset := "Using Identity Responses"]
q66_male_w_q67_null_q67[, predset := "Not Using Identity Responses"]

q66_male_dup_dt = rbind(q66_male_wo_q67, q66_male_w_q67, q66_male_w_q67_null_q67)

# Encode outcome (same sex contacts)
q66_male_dup_dt[, Y := 0]
q66_male_dup_dt[q66 == 4, Y := 1]
q66_male_dup_dt[sex == 1 & q66 == 2, Y := 1]
q66_male_dup_dt[sex == 2 & q66 == 3, Y := 1]

print(q66_male_dup_dt[, table(Y, useNA="always")])

## Y
##     0     1  <NA> 
## 68342  2702     0

print(q66_male_dup_dt[, unique(sitecode)])

## [1] "CA" "CT" "NY" "VT"

# Deal with missing values
# Fill nulls with mode and add a separate is_missing indicator
preds_for_q66_male = c(preds, 'q67')

# Remove preds without variation
preds_for_q66_male = Filter(function(p) { length(unique(q66_male_dup_dt[, get(p)])) > 1 }, preds_for_q66_male)

for(p in preds_for_q66_male) {
  
  q66_male_dup_dt[, (paste0(p,"_is_missing")) := ifelse(is.na(get(p)),1,0)]
  q66_male_dup_dt[get(paste0(p,"_is_missing")) == 1, 
            (p) := getmode(q66_male_dup_dt[!is.na(get(p)),get(p)])]
  
  # make all predictors factors for LASSO
  q66_male_dup_dt[, (p) := as.factor(get(p))]
}

# make sure continuous predictors are still numeric
q66_male_dup_dt[, bmipct := as.numeric(as.character(bmipct))]
q66_male_dup_dt[, bmi := as.numeric(as.character(bmi))]
q66_male_dup_dt[, stheight := as.numeric(as.character(stheight))]
q66_male_dup_dt[, stweight := as.numeric(as.character(stweight))]

# Make year a factor for FEs
q66_male_dup_dt[, year := as.factor(year)]

# prepare model matrix
missing_preds = grep("_is_missing", colnames(q66_male_dup_dt), value=T)
q66_male_simp_formula_str = paste(" ~ -1 + census_region + census_division + year +",
                         paste(preds_for_q66_male, collapse = " + "), "+", 
                         paste(missing_preds, collapse = " + "))

X_q66_male = sparse.model.matrix(as.formula(q66_male_simp_formula_str), data=q66_male_dup_dt)
Y_q66_male = q66_male_dup_dt[, Y]
w_q66_male = q66_male_dup_dt[, weight]

dt_q66_male_fn = q66_male_dup_dt[, .(sitecode, predset, year)]

print(dim(X_q66_male))

## [1] 71044   520

Define main function and Save image w data

# define fn 
gen_preds_for_state = function(state, X, Y, w, dt, model) {

  message(state)
  message(model)

  # filter data
  dt_filtered = dt[sitecode != state]
  X_filtered = X[dt[, sitecode] != state,]
  Y_filtered = Y[dt[, sitecode] != state]
  w_filtered = w[dt[, sitecode] != state]

  # Fit Models

  if(model %in% c("logit", "lasso", "lassolog", "ridge", "ridgelog")) {
    
    sitecode_ids = as.numeric(as.factor(dt_filtered$sitecode))
    STATES_PER_FOLD = 6
    sitecode_ids = ceiling(sitecode_ids/STATES_PER_FOLD)

    if(model == "logit") {
      m = glmnet(
        X_filtered, 
        Y_filtered,
        family = "binomial",
        type.logistic = "modified.Newton",
        lambda.min.ratio = 1e-99 # approximating zero lambda (regular logit)
      )
    
    } else if(model == "lasso") {
      
      m = cv.glmnet(
        X_filtered, 
        Y_filtered, 
        parallel = TRUE,
        foldid = sitecode_ids,
        nfolds = max(sitecode_ids),
        type.measure = "mse",
        nlambda = 10
      )

    } else if(model == "lassolog") {
      
      m = cv.glmnet(
        X_filtered, 
        Y_filtered,
        parallel = TRUE,
        foldid = sitecode_ids,
        nfolds = max(sitecode_ids),
        type.measure = "mse",
        family = "binomial",
        type.logistic = "modified.Newton",
        nlambda = 10)

    } else if(model == "ridge") {
      
      m = cv.glmnet(
        X_filtered, 
        Y_filtered,
        parallel = TRUE,
        foldid = sitecode_ids,
        nfolds = max(sitecode_ids),
        type.measure = "mse",
        alpha = 0,
        nlambda = 10
      )

    } else if(model == "ridgelog") {
      
      m = cv.glmnet(
        X_filtered,
        Y_filtered,
        parallel = TRUE,
        foldid = sitecode_ids,
        nfolds = max(sitecode_ids),
        type.measure = "mse",
        alpha = 0,
        family = "binomial",
        type.logistic = "modified.Newton",
        nlambda = 10)

    } else { stop("invalid model within glmnet") }

  } else if (model == "ols") {
      m = lm.fit(X_filtered, Y_filtered)

  } else if (model == "rf") {
      region_feat_names = grep("census_region", colnames(X), value = T)
      year_feat_names = grep("year", colnames(X), value = T)
      
      m = ranger(
        num.trees = 90,
        mtry = round(ncol(X_filtered)/3),
        min.node.size = 10,
        max.depth = 14,
        oob.error = FALSE,
        num.threads = N_CORES,
        verbose = F,
        seed = 13,
        classification = F,
        x = X_filtered,
        y = Y_filtered,
        always.split.variables = c(region_feat_names, year_feat_names)
      )

  } else if (model == "gbrt") {
      
      m = xgb.train(
        params = list(
          objective = "binary:logistic",
          max_depth = 5,
          eta = 0.10
        ),
        data = xgb.DMatrix(X_filtered, label = Y_filtered),
        nrounds = 189
      )

  } else { stop("invalid model argument")}

  # Make predictions

  predset_years_to_iterate_over = unique(dt[sitecode == state, .(predset,year)])

  pred_vals = rbindlist(lapply(seq_len(nrow(predset_years_to_iterate_over)), function(i) {

    pred_predset = predset_years_to_iterate_over$predset[i]
    pred_year = predset_years_to_iterate_over$year[i]
    message(state)
    message(pred_predset)
    message(pred_year)

    dt_pred = dt[sitecode == state & predset == pred_predset & year == pred_year]
    X_pred = X[dt[, sitecode == state & predset == pred_predset & year == pred_year],]
    w_pred = w[dt[, sitecode == state & predset == pred_predset & year == pred_year]]
    Y_pred = Y[dt[, sitecode == state & predset == pred_predset & year == pred_year]]

    # Change year FE to 2017
    X_pred_2017 = X_pred
    if("year2015" %in% colnames(X_pred_2017)) X_pred_2017[, "year2015"] = 0
    X_pred_2017[, "year2017"] = 1
    dt_pred_2017 = copy(dt_pred)
    dt_pred_2017[, year := "2017"]

    if(model %in% c("lasso", "lassolog", "ridge", "ridgelog")) {
      preds = predict(m, newx=X_pred, s = "lambda.min", type = "response")[,1]
      preds_2017 = predict(m, newx=X_pred_2017, s = "lambda.min", type = "response")[,1]

    } else if(model == "logit") {
      preds = predict(m, newx=X_pred, s = 0, type = "response")[,1]
      preds_2017 = predict(m, newx=X_pred_2017, s = 0, type = "response")[,1]
    
    } else if(model == "ols") {

      coefs = m$coefficients
      coefs = ifelse(is.na(coefs), 0, coefs)
      preds = (X_pred %*% coefs)[,1]
      preds_2017 = (X_pred_2017 %*% coefs)[,1]
      

    } else if (model == "rf") {
      preds = predict(m, data=X_pred, num.threads = N_CORES)$predictions
      preds_2017 = predict(m, data=X_pred_2017, num.threads = N_CORES)$predictions

    } else if (model == "gbrt") {
      preds = predict(m, X_pred)
      preds_2017 = predict(m, X_pred_2017)

    } else { stop("invalid model at prediction step") }

    wm_dt = data.table(
      prediction = preds,
      prediction_2017 = preds_2017,
      real_outcome = Y_pred,
      weight = w_pred)

    prev_preds = data.table(
        real_prop = wm_dt[,weighted.mean(real_outcome, weight, na.rm=T)],
        pred_prop = wm_dt[,weighted.mean(prediction, weight)], 
        pred_prop_2017 = wm_dt[,weighted.mean(prediction_2017, weight)],
        predset = pred_predset,
        year = pred_year,
        pred_state = state,
        model = model,
        n = length(Y_pred)
    )

    return(prev_preds)
  }))

  return(pred_vals)
}

Save images with data to save time

# q66 image
save(
  list = c("dt_q66_fn", "w_q66", "X_q66", "Y_q66", "gen_preds_for_state"),
  file = "data/test/yrbs_image_20220808_q66.RData", compress=F)

# q67 image
save(
  list = c("dt_q67_fn", "w_q67", "X_q67", "Y_q67", "gen_preds_for_state"),
  file = "data/test/yrbs_image_20220808_q67.RData", compress=F)

# q66 (male) image
save(
  list = c("dt_q66_male_fn", "w_q66_male", "X_q66_male", "Y_q66_male", "gen_preds_for_state"),
  file = "data/test/yrbs_image_20220808_q66_male.RData", compress=F)

Function to run on cluster

fn_to_run_on_nodes = function(q, state, OUT_PATH) {
  
  library(glmnet)
  library(doMC)
  registerDoMC(cores = N_CORES)
  library(data.table)
  library(ranger)
  library(Matrix)
  library(xgboost)
  
  # RUNNING JUST ONE MODEL FOR TESTING. UNCOMMENT FULL VECTOR TO RUN ALL MODELS
  models = c("rf") 
  # models = c("rf", "gbrt", "ols", "lasso", "ridge", "logit", "lassolog", "ridgelog")
  
  # Exclude models already run
  run_dt = as.data.table(data.table::transpose(strsplit(list.files(OUT_PATH), "_")))
  if(nrow(run_dt)>0) {
    setnames(run_dt, c("qs","model","st"))
    run_dt[, st := gsub(".csv", "", st)]
    models_left_to_run = setdiff(models, run_dt[qs == q & st == state, unique(model)])
    message(paste(models_left_to_run, collapse = ", "))
    if(length(models_left_to_run) == 0) {
      message("No models left to run")
      return(0)
    }
  } else {
    models_left_to_run = models
  }
   
  # load workspace and define arguments
  if(q == "q66") {
    load("data/test/yrbs_image_20220808_q66.RData")
    X = as.matrix(X_q66)
    Y = Y_q66
    w = w_q66
    dt = dt_q66_fn
    
  } else if (q == "q67") {
    load("data/test/yrbs_image_20220808_q67.RData")
    X = as.matrix(X_q67)
    Y = Y_q67
    w = w_q67
    dt = dt_q67_fn
  
  } else if (q == "q66m") {
    load("data/test/yrbs_image_20220808_q66_male.RData")
    X = as.matrix(X_q66_male)
    Y = Y_q66_male
    w = w_q66_male
    dt = dt_q66_male_fn
  
  } else {stop("invalid argument for q")}
    
  # run model and predictions, write to file
  
  for(m in models_left_to_run) {
    res = gen_preds_for_state(state, X, Y, w, dt,  m)
    fwrite(res, paste0(OUT_PATH, q, "_", m, "_", state, ".csv"))
    gc()
  }
  
}

# Run Q66 locally

OUT_PATH = "~/YRBS_predictions/data/test/test_output/"

load("data/test/yrbs_image_20220808_q66.RData")
q66_models_states = as.data.table(expand.grid(q = "q66", state = dt_q66_fn[, unique(sitecode)]))

lapply(seq_len(nrow(q66_models_states)), function(i) {

  q = q66_models_states$q[i]
  state = q66_models_states$state[i]
  
  fn_to_run_on_nodes(q, state, OUT_PATH)
})

## 

## No models left to run

## 

## No models left to run

## 

## No models left to run

## 

## No models left to run

## [[1]]
## [1] 0
## 
## [[2]]
## [1] 0
## 
## [[3]]
## [1] 0
## 
## [[4]]
## [1] 0

# Run Q66 male locally

OUT_PATH = "~/YRBS_predictions/data/test/test_output/"

load("data/test/yrbs_image_20220808_q66_male.RData")
q66_male_models_states = as.data.table(expand.grid(q = "q66m", state = dt_q66_male_fn[, unique(sitecode)]))

lapply(rev(seq_len(nrow(q66_male_models_states))), function(i) {

  q = q66_male_models_states$q[i]
  state = q66_male_models_states$state[i]
  
  fn_to_run_on_nodes(q, state, OUT_PATH)
})

## 

## No models left to run

## 

## No models left to run

## 

## No models left to run

## 

## No models left to run

## [[1]]
## [1] 0
## 
## [[2]]
## [1] 0
## 
## [[3]]
## [1] 0
## 
## [[4]]
## [1] 0

# Run Q67 locally

OUT_PATH = "~/YRBS_predictions/data/test/test_output/"

load("data/test/yrbs_image_20220808_q67.RData")
q67_models_states = as.data.table(expand.grid(q = "q67", state = dt_q67_fn[, unique(sitecode)]))

lapply(rev(seq_len(nrow(q67_models_states))), function(i) {

  q = q67_models_states$q[i]
  state = q67_models_states$state[i]
  
  fn_to_run_on_nodes(q, state, OUT_PATH)
})

## 

## No models left to run

## 

## No models left to run

## 

## No models left to run

## 

## No models left to run

## [[1]]
## [1] 0
## 
## [[2]]
## [1] 0
## 
## [[3]]
## [1] 0
## 
## [[4]]
## [1] 0