From a330f5357cdd3e8f0ce2dae11a2da24dff65a0b1 Mon Sep 17 00:00:00 2001 From: Jason Poulos Date: Thu, 13 Oct 2022 15:07:29 -0400 Subject: [PATCH 1/2] db callbacks error --- R/Lrnr_gru_keras.R | 2 +- R/Lrnr_lstm_keras.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/Lrnr_gru_keras.R b/R/Lrnr_gru_keras.R index 0630316e..a8ef063e 100644 --- a/R/Lrnr_gru_keras.R +++ b/R/Lrnr_gru_keras.R @@ -186,7 +186,7 @@ Lrnr_gru_keras <- R6Class( y = args$y, batch_size = args$batch_size, epochs = args$epochs, - callbacks = callbacks, + callbacks = args$callbacks, verbose = verbose, shuffle = FALSE ) diff --git a/R/Lrnr_lstm_keras.R b/R/Lrnr_lstm_keras.R index ba27992d..91d1c854 100644 --- a/R/Lrnr_lstm_keras.R +++ b/R/Lrnr_lstm_keras.R @@ -184,7 +184,7 @@ Lrnr_lstm_keras <- R6Class( y = args$y, batch_size = args$batch_size, epochs = args$epochs, - callbacks = callbacks, + callbacks = args$callbacks, verbose = verbose, shuffle = FALSE ) From 88e72df54849fca923eb4d1d7041ff5b1bda3b94 Mon Sep 17 00:00:00 2001 From: Jason Poulos Date: Fri, 14 Oct 2022 14:51:03 -0400 Subject: [PATCH 2/2] add validation split option --- R/Lrnr_gru_keras.R | 5 ++++- R/Lrnr_lstm_keras.R | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/R/Lrnr_gru_keras.R b/R/Lrnr_gru_keras.R index a8ef063e..a774fa34 100644 --- a/R/Lrnr_gru_keras.R +++ b/R/Lrnr_gru_keras.R @@ -46,6 +46,7 @@ #' be applied at given stages of the training procedure. Default callback #' function \code{callback_early_stopping} stops training if the validation #' loss does not improve across \code{patience} number of epochs. +#' - \code{validation_split}: Fraction of the training data to be used as validation data. Default is 0 (no validation). #' - \code{...}: Other parameters passed to \code{\link[keras]{keras}}. #' #' @examples @@ -74,7 +75,7 @@ #' valid_task <- validation(task, fold = task$folds[[1]]) #' #' # instantiate learner, then fit and predict (simplifed example) -#' gru_lrnr <- Lrnr_gru_keras$new(batch_size = 1, epochs = 200) +#' gru_lrnr <- Lrnr_gru_keras$new(batch_size = 1, epochs = 200, validation_split=0.2) #' gru_fit <- gru_lrnr$train(train_task) #' gru_preds <- gru_fit$predict(valid_task) #' } @@ -95,6 +96,7 @@ Lrnr_gru_keras <- R6Class( callbacks = list( keras::callback_early_stopping(patience = 10) ), + validation_split=0, ...) { params <- args_to_list() super$initialize(params = params, ...) @@ -187,6 +189,7 @@ Lrnr_gru_keras <- R6Class( batch_size = args$batch_size, epochs = args$epochs, callbacks = args$callbacks, + validation_split= args$validation_split, verbose = verbose, shuffle = FALSE ) diff --git a/R/Lrnr_lstm_keras.R b/R/Lrnr_lstm_keras.R index 91d1c854..022d59a8 100644 --- a/R/Lrnr_lstm_keras.R +++ b/R/Lrnr_lstm_keras.R @@ -44,6 +44,7 @@ #' be applied at given stages of the training procedure. Default callback #' function \code{callback_early_stopping} stops training if the validation #' loss does not improve across \code{patience} number of epochs. +#' - \code{validation_split}: Fraction of the training data to be used as validation data. Default is 0 (no validation). #' - \code{...}: Other parameters passed to \code{\link[keras]{keras}}. #' #' @examples @@ -72,7 +73,7 @@ #' valid_task <- validation(task, fold = task$folds[[1]]) #' #' # instantiate learner, then fit and predict (simplifed example) -#' lstm_lrnr <- Lrnr_lstm_keras$new(batch_size = 1, epochs = 200) +#' lstm_lrnr <- Lrnr_lstm_keras$new(batch_size = 1, epochs = 200, validation_split=0.2) #' lstm_fit <- lstm_lrnr$train(train_task) #' lstm_preds <- lstm_fit$predict(valid_task) #' } @@ -93,6 +94,7 @@ Lrnr_lstm_keras <- R6Class( lr = 0.001, layers = 1, callbacks = list(keras::callback_early_stopping(patience = 10)), + validation_split=0, ...) { params <- args_to_list() super$initialize(params = params, ...) @@ -185,6 +187,7 @@ Lrnr_lstm_keras <- R6Class( batch_size = args$batch_size, epochs = args$epochs, callbacks = args$callbacks, + validation_split= args$validation_split, verbose = verbose, shuffle = FALSE )