From f4530c6a36021da463a33d3c13ab7b15a4f20bac Mon Sep 17 00:00:00 2001 From: Martin Date: Tue, 29 Oct 2024 13:10:36 +0100 Subject: [PATCH 1/5] cat-methods for L0 --- R/process.R | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/R/process.R b/R/process.R index 24434d0..d22265c 100644 --- a/R/process.R +++ b/R/process.R @@ -129,6 +129,35 @@ get_measure_L0 <- function(res.dt,x_explain_mutable,x_sim_mutable){ res.dt[id_explain==i,measure_L0:=value] } } + +get_measure_L0_new <- function(res.dt,x_explain_mutable,x_sim_mutable){ + n_explain <- x_explain_mutable[,.N] + for (i in seq_len(n_explain)){ + org <- x_explain_mutable[id_explain==i,-"id_explain"] + sim <- x_sim_mutable[id_explain==i,-"id_explain"] + value <- apply(X = sim,FUN=function(x) sum(x == org),MARGIN = 1) + + res.dt[id_explain==i,measure_L0:=value] + } +} + +get_measure_L0_new2 <- function(res.dt,x_explain_mutable,x_sim_mutable){ + + combined <- x_explain_mutable[x_sim_mutable, on = "id_explain", nomatch = 0] + + # Identify columns to compare (exclude `id_explain`) + columns_to_compare <- setdiff(names(x_explain_mutable), "id_explain") + + # Calculate the number of identical values in each row + value <- combined[, rowSums(mapply(function(col1, col2) col1 == col2, + .SD[, columns_to_compare, with = FALSE], + .SD[, paste0("i.",columns_to_compare), with = FALSE]))] + + res.dt[,measure_L0 := value] +} + + + get_measure_L1 <- function(res.dt,x_explain_mutable,x_sim_mutable){ n_explain <- x_explain_mutable[,.N] for (i in seq_len(n_explain)){ From 67cf1dd9ab3c2e06aee1b1ca4f0c8b78202c7339 Mon Sep 17 00:00:00 2001 From: Martin Date: Tue, 29 Oct 2024 14:02:16 +0100 Subject: [PATCH 2/5] use new L0 formula --- R/process.R | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/R/process.R b/R/process.R index d22265c..b55ac7d 100644 --- a/R/process.R +++ b/R/process.R @@ -101,16 +101,19 @@ process = function(x_sim, res.dt <- res.dt[measure_validation==1] } - ret_sim0 <- res.dt[,head(.SD,return_best_k),by=id_explain][,.(row_id,counterfactual_rank,pred)] + cols <- c("row_id","counterfactual_rank","pred",paste0("measure_",measures)) + + ret_sim0 <- res.dt[,head(.SD,return_best_k),by=id_explain][,cols,with = FALSE] ret_sim <- x_sim[ret_sim0,on="row_id"] ret_sim[,row_id:=NULL] time_process = difftime(Sys.time(), time_process_start, units = "secs") - data.table::setcolorder(ret_sim,c("id_explain","counterfactual_rank","pred")) + data.table::setcolorder(ret_sim,cols[-1]) - ret <- list(cf=ret_sim, + ret <- list(cf=ret_sim[,-cols[-1],with = FALSE], + cf_measures = ret_sim[,cols[-1],with = FALSE], time_process = time_process) return(ret) @@ -123,25 +126,7 @@ get_measure_validation <- function(res.dt,x_explain_mutable,x_sim_mutable,c_int) get_measure_L0 <- function(res.dt,x_explain_mutable,x_sim_mutable){ - n_explain <- x_explain_mutable[,.N] - for (i in seq_len(n_explain)){ - value <- Rfast::colsums(unlist(x_explain_mutable[id_explain==i,-1])-t(as.matrix(x_sim_mutable[id_explain==i,-1]))==0,parallel=T) - res.dt[id_explain==i,measure_L0:=value] - } -} - -get_measure_L0_new <- function(res.dt,x_explain_mutable,x_sim_mutable){ - n_explain <- x_explain_mutable[,.N] - for (i in seq_len(n_explain)){ - org <- x_explain_mutable[id_explain==i,-"id_explain"] - sim <- x_sim_mutable[id_explain==i,-"id_explain"] - value <- apply(X = sim,FUN=function(x) sum(x == org),MARGIN = 1) - - res.dt[id_explain==i,measure_L0:=value] - } -} - -get_measure_L0_new2 <- function(res.dt,x_explain_mutable,x_sim_mutable){ + n_features <- ncol(x_explain_mutable)-1 combined <- x_explain_mutable[x_sim_mutable, on = "id_explain", nomatch = 0] @@ -153,11 +138,10 @@ get_measure_L0_new2 <- function(res.dt,x_explain_mutable,x_sim_mutable){ .SD[, columns_to_compare, with = FALSE], .SD[, paste0("i.",columns_to_compare), with = FALSE]))] - res.dt[,measure_L0 := value] + res.dt[,measure_L0 := n_features-value] # Number of features changed from original value } - get_measure_L1 <- function(res.dt,x_explain_mutable,x_sim_mutable){ n_explain <- x_explain_mutable[,.N] for (i in seq_len(n_explain)){ From b5289124f001db25291380886e8097735e487fa6 Mon Sep 17 00:00:00 2001 From: Martin Date: Tue, 29 Oct 2024 15:06:17 +0100 Subject: [PATCH 3/5] add gower + new output --- R/explain.R | 1 + R/process.R | 53 +++++++++++++++++--- inst/scripts/devel_model4.R | 98 +++++++++++++++++++++++++++++++++++++ 3 files changed, 144 insertions(+), 8 deletions(-) create mode 100644 inst/scripts/devel_model4.R diff --git a/R/explain.R b/R/explain.R index 5792a6c..925ccf1 100644 --- a/R/explain.R +++ b/R/explain.R @@ -144,6 +144,7 @@ explain_mcce = function(model, x_explain, x_train, predict_model=NULL, process.time=cfs$time_process) ret <- list(cf = cfs$cf[], + cf_measures = cfs$cf_measures[], fixed_features = fit_object$fixed_features, mutable_features = fit_object$mutable_features, time = time_vec) diff --git a/R/process.R b/R/process.R index b55ac7d..f31d709 100644 --- a/R/process.R +++ b/R/process.R @@ -88,7 +88,10 @@ process = function(x_sim, get_measure_L2(res.dt,x_explain_mutable,x_sim_mutable) measure_ordering <- c(measure_ordering,1) } - + if(this_measure=="gower"){ + get_measure_gower(res.dt,x_explain_mutable,x_sim_mutable) + measure_ordering <- c(measure_ordering,1) + } } @@ -101,19 +104,22 @@ process = function(x_sim, res.dt <- res.dt[measure_validation==1] } - cols <- c("row_id","counterfactual_rank","pred",paste0("measure_",measures)) + cols <- c("counterfactual_rank","row_id","pred",paste0("measure_",measures)) ret_sim0 <- res.dt[,head(.SD,return_best_k),by=id_explain][,cols,with = FALSE] ret_sim <- x_sim[ret_sim0,on="row_id"] - ret_sim[,row_id:=NULL] time_process = difftime(Sys.time(), time_process_start, units = "secs") - data.table::setcolorder(ret_sim,cols[-1]) - ret <- list(cf=ret_sim[,-cols[-1],with = FALSE], - cf_measures = ret_sim[,cols[-1],with = FALSE], + data.table::setcolorder(ret_sim,c("id_explain",cols[-2])) + + cols_measure <- c("id_explain",cols[-2]) + cols_cf <- names(ret_sim)[!(names(ret_sim) %in% cols[-1])] + + ret <- list(cf=ret_sim[,cols_cf,with = FALSE], + cf_measures = ret_sim[,cols_measure,with = FALSE], time_process = time_process) return(ret) @@ -133,15 +139,21 @@ get_measure_L0 <- function(res.dt,x_explain_mutable,x_sim_mutable){ # Identify columns to compare (exclude `id_explain`) columns_to_compare <- setdiff(names(x_explain_mutable), "id_explain") + value <- identical_rows(combined,columns_to_compare) + + res.dt[,measure_L0 := n_features-value] # Number of features changed from original value +} + +identical_rows <- function(combined,columns_to_compare){ + # Calculate the number of identical values in each row value <- combined[, rowSums(mapply(function(col1, col2) col1 == col2, .SD[, columns_to_compare, with = FALSE], .SD[, paste0("i.",columns_to_compare), with = FALSE]))] - - res.dt[,measure_L0 := n_features-value] # Number of features changed from original value } + get_measure_L1 <- function(res.dt,x_explain_mutable,x_sim_mutable){ n_explain <- x_explain_mutable[,.N] for (i in seq_len(n_explain)){ @@ -161,3 +173,28 @@ get_measure_L2 <- function(res.dt,x_explain_mutable,x_sim_mutable){ } +get_measure_gower <- function(res.dt,x_explain_mutable,x_sim_mutable){ + + cat_cols <- names(which(sapply(x_explain_mutable[,-1],is.factor))) + num_cols <- names(which(sapply(x_explain_mutable[,-1],is.numeric))) + + combined <- x_explain_mutable[x_sim_mutable, on = "id_explain", nomatch = 0] + + if(length(cat_cols)>0){ + cat_contrib <- identical_rows(combined,cat_cols) + } else { + cat_contrib <- 0 + } + num_contrib <- 0 + if(length(num_cols)>0){ + for(i in seq_along(num_cols)){ + col1 <- unlist(combined[,num_cols[i],with=FALSE]) + col2 <- unlist(combined[,paste0("i.",num_cols[i]),with=FALSE]) + + num_contrib <- num_contrib + abs(col1-col2)/range(col2) # Using the range of the syntehtic data instead of the original for simplicity + } + } + res.dt[,measure_gower:=cat_contrib + num_contrib] +} + + diff --git a/inst/scripts/devel_model4.R b/inst/scripts/devel_model4.R new file mode 100644 index 0000000..e8f15a5 --- /dev/null +++ b/inst/scripts/devel_model4.R @@ -0,0 +1,98 @@ + + +devtools::load_all(".") + +data(iris) + +iris_binary <- as.data.table(iris) +iris_binary[,y:=ifelse(Species=="virginica",0,1)] +iris_binary[,y:=as.factor(y)] +iris_binary[,Sepal.length_cat := as.factor(ifelse(Sepal.Length>5.5,1,0))] + +iris_binary[,Species:=NULL] +iris_binary_test <- iris_binary[c(83,53,70,45,44)] +iris_binary_train <- iris_binary[-c(83,53,70,45,44)] + + +model <- glm(y~.,family = "binomial",data = iris_binary_train) + +predict(object = model,newdata=iris_binary_test,type="response") + +cf <- explain_mcce(model = model, + x_explain = iris_binary_test[,-"y"], + x_train = iris_binary_train[,-"y"], + process.measures = c("validation","L0","gower"), + c_int=c(0,0.5), + fit.seed = 123, + generate.seed = 123) + + + + +#library(xgboost) +data("Boston", package = "MASS") + +x_var <- c("lstat", "rm", "dis", "indus") +y_var <- "medv" +xy_var <- c(x_var,y_var) + + +ind_x_test <- 1:6 +xy_train <-Boston[-ind_x_test, xy_var] +x_train <-Boston[-ind_x_test, x_var] +x_test <- Boston[ind_x_test, x_var] + + +model <- lm(as.formula(paste0(y_var,"~.")),data = xy_train) + +predict(model,x_test) + +### Manually adjust model parameters to match those in python: + +model$coefficients <- c(8.632407012695102,-0.6668123 , 4.54767491, -0.92003323, -0.24707083) + +predict(model,x_test) + + +cf.Boston <- explain_mcce(model = model, + x_explain = x_test, + x_train = x_train, + fixed_features = "lstat", + c_int=c(35,1000),fit.seed = 123,generate.seed = 123) + +cf.Boston$ + + +cf.Boston$cf[] + + +#### Xgbost model on these data as well: +library(xgboost) +data("Boston", package = "MASS") + +x_var <- c("lstat", "rm", "dis", "indus") +y_var <- "medv" + +ind_x_test <- 1:6 +x_train <- as.matrix(Boston[-ind_x_test, x_var]) +y_train <- Boston[-ind_x_test, y_var] +x_test <- as.matrix(Boston[ind_x_test, x_var]) + +# Fitting a basic xgboost model to the training data +model <- xgboost( + data = x_train, + label = y_train, + nround = 20, + verbose = FALSE +) + +predict(model,x_test) + +cf.Boston <- explain_mcce(model = model, + x_explain = x_test, + x_train = x_train, + fixed_features = "lstat", + c_int=c(35,1000)) + + + From 618c07979cfcb2750a8dc303da20aba89bca1af4 Mon Sep 17 00:00:00 2001 From: Martin Date: Tue, 29 Oct 2024 15:07:20 +0100 Subject: [PATCH 4/5] test updates --- tests/testthat/_snaps/output.md | 18 ++++++++++-------- .../_snaps/output/output_lm_numeric_ctree.rds | Bin 383 -> 472 bytes .../_snaps/output/output_lm_numeric_rpart.rds | Bin 381 -> 469 bytes 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/testthat/_snaps/output.md b/tests/testthat/_snaps/output.md index c78567c..ed30616 100644 --- a/tests/testthat/_snaps/output.md +++ b/tests/testthat/_snaps/output.md @@ -3,18 +3,20 @@ Code (out <- code) Output - id_explain counterfactual_rank pred Solar.R Wind Temp Month Day - 1: 1 1 14.051 99 7.4 59 5 4 - 2: 2 1 9.817 224 10.9 59 5 8 - 3: 3 1 10.754 238 15.5 72 9 19 + id_explain counterfactual_rank Solar.R Wind Temp Month Day + + 1: 1 1 284 7.4 59 9 5 + 2: 2 1 224 10.9 57 9 9 + 3: 3 1 44 15.5 72 8 21 # output_lm_numeric_rpart Code (out <- code) Output - id_explain counterfactual_rank pred Solar.R Wind Temp Month Day - 1: 1 1 10.39 92 7.4 59 6 3 - 2: 2 1 10.64 224 10.9 59 5 11 - 3: 3 1 10.75 238 15.5 72 9 19 + id_explain counterfactual_rank Solar.R Wind Temp Month Day + + 1: 1 1 99 7.4 63 9 14 + 2: 2 1 284 10.9 63 9 9 + 3: 3 1 238 15.5 71 9 21 diff --git a/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds b/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds index c014e5e0c769e8be533dcc99c3229f01f895fb29..43c61725d49d32bab691877f22359dc3de7bb4e6 100644 GIT binary patch literal 472 zcmV;}0Vn<+iwFP!000001I1H2PQySD-Pq2DAfgZiG$7GH1qQkx3Q{0M0SzC4Xw1f5 zD@)dP@CQQ2HQ*8)fF=h3QE&koTEqpwtnJwlOb~)5D?QK5oAvDM`1!Di5R#E1OGwUg z7iIVNHWsR!8{>fjHqS8_Rw;`WJSo}(0Twdt#9UDsysqIF+eL}H0L zs~jsFn?2kRpwPo*fg9&7O7Xc$9A|ZgJ&!x(xXFYp8ra9GkL!$lVN4fWBvoL@bvm{|7% z%rqCHyJ=9(EwuY)=9@DKUxj@064KG|GU4ClBZd9d zB8NYSBKMzfp|o5ghF`G=_5zn^XP8mm4s}FZ7cJL3HXx zwURx`4`TfNPw3uMl#p~2ykXtcpd$~cl1bad&`g5;9=2=+@yGw<`|CupbX_P!Vbt@Z OTYdniPH$yj1poje`Qcpv literal 383 zcmV-_0f7D=iwFP!000001B>8dU|?WoU}0orU}gm}8CXL@+;lB~V!~hv2+aY+ybwME zBM>t|*~nrr%m)&4(2{OY=%ni_#L4OG*=S;)@dVvY}Ee1x2YTP&RvTeokVMUJy(wJTng_ z#S)U5TL81vH$Sf=1Il4`NvuS7RB}#YG0ag+Fh`{%mL%$xBqrsgqKKpw0eygG0B2Et zxgN}Cpdess022TI|No1oP*5)u9%}q{*MHx#@W^oC~O8}UBDk)|I004`gpSl15 diff --git a/tests/testthat/_snaps/output/output_lm_numeric_rpart.rds b/tests/testthat/_snaps/output/output_lm_numeric_rpart.rds index 7a3f5fa18e5c53fd1a69956bd77e2f08eabac33d..d7278bfcf85488bec950d3bb90e97474e662bb3b 100644 GIT binary patch literal 469 zcmV;`0V@6M}i*TJJVf(2XLFx z6rU)HUOP4p=*$hN?_#b8l-oYF1vm=w22xH&2X`{dq>UQK%mrkKc=b4kjY(L?-1M<~ z0-^k_$FZ;PfUE5@*OGRXIyKqumgk1Y;HYc(ymK6~lYb~A9zepK+%LG3@?_RITsztr zGbZ~lA77$rW9s2{<>F>o#w*V$X*zZ5d?9rV0HxoGuQpSB8}zaa$HU7+zo(B9b!Uqj zeM5@eU*1Bi`BV;n(4jb}`_w$eoQZOnCleH9WOLyQvY3{Yi~hZg)+CfJ%qm!vuprKr z?$N#>NRL~{s{4AD3X!nn(K LKEl8fUj+aFidE#; literal 381 zcmV-@0fPP?iwFP!000001I1ChPQx%1bsmMbh^nX(8ygu3BSV>lH%zZwkzodteBdDz0)G9m#o%PZ#IT`lILcUmq^FL(jP$!l+!e)h=lBNY(6;5JCKzO&!!k1m>4GLqFilAFGRg!HS`e-yH z5L3nUf+V9Hu?Pk{gT!Z$o+!y4G8C$KdN(1ExjS0g$5c!90L4Rn=~hDdvD#Q0@T_hs zATZUYOl>NJ9J(AH5%i6TnCPHZpv96?S9c~yU|xw&p6BnigL5vHZBa Date: Tue, 29 Oct 2024 15:09:30 +0100 Subject: [PATCH 5/5] style pkg --- R/explain.R | 116 +++++++++++++------------ R/fit.R | 83 +++++++++--------- R/generate.R | 83 +++++++++--------- R/model.R | 10 +-- R/process.R | 159 +++++++++++++++++------------------ R/zzz.R | 4 +- tests/testthat/test-output.R | 40 +++++---- tests/testthat/test-setup.R | 48 ++++++----- 8 files changed, 277 insertions(+), 266 deletions(-) diff --git a/R/explain.R b/R/explain.R index 925ccf1..b3bbb75 100644 --- a/R/explain.R +++ b/R/explain.R @@ -70,20 +70,19 @@ #' #' @export #' -explain_mcce = function(model, x_explain, x_train, predict_model=NULL, - fixed_features = NULL, c_int=c(0.5,1), - featuremodel_object = NULL, - fit.autoregressive_model="ctree", fit.decision = TRUE, fit.seed = NULL, - generate.K = 1000, generate.seed = NULL, - process.measures = c("validation","L0","L1"), - process.return_best_k = 1, - process.remove_invalid = TRUE, - process.sort_by_measures_order = TRUE, - return_featuremodel_object = FALSE, - return_sim_data = FALSE, - timing = TRUE, - ...){ - +explain_mcce <- function(model, x_explain, x_train, predict_model = NULL, + fixed_features = NULL, c_int = c(0.5, 1), + featuremodel_object = NULL, + fit.autoregressive_model = "ctree", fit.decision = TRUE, fit.seed = NULL, + generate.K = 1000, generate.seed = NULL, + process.measures = c("validation", "L0", "L1"), + process.return_best_k = 1, + process.remove_invalid = TRUE, + process.sort_by_measures_order = TRUE, + return_featuremodel_object = FALSE, + return_sim_data = FALSE, + timing = TRUE, + ...) { if (!(is.matrix(x_explain) || is.data.frame(x_explain))) { stop("x_explain should be a matrix or a data.frame/data.table.\n") } else { @@ -92,15 +91,14 @@ explain_mcce = function(model, x_explain, x_train, predict_model=NULL, predict_model <- get_predict_model(predict_model, model) - if(is.null(featuremodel_object)){ - if (! (is.matrix(x_train) || is.data.frame(x_train))) { + if (is.null(featuremodel_object)) { + if (!(is.matrix(x_train) || is.data.frame(x_train))) { stop("x_train should be a matrix or a data.frame/data.table, or 'featuremodel_object' must be passed\n") } else { x_train <- data.table::as.data.table(x_train) } - pred_train <- predict_model(model,x_train) - + pred_train <- predict_model(model, x_train) } else { # If featuremodel_object is passed with don't need the prediction on the training data set pred_train <- NULL } @@ -109,51 +107,61 @@ explain_mcce = function(model, x_explain, x_train, predict_model=NULL, - fit_object <- fit(x_train = x_train, - pred_train = pred_train, - fixed_features = fixed_features, - c_int = c_int, - featuremodel_object = featuremodel_object, - autoregressive_model = fit.autoregressive_model, - decision = fit.decision, - seed = fit.seed, - ...) + fit_object <- fit( + x_train = x_train, + pred_train = pred_train, + fixed_features = fixed_features, + c_int = c_int, + featuremodel_object = featuremodel_object, + autoregressive_model = fit.autoregressive_model, + decision = fit.decision, + seed = fit.seed, + ... + ) - sim_object <- generate(x_explain = x_explain, - fit_object=fit_object, - K = generate.K, - seed=generate.seed) + sim_object <- generate( + x_explain = x_explain, + fit_object = fit_object, + K = generate.K, + seed = generate.seed + ) x_sim <- sim_object$simData - pred_sim <- predict_model(model,x_sim[,-"id_explain"]) - - cfs <- process(x_sim = x_sim, - pred_sim = pred_sim, - x_explain = x_explain, - fit_object = fit_object, - measures = process.measures, # Don't obey this quite yet - remove_invalid = process.remove_invalid, - return_best_k = process.return_best_k, - sort_by_measures_order = process.sort_by_measures_order) - - time_vec <- c(fit.time=fit_object$time_fit, - generate.time=sim_object$time_generate, - process.time=cfs$time_process) - - ret <- list(cf = cfs$cf[], - cf_measures = cfs$cf_measures[], - fixed_features = fit_object$fixed_features, - mutable_features = fit_object$mutable_features, - time = time_vec) - - if(return_featuremodel_object==TRUE){ + pred_sim <- predict_model(model, x_sim[, -"id_explain"]) + + cfs <- process( + x_sim = x_sim, + pred_sim = pred_sim, + x_explain = x_explain, + fit_object = fit_object, + measures = process.measures, # Don't obey this quite yet + remove_invalid = process.remove_invalid, + return_best_k = process.return_best_k, + sort_by_measures_order = process.sort_by_measures_order + ) + + time_vec <- c( + fit.time = fit_object$time_fit, + generate.time = sim_object$time_generate, + process.time = cfs$time_process + ) + + ret <- list( + cf = cfs$cf[], + cf_measures = cfs$cf_measures[], + fixed_features = fit_object$fixed_features, + mutable_features = fit_object$mutable_features, + time = time_vec + ) + + if (return_featuremodel_object == TRUE) { fit_object$time_fit <- NULL # Removed ret$featuremodel_object <- fit_object } - if(return_sim_data==TRUE){ + if (return_sim_data == TRUE) { ret$sim_data <- x_sim } if (timing == FALSE) { diff --git a/R/fit.R b/R/fit.R index a7107c5..6f38e81 100644 --- a/R/fit.R +++ b/R/fit.R @@ -18,15 +18,14 @@ #' #' @export #' -fit = function(x_train, pred_train, fixed_features, c_int=c(mean(pred_train),1),featuremodel_object = NULL, - autoregressive_model="ctree", seed=NULL,decision = TRUE,...){ - - if(!is.null(featuremodel_object)){ - featuremodel_object$time_fit <- as.difftime(0,units = "secs") +fit <- function(x_train, pred_train, fixed_features, c_int = c(mean(pred_train), 1), featuremodel_object = NULL, + autoregressive_model = "ctree", seed = NULL, decision = TRUE, ...) { + if (!is.null(featuremodel_object)) { + featuremodel_object$time_fit <- as.difftime(0, units = "secs") return(featuremodel_object) } - if(!is.null(seed)) set.seed(seed) + if (!is.null(seed)) set.seed(seed) if (!is.matrix(x_train) && !is.data.frame(x_train)) { stop("x_train should be a matrix or a data.frame/data.table.\n") @@ -44,80 +43,82 @@ fit = function(x_train, pred_train, fixed_features, c_int=c(mean(pred_train),1), ) } - if(decision){ - decision0 <- (pred_train>=c_int[1] & pred_train<=c_int[2])*1 + if (decision) { + decision0 <- (pred_train >= c_int[1] & pred_train <= c_int[2]) * 1 # x_train[,decision := decision0] - data.table::set(x_train, i = NULL, j="decision", value=decision0) + data.table::set(x_train, i = NULL, j = "decision", value = decision0) - fixed_features <- c(fixed_features,"decision") + fixed_features <- c(fixed_features, "decision") } - if(length(fixed_features)==0){ - x_train[,dummy:=1] + if (length(fixed_features) == 0) { + x_train[, dummy := 1] mutable_features <- names(x_train)[!(names(x_train) %in% "dummy")] current_x <- "dummy" - } else { mutable_features <- names(x_train)[!(names(x_train) %in% fixed_features)] current_x <- fixed_features } - N_mutable = length(mutable_features) + N_mutable <- length(mutable_features) featuremodel_list <- list() - time_fit_start = Sys.time() + time_fit_start <- Sys.time() # Fit the models - for(i in seq_len(N_mutable)){ - + for (i in seq_len(N_mutable)) { response <- mutable_features[i] features <- current_x - if(autoregressive_model=="ctree"){ - featuremodel_list[[i]] <- model.ctree(response,features,data=x_train,...) - } else if(autoregressive_model%in% c("rpart")){ - featuremodel_list[[i]] <- model.rpart(response,features,data=x_train,...) + if (autoregressive_model == "ctree") { + featuremodel_list[[i]] <- model.ctree(response, features, data = x_train, ...) + } else if (autoregressive_model %in% c("rpart")) { + featuremodel_list[[i]] <- model.rpart(response, features, data = x_train, ...) } else { stop("autoregressive_model argument not recognized.") } current_x <- c(current_x, mutable_features[i]) -# print(i) + # print(i) } - time_fit = difftime(Sys.time(), time_fit_start, units = "secs") - ret <- list(featuremodel_list = featuremodel_list, - time_fit = time_fit, - fixed_features = fixed_features, - mutable_features = mutable_features, - autoregressive_model = autoregressive_model, - decision = decision, - x_train = x_train, - c_int = c_int) + time_fit <- difftime(Sys.time(), time_fit_start, units = "secs") + ret <- list( + featuremodel_list = featuremodel_list, + time_fit = time_fit, + fixed_features = fixed_features, + mutable_features = mutable_features, + autoregressive_model = autoregressive_model, + decision = decision, + x_train = x_train, + c_int = c_int + ) return(ret) } #' @keywords internal -model.ctree <- function(response,features,data,...){ +model.ctree <- function(response, features, data, ...) { formula <- as.formula(paste0(response, "~", paste0(features, collapse = "+"))) - mod <- party::ctree(formula = formula, - data = data, - ...) - + mod <- party::ctree( + formula = formula, + data = data, + ... + ) } #' @keywords internal -model.rpart <- function(response,features,data,...){ +model.rpart <- function(response, features, data, ...) { formula <- as.formula(paste0(response, "~", paste0(features, collapse = "+"))) - mod <- rpart::rpart(formula = formula, - data = data, - ...) - + mod <- rpart::rpart( + formula = formula, + data = data, + ... + ) } diff --git a/R/generate.R b/R/generate.R index 259d77e..9d87548 100644 --- a/R/generate.R +++ b/R/generate.R @@ -16,10 +16,7 @@ #' #' @export #' -generate = function(x_explain, fit_object, K = 1000, seed=NULL){ - - - +generate <- function(x_explain, fit_object, K = 1000, seed = NULL) { mutable_features <- fit_object$mutable_features fixed_features <- fit_object$fixed_features decision <- fit_object$decision @@ -35,15 +32,14 @@ generate = function(x_explain, fit_object, K = 1000, seed=NULL){ } - if(decision){ - set(x_explain, i = NULL, j="decision", value=1) + if (decision) { + set(x_explain, i = NULL, j = "decision", value = 1) } - if(length(fixed_features)==0){ - x_explain[,dummy:=1] - - dup_features <- c("dummy",fixed_features) + if (length(fixed_features) == 0) { + x_explain[, dummy := 1] + dup_features <- c("dummy", fixed_features) } else { dup_features <- fixed_features } @@ -51,72 +47,77 @@ generate = function(x_explain, fit_object, K = 1000, seed=NULL){ explain_vec <- rep(seq_len(n_explain), each = K) - time_generate_start = Sys.time() + time_generate_start <- Sys.time() - simData <- x_explain[explain_vec, dup_features,with=F] + simData <- x_explain[explain_vec, dup_features, with = F] simData <- data.table::setDT(simData) rowno <- seq_len(nrow(x_train)) N_mutable <- length(mutable_features) - if(autoregressive_model=="ctree"){ - predict_node = predict_node.ctree - } else if(autoregressive_model=="rpart"){ - predict_node = predict_node.rpart + if (autoregressive_model == "ctree") { + predict_node <- predict_node.ctree + } else if (autoregressive_model == "rpart") { + predict_node <- predict_node.rpart } - if(!is.null(seed))set.seed(seed) - for(i in seq_len(N_mutable)){ + if (!is.null(seed)) set.seed(seed) + for (i in seq_len(N_mutable)) { this_feature <- mutable_features[i] fit.nodes <- predict_node(model = featuremodel_list[[i]]) pred.nodes <- predict_node(model = featuremodel_list[[i]], newdata = simData) - for (pred_node in unique(pred.nodes)){ + for (pred_node in unique(pred.nodes)) { these_rows <- which(pred.nodes == pred_node) - newrowno <- sample(rowno[fit.nodes == pred_node], length(these_rows),replace = TRUE) - tmp <- unlist(x_train[newrowno, this_feature, with=F]) - data.table::set(simData,i = these_rows,j=this_feature, value = tmp) + newrowno <- sample(rowno[fit.nodes == pred_node], length(these_rows), replace = TRUE) + tmp <- unlist(x_train[newrowno, this_feature, with = F]) + data.table::set(simData, i = these_rows, j = this_feature, value = tmp) } } data.table::setcolorder(simData, neworder = names(x_train)) - data.table::set(simData,j="id_explain", value = explain_vec) + data.table::set(simData, j = "id_explain", value = explain_vec) - if(length(fixed_features)==0){ - simData[,dummy:=NULL] + if (length(fixed_features) == 0) { + simData[, dummy := NULL] } - if(decision){ - data.table::set(simData, i = NULL, j="decision", value=NULL) + if (decision) { + data.table::set(simData, i = NULL, j = "decision", value = NULL) } - data.table::setcolorder(simData,"id_explain") + data.table::setcolorder(simData, "id_explain") - time_generate = difftime(Sys.time(), time_generate_start, units = "secs") + time_generate <- difftime(Sys.time(), time_generate_start, units = "secs") - ret <- list(simData = simData, - time_generate = time_generate) + ret <- list( + simData = simData, + time_generate = time_generate + ) return(ret) } -predict_node.ctree <- function(model,newdata=NULL){ - party::where(object=model,newdata=newdata) +predict_node.ctree <- function(model, newdata = NULL) { + party::where(object = model, newdata = newdata) } -predict_node.rpart <- function(model,newdata=NULL,version = "new"){ - if(version=="new"){ +predict_node.rpart <- function(model, newdata = NULL, version = "new") { + if (version == "new") { # Using a modified version of rpart_leaves from the last answer here: https://stackoverflow.com/questions/17597739/get-id-name-of-rpart-model-nodes # with getFromNamespace instead of ::: to avoid CRAN notes - if(is.null(newdata)){ + if (is.null(newdata)) { return(model$where) } if (is.null(attr(newdata, "terms"))) { Terms <- delete.response(model$terms) - newdata <- model.frame(Terms, newdata, na.action = na.pass, - xlev = attr(model, "xlevels")) - if (!is.null(cl <- attr(Terms, "dataClasses"))) + newdata <- model.frame(Terms, newdata, + na.action = na.pass, + xlev = attr(model, "xlevels") + ) + if (!is.null(cl <- attr(Terms, "dataClasses"))) { .checkMFClasses(cl, newdata, TRUE) + } } newdata <- getFromNamespace("rpart.matrix", ns = "rpart")(newdata) where <- unname(getFromNamespace("pred.rpart", ns = "rpart")(model, newdata)) @@ -124,9 +125,7 @@ predict_node.rpart <- function(model,newdata=NULL,version = "new"){ } else { # Using partykit::as.party(model) # See: "https://stackoverflow.com/questions/36748531/getting-the-observations-in-a-rparts-node-i-e-cart" - where <- predict(object=partykit::as.party(model),newdata=newdata,type="node") + where <- predict(object = partykit::as.party(model), newdata = newdata, type = "node") return(where) } } - - diff --git a/R/model.R b/R/model.R index d26f6ab..5d64f82 100644 --- a/R/model.R +++ b/R/model.R @@ -120,7 +120,6 @@ predict_model.glm <- function(x, newdata) { #' #' @keywords internal get_predict_model <- function(predict_model, model) { - # Checks that predict_model is a proper function (R + py) # Extracts natively supported functions for predict_model if exists and not passed (R only) # Checks that predict_model provide the right output format (R and py) @@ -129,16 +128,16 @@ get_predict_model <- function(predict_model, model) { model_class0 <- class(model)[1] # checks predict_model - if(!(is.function(predict_model)) && - !(is.null(predict_model))){ + if (!(is.function(predict_model)) && + !(is.null(predict_model))) { stop("`predict_model` must be NULL or a function.") } - supported_models <- substring(rownames(attr(methods(predict_model), "info")),first=15) + supported_models <- substring(rownames(attr(methods(predict_model), "info")), first = 15) # Get native predict_model if not passed and exists if (is.null(predict_model)) { - if(model_class0 %in% supported_models){ + if (model_class0 %in% supported_models) { predict_model <- mcceR::predict_model } else { stop( @@ -151,4 +150,3 @@ get_predict_model <- function(predict_model, model) { return(predict_model) } - diff --git a/R/process.R b/R/process.R index f31d709..7e064ab 100644 --- a/R/process.R +++ b/R/process.R @@ -29,14 +29,14 @@ #' #' @export #' -process = function(x_sim, - pred_sim, - x_explain, - fit_object, - measures = c("validation","L0","L1"), # Don't obey this quite yet - remove_invalid = TRUE, - return_best_k = 1, - sort_by_measures_order = TRUE){ # Don't obey this quite yet +process <- function(x_sim, + pred_sim, + x_explain, + fit_object, + measures = c("validation", "L0", "L1"), # Don't obey this quite yet + remove_invalid = TRUE, + return_best_k = 1, + sort_by_measures_order = TRUE) { # Don't obey this quite yet if (!is.matrix(x_sim) && !is.data.frame(x_sim)) { stop("x_sim should be a matrix or a data.frame/data.table.\n") @@ -54,147 +54,146 @@ process = function(x_sim, mutable_features <- fit_object$mutable_features c_int <- fit_object$c_int - time_process_start = Sys.time() + time_process_start <- Sys.time() n_explain <- nrow(x_explain) - mutable_features_plus <- c("id_explain",mutable_features) + mutable_features_plus <- c("id_explain", mutable_features) - x_sim_mutable <- x_sim[,mutable_features_plus,with=F] - x_explain_mutable <- cbind(id_explain=seq_len(n_explain),x_explain[,mutable_features,with=F]) + x_sim_mutable <- x_sim[, mutable_features_plus, with = F] + x_explain_mutable <- cbind(id_explain = seq_len(n_explain), x_explain[, mutable_features, with = F]) - x_sim[,row_id:=1:.N] + x_sim[, row_id := 1:.N] - res.dt <- data.table::data.table(row_id=seq_len(nrow(x_sim)),id_explain=x_sim[,id_explain],pred=pred_sim) + res.dt <- data.table::data.table(row_id = seq_len(nrow(x_sim)), id_explain = x_sim[, id_explain], pred = pred_sim) measure_ordering <- c() - for(this_measure in measures){ - - if(this_measure=="validation"){ - get_measure_validation(res.dt,x_explain_mutable,x_sim_mutable,c_int) - measure_ordering <- c(measure_ordering,-1) + for (this_measure in measures) { + if (this_measure == "validation") { + get_measure_validation(res.dt, x_explain_mutable, x_sim_mutable, c_int) + measure_ordering <- c(measure_ordering, -1) } - if(this_measure=="L0"){ - get_measure_L0(res.dt,x_explain_mutable,x_sim_mutable) - measure_ordering <- c(measure_ordering,1) + if (this_measure == "L0") { + get_measure_L0(res.dt, x_explain_mutable, x_sim_mutable) + measure_ordering <- c(measure_ordering, 1) } - if(this_measure=="L1"){ - get_measure_L1(res.dt,x_explain_mutable,x_sim_mutable) - measure_ordering <- c(measure_ordering,1) + if (this_measure == "L1") { + get_measure_L1(res.dt, x_explain_mutable, x_sim_mutable) + measure_ordering <- c(measure_ordering, 1) } - if(this_measure=="L2"){ - get_measure_L2(res.dt,x_explain_mutable,x_sim_mutable) - measure_ordering <- c(measure_ordering,1) + if (this_measure == "L2") { + get_measure_L2(res.dt, x_explain_mutable, x_sim_mutable) + measure_ordering <- c(measure_ordering, 1) } - if(this_measure=="gower"){ - get_measure_gower(res.dt,x_explain_mutable,x_sim_mutable) - measure_ordering <- c(measure_ordering,1) + if (this_measure == "gower") { + get_measure_gower(res.dt, x_explain_mutable, x_sim_mutable) + measure_ordering <- c(measure_ordering, 1) } } - data.table::setorderv(res.dt,cols=c("id_explain",paste0("measure_",measures)),order = c(1,measure_ordering)) + data.table::setorderv(res.dt, cols = c("id_explain", paste0("measure_", measures)), order = c(1, measure_ordering)) res.dt[, counterfactual_rank := 1:.N, by = id_explain] - if(remove_invalid){ - res.dt <- res.dt[measure_validation==1] + if (remove_invalid) { + res.dt <- res.dt[measure_validation == 1] } - cols <- c("counterfactual_rank","row_id","pred",paste0("measure_",measures)) + cols <- c("counterfactual_rank", "row_id", "pred", paste0("measure_", measures)) - ret_sim0 <- res.dt[,head(.SD,return_best_k),by=id_explain][,cols,with = FALSE] + ret_sim0 <- res.dt[, head(.SD, return_best_k), by = id_explain][, cols, with = FALSE] - ret_sim <- x_sim[ret_sim0,on="row_id"] + ret_sim <- x_sim[ret_sim0, on = "row_id"] - time_process = difftime(Sys.time(), time_process_start, units = "secs") + time_process <- difftime(Sys.time(), time_process_start, units = "secs") - data.table::setcolorder(ret_sim,c("id_explain",cols[-2])) + data.table::setcolorder(ret_sim, c("id_explain", cols[-2])) - cols_measure <- c("id_explain",cols[-2]) + cols_measure <- c("id_explain", cols[-2]) cols_cf <- names(ret_sim)[!(names(ret_sim) %in% cols[-1])] - ret <- list(cf=ret_sim[,cols_cf,with = FALSE], - cf_measures = ret_sim[,cols_measure,with = FALSE], - time_process = time_process) + ret <- list( + cf = ret_sim[, cols_cf, with = FALSE], + cf_measures = ret_sim[, cols_measure, with = FALSE], + time_process = time_process + ) return(ret) } -get_measure_validation <- function(res.dt,x_explain_mutable,x_sim_mutable,c_int){ - res.dt[,measure_validation:= (pred>=c_int[1] & pred<=c_int[2])*1] +get_measure_validation <- function(res.dt, x_explain_mutable, x_sim_mutable, c_int) { + res.dt[, measure_validation := (pred >= c_int[1] & pred <= c_int[2]) * 1] } -get_measure_L0 <- function(res.dt,x_explain_mutable,x_sim_mutable){ - n_features <- ncol(x_explain_mutable)-1 +get_measure_L0 <- function(res.dt, x_explain_mutable, x_sim_mutable) { + n_features <- ncol(x_explain_mutable) - 1 combined <- x_explain_mutable[x_sim_mutable, on = "id_explain", nomatch = 0] # Identify columns to compare (exclude `id_explain`) columns_to_compare <- setdiff(names(x_explain_mutable), "id_explain") - value <- identical_rows(combined,columns_to_compare) + value <- identical_rows(combined, columns_to_compare) - res.dt[,measure_L0 := n_features-value] # Number of features changed from original value + res.dt[, measure_L0 := n_features - value] # Number of features changed from original value } -identical_rows <- function(combined,columns_to_compare){ - +identical_rows <- function(combined, columns_to_compare) { # Calculate the number of identical values in each row - value <- combined[, rowSums(mapply(function(col1, col2) col1 == col2, - .SD[, columns_to_compare, with = FALSE], - .SD[, paste0("i.",columns_to_compare), with = FALSE]))] + value <- combined[, rowSums(mapply( + function(col1, col2) col1 == col2, + .SD[, columns_to_compare, with = FALSE], + .SD[, paste0("i.", columns_to_compare), with = FALSE] + ))] } -get_measure_L1 <- function(res.dt,x_explain_mutable,x_sim_mutable){ - n_explain <- x_explain_mutable[,.N] - for (i in seq_len(n_explain)){ - value <- Rfast::dista(x_explain_mutable[id_explain==i,-1],x_sim_mutable[id_explain==i,-1],trans=F,type = "manhattan") - res.dt[id_explain==i,measure_L1:=value] - #set(res.dt,i=which(res.dt$id_explain==i),j = "measure_manhattan",value = Rfast::dista(x_explain_mutable[id_explain==i,-1],x_sim_mutable[id_explain==i,-1],trans=F,type = "manhattan")) +get_measure_L1 <- function(res.dt, x_explain_mutable, x_sim_mutable) { + n_explain <- x_explain_mutable[, .N] + for (i in seq_len(n_explain)) { + value <- Rfast::dista(x_explain_mutable[id_explain == i, -1], x_sim_mutable[id_explain == i, -1], trans = F, type = "manhattan") + res.dt[id_explain == i, measure_L1 := value] + # set(res.dt,i=which(res.dt$id_explain==i),j = "measure_manhattan",value = Rfast::dista(x_explain_mutable[id_explain==i,-1],x_sim_mutable[id_explain==i,-1],trans=F,type = "manhattan")) } } -get_measure_L2 <- function(res.dt,x_explain_mutable,x_sim_mutable){ - n_explain <- x_explain_mutable[,.N] - for (i in seq_len(n_explain)){ - value <- Rfast::dista(x_explain_mutable[id_explain==i,-1],x_sim_mutable[id_explain==i,-1],trans=F) - res.dt[id_explain==i,measure_L2:=value] -# set(res.dt,i=which(res.dt$id_explain==i),j="measure_euclidean",value = Rfast::dista(x_explain_mutable[id_explain==i,-1],x_sim_mutable[id_explain==i,-1],trans=F)) +get_measure_L2 <- function(res.dt, x_explain_mutable, x_sim_mutable) { + n_explain <- x_explain_mutable[, .N] + for (i in seq_len(n_explain)) { + value <- Rfast::dista(x_explain_mutable[id_explain == i, -1], x_sim_mutable[id_explain == i, -1], trans = F) + res.dt[id_explain == i, measure_L2 := value] + # set(res.dt,i=which(res.dt$id_explain==i),j="measure_euclidean",value = Rfast::dista(x_explain_mutable[id_explain==i,-1],x_sim_mutable[id_explain==i,-1],trans=F)) } } -get_measure_gower <- function(res.dt,x_explain_mutable,x_sim_mutable){ - - cat_cols <- names(which(sapply(x_explain_mutable[,-1],is.factor))) - num_cols <- names(which(sapply(x_explain_mutable[,-1],is.numeric))) +get_measure_gower <- function(res.dt, x_explain_mutable, x_sim_mutable) { + cat_cols <- names(which(sapply(x_explain_mutable[, -1], is.factor))) + num_cols <- names(which(sapply(x_explain_mutable[, -1], is.numeric))) combined <- x_explain_mutable[x_sim_mutable, on = "id_explain", nomatch = 0] - if(length(cat_cols)>0){ - cat_contrib <- identical_rows(combined,cat_cols) + if (length(cat_cols) > 0) { + cat_contrib <- identical_rows(combined, cat_cols) } else { cat_contrib <- 0 } num_contrib <- 0 - if(length(num_cols)>0){ - for(i in seq_along(num_cols)){ - col1 <- unlist(combined[,num_cols[i],with=FALSE]) - col2 <- unlist(combined[,paste0("i.",num_cols[i]),with=FALSE]) + if (length(num_cols) > 0) { + for (i in seq_along(num_cols)) { + col1 <- unlist(combined[, num_cols[i], with = FALSE]) + col2 <- unlist(combined[, paste0("i.", num_cols[i]), with = FALSE]) - num_contrib <- num_contrib + abs(col1-col2)/range(col2) # Using the range of the syntehtic data instead of the original for simplicity + num_contrib <- num_contrib + abs(col1 - col2) / range(col2) # Using the range of the syntehtic data instead of the original for simplicity } } - res.dt[,measure_gower:=cat_contrib + num_contrib] + res.dt[, measure_gower := cat_contrib + num_contrib] } - - diff --git a/R/zzz.R b/R/zzz.R index 05efa88..9f73256 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,5 +1,4 @@ .onLoad <- function(libname = find.package("mcceR"), pkgname = "mcceR") { - # CRAN Note avoidance utils::globalVariables( c( @@ -22,7 +21,8 @@ "dummy", "get_predict_model", "predict_node.rpart", - "pred" + "pred", + "measure_gower" ) ) invisible() diff --git a/tests/testthat/test-output.R b/tests/testthat/test-output.R index 0ab90d4..7b3527d 100644 --- a/tests/testthat/test-output.R +++ b/tests/testthat/test-output.R @@ -4,15 +4,17 @@ test_that("output_lm_numeric_ctree", { set.seed(123) expect_snapshot_rds( { - explain_mcce(model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - c_int = c(-Inf,15), - fixed_features = "Wind", - fit.autoregressive_model = "ctree", - fit.seed = 1, - generate.seed = 2, - timing = FALSE) + explain_mcce( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + c_int = c(-Inf, 15), + fixed_features = "Wind", + fit.autoregressive_model = "ctree", + fit.seed = 1, + generate.seed = 2, + timing = FALSE + ) }, "output_lm_numeric_ctree" ) @@ -22,15 +24,17 @@ test_that("output_lm_numeric_rpart", { set.seed(123) expect_snapshot_rds( { - explain_mcce(model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - c_int = c(-Inf,15), - fixed_features = "Wind", - fit.autoregressive_model = "rpart", - fit.seed = 1, - generate.seed = 2, - timing = FALSE) + explain_mcce( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + c_int = c(-Inf, 15), + fixed_features = "Wind", + fit.autoregressive_model = "rpart", + fit.seed = 1, + generate.seed = 2, + timing = FALSE + ) }, "output_lm_numeric_rpart" ) diff --git a/tests/testthat/test-setup.R b/tests/testthat/test-setup.R index 854c326..834090d 100644 --- a/tests/testthat/test-setup.R +++ b/tests/testthat/test-setup.R @@ -1,27 +1,30 @@ - test_that("Identical results when reusing the featuremodel in a new explain-call", { - explained1 <- explain_mcce(model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - c_int = c(-Inf,15), - fixed_features = "Wind", - fit.autoregressive_model = "ctree", - fit.seed = 1, - generate.seed = 2, - timing = FALSE, - return_featuremodel_object = TRUE) # Returning featuremodel + explained1 <- explain_mcce( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + c_int = c(-Inf, 15), + fixed_features = "Wind", + fit.autoregressive_model = "ctree", + fit.seed = 1, + generate.seed = 2, + timing = FALSE, + return_featuremodel_object = TRUE + ) # Returning featuremodel - explained2 <- explain_mcce(model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - c_int = c(-Inf,15), - featuremodel_object = explained1$featuremodel_object, # Reuse featuremodel - fixed_features = "Wind", - fit.autoregressive_model = "ctree", - fit.seed = 1, - generate.seed = 2, - timing = FALSE, - return_featuremodel_object = TRUE) + explained2 <- explain_mcce( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + c_int = c(-Inf, 15), + featuremodel_object = explained1$featuremodel_object, # Reuse featuremodel + fixed_features = "Wind", + fit.autoregressive_model = "ctree", + fit.seed = 1, + generate.seed = 2, + timing = FALSE, + return_featuremodel_object = TRUE + ) # The entire object is identical @@ -29,5 +32,4 @@ test_that("Identical results when reusing the featuremodel in a new explain-call explained1, explained2 ) - })