Skip to content

Commit

Permalink
fix: ts_validation on refresh #960
Browse files Browse the repository at this point in the history
  • Loading branch information
laresbernardo committed Apr 30, 2024
1 parent d252591 commit 5158e1a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion R/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: Robyn
Type: Package
Title: Semi-Automated Marketing Mix Modeling (MMM) from Meta Marketing Science
Version: 3.11.0
Version: 3.10.6.9005
Authors@R: c(
person("Gufeng", "Zhou", , "gufeng@meta.com", c("cre","aut")),
person("Bernardo", "Lares", , "laresbernardo@gmail.com", c("aut")),
Expand Down
12 changes: 10 additions & 2 deletions R/R/refresh.R
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,15 @@ robyn_refresh <- function(json_file = NULL,
}

## Refresh hyperparameter bounds
ts_validation <- ifelse(
"ts_validation" %in% names(list(...)),
isTRUE(list(...)[["ts_validation"]]),
isTRUE(Robyn$listInit$OutputCollect$OutputModels$ts_validation))
InputCollectRF$hyperparameters <- refresh_hyps(
initBounds = Robyn$listInit$OutputCollect$hyper_updated,
listOutputPrev, refresh_steps,
rollingWindowLength = InputCollectRF$rollingWindowLength
rollingWindowLength = InputCollectRF$rollingWindowLength,
ts_validation = ts_validation
)

## Feature engineering for refreshed data
Expand All @@ -289,6 +294,7 @@ robyn_refresh <- function(json_file = NULL,
trials = refresh_trials,
refresh = TRUE,
add_penalty_factor = listOutputPrev[["add_penalty_factor"]],
ts_validation = ts_validation,
...
)
OutputCollectRF <- robyn_outputs(
Expand Down Expand Up @@ -527,7 +533,8 @@ Models (IDs):
#' @export
plot.robyn_refresh <- function(x, ...) plot((x$refresh$plots[[1]] / x$refresh$plots[[2]]), ...)

refresh_hyps <- function(initBounds, listOutputPrev, refresh_steps, rollingWindowLength) {
refresh_hyps <- function(initBounds, listOutputPrev, refresh_steps,
rollingWindowLength, ts_validation = FALSE) {
initBoundsDis <- unlist(lapply(initBounds, function(x) ifelse(length(x) == 2, x[2] - x[1], 0)))
newBoundsFreedom <- refresh_steps / rollingWindowLength
message(">>> New bounds freedom: ", round(100 * newBoundsFreedom, 2), "%")
Expand Down Expand Up @@ -559,5 +566,6 @@ refresh_hyps <- function(initBounds, listOutputPrev, refresh_steps, rollingWindo
hyper_updated_prev[hn][[1]] <- getRange
}
}
if (!ts_validation) hyper_updated_prev[["train_size"]] <- NULL
return(hyper_updated_prev)
}

0 comments on commit 5158e1a

Please sign in to comment.