Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: check and fix all_sol_json & new pareto_df parameters inputs #843

Merged
merged 5 commits into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions R/R/inputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ robyn_inputs <- function(dt_input = NULL,
# Check for no-variance columns on raw data (after removing not-used)
check_novar(select(dt_input, -all_of(unused_vars)))

# Calculate total media spend used to model
paid_media_total <- dt_input[
rollingWindowEndWhich:rollingWindowLength, ] %>%
select(paid_media_vars) %>% sum()

## Collect input
InputCollect <- list(
dt_input = dt_input,
Expand All @@ -294,6 +299,7 @@ robyn_inputs <- function(dt_input = NULL,
paid_media_vars = paid_media_vars,
paid_media_signs = paid_media_signs,
paid_media_spends = paid_media_spends,
paid_media_total = paid_media_total,
mediaVarCount = mediaVarCount,
exposure_vars = exposure_vars,
organic_vars = organic_vars,
Expand All @@ -307,6 +313,7 @@ robyn_inputs <- function(dt_input = NULL,
window_end = window_end,
rollingWindowEndWhich = rollingWindowEndWhich,
rollingWindowLength = rollingWindowLength,
totalObservations = nrow(dt_input),
refreshAddedStart = refreshAddedStart,
adstock = adstock,
hyperparameters = hyperparameters,
Expand Down
32 changes: 21 additions & 11 deletions R/R/json.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#' @param select_model Character. Which model ID do you want to export
#' into the JSON file?
#' @param dir Character. Existing directory to export JSON file to.
#' @param all_sol_json Dataframe. Add all pareto solutions to json.
#' @param pareto_df Dataframe. Save all pareto solutions to json file.
#' @param ... Additional parameters.
#' @examples
#' \dontrun{
Expand All @@ -36,7 +36,7 @@ robyn_write <- function(InputCollect,
OutputModels = NULL,
export = TRUE,
quiet = FALSE,
all_sol_json = NULL,
pareto_df = NULL,
...) {
# Checks
stopifnot(inherits(InputCollect, "robyn_inputs"))
Expand Down Expand Up @@ -70,6 +70,7 @@ robyn_write <- function(InputCollect,
outputs <- list()
outputs$select_model <- select_model
outputs$ts_validation <- OutputCollect$OutputModels$ts_validation
outputs$export_timestamp <- Sys.time()
outputs$run_time <- run_time
outputs$outputs_time <- outputs_time
outputs$total_time <- total_time
Expand All @@ -87,6 +88,9 @@ robyn_write <- function(InputCollect,
)
outputs$errors <- filter(OutputCollect$resultHypParam, .data$solID == select_model) %>%
select(starts_with("rsq_"), starts_with("nrmse"), .data$decomp.rssd, .data$mape)
if ("clusters" %in% names(OutputCollect)) {
outputs$clusters <- OutputCollect$clusters$n_clusters
}
outputs$hyper_values <- OutputCollect$resultHypParam %>%
filter(.data$solID == select_model) %>%
select(contains(HYPS_NAMES), dplyr::ends_with("_penalty"), any_of(HYPS_OTHERS)) %>%
Expand All @@ -109,15 +113,21 @@ robyn_write <- function(InputCollect,
attr(ret, "json_file") <- filename
if (export) {
if (!quiet) message(sprintf(">> Exported model %s as %s", select_model, filename))
if (!is.null(all_sol_json)) {
all_c <- unique(all_sol_json$cluster)
all_sol_json <- lapply(all_c, function(x) {
(all_sol_json %>% filter(.data$cluster == x))$solID
})
names(all_sol_json) <- paste0("cluster", all_c)
ret[["InputCollect"]][["total_time"]] <- total_time
ret[["InputCollect"]][["total_iters"]] <- OutputModels$iterations * OutputModels$trials
ret[["OutputCollect"]][["all_sols"]] <- all_sol_json
if (!is.null(pareto_df)) {
if (!all(c("solID", "cluster") %in% names(pareto_df))) {
warning(paste(
"Input 'pareto_df' is not a valid data.frame;",
"must contain 'solID' and 'cluster' columns."))
} else {
all_c <- unique(pareto_df$cluster)
pareto_df <- lapply(all_c, function(x) {
(pareto_df %>% filter(.data$cluster == x))$solID
})
names(pareto_df) <- paste0("cluster", all_c)
ret[["InputCollect"]][["total_time"]] <- total_time
ret[["InputCollect"]][["total_iters"]] <- OutputModels$iterations * OutputModels$trials
ret[["OutputCollect"]][["all_sols"]] <- pareto_df
}
}
write_json(ret, filename, pretty = TRUE, digits = 10)
}
Expand Down
1 change: 1 addition & 0 deletions R/R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ robyn_run <- function(InputCollect = NULL,
attr(OutputModels, "refresh") <- refresh

if (TRUE) {
OutputModels$train_timestamp <- Sys.time()
OutputModels$cores <- cores
OutputModels$iterations <- iterations
OutputModels$trials <- trials
Expand Down
6 changes: 3 additions & 3 deletions R/R/outputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,18 @@ robyn_outputs <- function(InputCollect, OutputModels,
}

if (all_sol_json) {
all_sol_json <- OutputCollect$resultHypParam %>%
pareto_df <- OutputCollect$resultHypParam %>%
filter(!is.na(.data$cluster)) %>%
select(c("solID", "cluster", "top_sol")) %>%
arrange(.data$cluster, -.data$top_sol, .data$solID)
} else {
all_sol_json <- NULL
pareto_df <- NULL
}
robyn_write(
InputCollect = InputCollect,
OutputModels = OutputModels,
dir = plot_folder, quiet = quiet,
all_sol_json = all_sol_json
pareto_df = pareto_df, ...
)

# For internal use -> UI Code
Expand Down
4 changes: 2 additions & 2 deletions R/man/robyn_write.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.