diff --git a/R/CATE_count.R b/R/CATE_count.R index 15d248a..df3437d 100644 --- a/R/CATE_count.R +++ b/R/CATE_count.R @@ -249,10 +249,15 @@ intxcount <- function(y, trt, x.cate, x.ps, time, data0 <- datatot_train[trt_train == 0, ] if (initial.predictor.method == "boosting") { + # if model has a single predictor, GBM must have cv.folds = 0 https://github.com/zoonproject/zoon/issues/130 + cate.cvfold <- if(ncol(x.cate) == 1){0} else {5} + fit1.boosting <- gbm(y ~ . - time + offset(time), data = data1, distribution = "poisson", + interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = cate.cvfold, ...) + best1.iter <- max(10, gbm.perf(fit1.boosting, method = "cv", plot.it = plot.gbmperf)) + - # TODO: if model has a single predictor, GBM must have cv.folds = 0 https://github.com/zoonproject/zoon/issues/130 fit1.boosting <- gbm(y ~ . - time + offset(time), data = data1, distribution = "poisson", - interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = 5, ...) + interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = cate.cvfold, ...) best1.iter <- max(10, gbm.perf(fit1.boosting, method = "cv", plot.it = plot.gbmperf)) withCallingHandlers({ f1.predictcv[index == k] <- predict(object = fit1.boosting, newdata = datatot_valid, n.trees = best1.iter, type = "response") @@ -263,7 +268,7 @@ intxcount <- function(y, trt, x.cate, x.ps, time, }) fit0.boosting <- gbm(y ~ . - time + offset(time), data = data0, distribution = "poisson", - interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = 5, ...) + interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = cate.cvfold, ...) best0.iter <- max(10, gbm.perf(fit0.boosting, method = "cv", plot.it = plot.gbmperf)) withCallingHandlers({ f0.predictcv[index == k] <- predict(object = fit0.boosting, newdata = datatot_valid, n.trees = best0.iter, type = "response") @@ -275,6 +280,7 @@ intxcount <- function(y, trt, x.cate, x.ps, time, best.iter <- max(best.iter, best1.iter, best0.iter) + } else if (initial.predictor.method == "poisson") { fit1.pois <- glm(y ~ . - time + offset(time), data = data1, family = "poisson") @@ -330,15 +336,19 @@ intxcount <- function(y, trt, x.cate, x.ps, time, if ("boosting" %in% score.method) { ## Boosting method based on the entire data (score 1) data1 <- datatot[trt == 1, ] - # TODO: if model has a single predictor, GBM must have cv.folds = 0 https://github.com/zoonproject/zoon/issues/130 + # if model has a single predictor, GBM must have cv.folds = 0 https://github.com/zoonproject/zoon/issues/130 + cate.cvfold <- if(ncol(x.cate) == 1){0} else {5} + cate.gbmmethod <- if(ncol(x.cate) == 1){"OOB"} else {"cv"} + if(cate.gbmmethod == "OOB") warning("If the model has a single predictor, GBM must use OOB") + fit1.boosting <- gbm(y ~ . - time + offset(time), data = data1, distribution = "poisson", - interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = 5, ...) - best1.iter <- max(10, gbm.perf(fit1.boosting, method = "cv", plot.it = plot.gbmperf)) + interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = cate.cvfold, ...) + best1.iter <- max(10, gbm.perf(fit1.boosting, method = cate.gbmmethod, plot.it = plot.gbmperf)) data0 <- datatot[trt == 0, ] fit0.boosting <- gbm(y ~ . - time + offset(time), data = data0, distribution = "poisson", - interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = 5, ...) - best0.iter <- max(10, gbm.perf(fit0.boosting, method = "cv", plot.it = plot.gbmperf)) + interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = cate.cvfold, ...) + best0.iter <- max(10, gbm.perf(fit0.boosting, method = cate.gbmmethod, plot.it = plot.gbmperf)) result$result.boosting <- list(fit0.boosting = fit0.boosting, best0.iter = best0.iter, fit1.boosting = fit1.boosting, best1.iter = best1.iter) best.iter <- max(best.iter, best1.iter, best0.iter)