diff --git a/R/explain.R b/R/explain.R index c9b9926..b84e28c 100644 --- a/R/explain.R +++ b/R/explain.R @@ -66,6 +66,7 @@ prepare_forestplot <- function(model) { #' #' @return plot (ggplot2 or base) #' + plot_regression <- function(plot_type, fitted_model, explained_instance, scale = NULL) { if(plot_type == "forestplot") { prepare_forestplot(fitted_model) @@ -73,7 +74,9 @@ plot_regression <- function(plot_type, fitted_model, explained_instance, scale = if(scale == "probability") { plot(breakDown::broken(fitted_model, explained_instance, baseline = "intercept"), trans = function(x) exp(x)/(1 + exp(x))) + - scale_y_continuous(limits = c(0, 1), name = "probability", expand = c(0, 0)) + ggplot2::scale_y_continuous(limits = c(0, 1), + name = "probability", + expand = c(0, 0)) } else { plot(breakDown::broken(fitted_model, explained_instance, baseline = "intercept")) diff --git a/vignettes/HR.Rmd b/vignettes/HR.Rmd index 8df5ac8..920628b 100644 --- a/vignettes/HR.Rmd +++ b/vignettes/HR.Rmd @@ -21,7 +21,7 @@ trees <- randomForest(left~., data = HR_data, ntree=1000) similar <- sample_locally(data = HR_data, explained_instance = HR_data[2,], explained_var = "left", - size = 200) + size = 2000) head(similar$data) similar <- add_predictions(HR_Data, similar, black_box_model = trees) @@ -32,8 +32,23 @@ trained <- fit_explanation(live_object = similar, selection = FALSE) # trained plot_explanation(trained, "waterfallplot", - explained_instance = HR_data[1,], + explained_instance = HR_data[2,], scale = "probability") plot_explanation(trained, "forestplot", - explained_instance = HR_data[1,]) + explained_instance = HR_data[2,]) + +HR_data$left <- as.numeric(as.character(HR_data$left)) +trees <- randomForest(left~., data = HR_data, ntree=1000) + +similar2 <- sample_locally(data = HR_data, + explained_instance = HR_data[2,], + explained_var = "left", + size = 2000) +similar2 <- add_predictions(HR_Data, similar2, black_box_model = trees) +glimpse(similar2$data) +trained2 <- fit_explanation(live_object = similar2, + white_box = "regr.lm", + selection = F) +plot_explanation(trained2, "forestplot", + explained_instance = HR_data[2,]) ```