Skip to content

Commit

Permalink
fixed error when model has single predictor in boosting
Browse files Browse the repository at this point in the history
  • Loading branch information
StanWijn committed Sep 28, 2022
1 parent 8663fca commit eb93cb2
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions R/CATE_count.R
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,15 @@ intxcount <- function(y, trt, x.cate, x.ps, time,
data0 <- datatot_train[trt_train == 0, ]

if (initial.predictor.method == "boosting") {
# if model has a single predictor, GBM must have cv.folds = 0 https://github.com/zoonproject/zoon/issues/130
cate.cvfold <- if(ncol(x.cate) == 1){0} else {5}
fit1.boosting <- gbm(y ~ . - time + offset(time), data = data1, distribution = "poisson",
interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = cate.cvfold, ...)
best1.iter <- max(10, gbm.perf(fit1.boosting, method = "cv", plot.it = plot.gbmperf))


# TODO: if model has a single predictor, GBM must have cv.folds = 0 https://github.com/zoonproject/zoon/issues/130
fit1.boosting <- gbm(y ~ . - time + offset(time), data = data1, distribution = "poisson",
interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = 5, ...)
interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = cate.cvfold, ...)
best1.iter <- max(10, gbm.perf(fit1.boosting, method = "cv", plot.it = plot.gbmperf))
withCallingHandlers({
f1.predictcv[index == k] <- predict(object = fit1.boosting, newdata = datatot_valid, n.trees = best1.iter, type = "response")
Expand All @@ -263,7 +268,7 @@ intxcount <- function(y, trt, x.cate, x.ps, time,
})

fit0.boosting <- gbm(y ~ . - time + offset(time), data = data0, distribution = "poisson",
interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = 5, ...)
interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = cate.cvfold, ...)
best0.iter <- max(10, gbm.perf(fit0.boosting, method = "cv", plot.it = plot.gbmperf))
withCallingHandlers({
f0.predictcv[index == k] <- predict(object = fit0.boosting, newdata = datatot_valid, n.trees = best0.iter, type = "response")
Expand All @@ -275,6 +280,7 @@ intxcount <- function(y, trt, x.cate, x.ps, time,

best.iter <- max(best.iter, best1.iter, best0.iter)


} else if (initial.predictor.method == "poisson") {

fit1.pois <- glm(y ~ . - time + offset(time), data = data1, family = "poisson")
Expand Down Expand Up @@ -330,15 +336,19 @@ intxcount <- function(y, trt, x.cate, x.ps, time,
if ("boosting" %in% score.method) {
## Boosting method based on the entire data (score 1)
data1 <- datatot[trt == 1, ]
# TODO: if model has a single predictor, GBM must have cv.folds = 0 https://github.com/zoonproject/zoon/issues/130
# if model has a single predictor, GBM must have cv.folds = 0 https://github.com/zoonproject/zoon/issues/130
cate.cvfold <- if(ncol(x.cate) == 1){0} else {5}
cate.gbmmethod <- if(ncol(x.cate) == 1){"OOB"} else {"cv"}
if(cate.gbmmethod == "OOB") warning("If the model has a single predictor, GBM must use OOB")

fit1.boosting <- gbm(y ~ . - time + offset(time), data = data1, distribution = "poisson",
interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = 5, ...)
best1.iter <- max(10, gbm.perf(fit1.boosting, method = "cv", plot.it = plot.gbmperf))
interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = cate.cvfold, ...)
best1.iter <- max(10, gbm.perf(fit1.boosting, method = cate.gbmmethod, plot.it = plot.gbmperf))

data0 <- datatot[trt == 0, ]
fit0.boosting <- gbm(y ~ . - time + offset(time), data = data0, distribution = "poisson",
interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = 5, ...)
best0.iter <- max(10, gbm.perf(fit0.boosting, method = "cv", plot.it = plot.gbmperf))
interaction.depth = tree.depth, n.trees = n.trees.boosting, cv.folds = cate.cvfold, ...)
best0.iter <- max(10, gbm.perf(fit0.boosting, method = cate.gbmmethod, plot.it = plot.gbmperf))

result$result.boosting <- list(fit0.boosting = fit0.boosting, best0.iter = best0.iter, fit1.boosting = fit1.boosting, best1.iter = best1.iter)
best.iter <- max(best.iter, best1.iter, best0.iter)
Expand Down

0 comments on commit eb93cb2

Please sign in to comment.