Skip to content

Commit

Permalink
Merge pull request #3 from NorskRegnesentral/categorical_fixes
Browse files Browse the repository at this point in the history
Adds gower + fixes L0 for categorical
  • Loading branch information
martinju authored Oct 29, 2024
2 parents 2dcf4ef + ad03f80 commit 6427297
Show file tree
Hide file tree
Showing 12 changed files with 414 additions and 252 deletions.
115 changes: 62 additions & 53 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,19 @@
#'
#' @export
#'
explain_mcce = function(model, x_explain, x_train, predict_model=NULL,
fixed_features = NULL, c_int=c(0.5,1),
featuremodel_object = NULL,
fit.autoregressive_model="ctree", fit.decision = TRUE, fit.seed = NULL,
generate.K = 1000, generate.seed = NULL,
process.measures = c("validation","L0","L1"),
process.return_best_k = 1,
process.remove_invalid = TRUE,
process.sort_by_measures_order = TRUE,
return_featuremodel_object = FALSE,
return_sim_data = FALSE,
timing = TRUE,
...){

explain_mcce <- function(model, x_explain, x_train, predict_model = NULL,
fixed_features = NULL, c_int = c(0.5, 1),
featuremodel_object = NULL,
fit.autoregressive_model = "ctree", fit.decision = TRUE, fit.seed = NULL,
generate.K = 1000, generate.seed = NULL,
process.measures = c("validation", "L0", "L1"),
process.return_best_k = 1,
process.remove_invalid = TRUE,
process.sort_by_measures_order = TRUE,
return_featuremodel_object = FALSE,
return_sim_data = FALSE,
timing = TRUE,
...) {
if (!(is.matrix(x_explain) || is.data.frame(x_explain))) {
stop("x_explain should be a matrix or a data.frame/data.table.\n")
} else {
Expand All @@ -92,15 +91,14 @@ explain_mcce = function(model, x_explain, x_train, predict_model=NULL,

predict_model <- get_predict_model(predict_model, model)

if(is.null(featuremodel_object)){
if (! (is.matrix(x_train) || is.data.frame(x_train))) {
if (is.null(featuremodel_object)) {
if (!(is.matrix(x_train) || is.data.frame(x_train))) {
stop("x_train should be a matrix or a data.frame/data.table, or 'featuremodel_object' must be passed\n")
} else {
x_train <- data.table::as.data.table(x_train)
}

pred_train <- predict_model(model,x_train)

pred_train <- predict_model(model, x_train)
} else { # If featuremodel_object is passed with don't need the prediction on the training data set
pred_train <- NULL
}
Expand All @@ -109,50 +107,61 @@ explain_mcce = function(model, x_explain, x_train, predict_model=NULL,



fit_object <- fit(x_train = x_train,
pred_train = pred_train,
fixed_features = fixed_features,
c_int = c_int,
featuremodel_object = featuremodel_object,
autoregressive_model = fit.autoregressive_model,
decision = fit.decision,
seed = fit.seed,
...)
fit_object <- fit(
x_train = x_train,
pred_train = pred_train,
fixed_features = fixed_features,
c_int = c_int,
featuremodel_object = featuremodel_object,
autoregressive_model = fit.autoregressive_model,
decision = fit.decision,
seed = fit.seed,
...
)



sim_object <- generate(x_explain = x_explain,
fit_object=fit_object,
K = generate.K,
seed=generate.seed)
sim_object <- generate(
x_explain = x_explain,
fit_object = fit_object,
K = generate.K,
seed = generate.seed
)

x_sim <- sim_object$simData

pred_sim <- predict_model(model,x_sim[,-"id_explain"])

cfs <- process(x_sim = x_sim,
pred_sim = pred_sim,
x_explain = x_explain,
fit_object = fit_object,
measures = process.measures, # Don't obey this quite yet
remove_invalid = process.remove_invalid,
return_best_k = process.return_best_k,
sort_by_measures_order = process.sort_by_measures_order)

time_vec <- c(fit.time=fit_object$time_fit,
generate.time=sim_object$time_generate,
process.time=cfs$time_process)

ret <- list(cf = cfs$cf[],
fixed_features = fit_object$fixed_features,
mutable_features = fit_object$mutable_features,
time = time_vec)

if(return_featuremodel_object==TRUE){
pred_sim <- predict_model(model, x_sim[, -"id_explain"])

cfs <- process(
x_sim = x_sim,
pred_sim = pred_sim,
x_explain = x_explain,
fit_object = fit_object,
measures = process.measures, # Don't obey this quite yet
remove_invalid = process.remove_invalid,
return_best_k = process.return_best_k,
sort_by_measures_order = process.sort_by_measures_order
)

time_vec <- c(
fit.time = fit_object$time_fit,
generate.time = sim_object$time_generate,
process.time = cfs$time_process
)

ret <- list(
cf = cfs$cf[],
cf_measures = cfs$cf_measures[],
fixed_features = fit_object$fixed_features,
mutable_features = fit_object$mutable_features,
time = time_vec
)

if (return_featuremodel_object == TRUE) {
fit_object$time_fit <- NULL # Removed
ret$featuremodel_object <- fit_object
}
if(return_sim_data==TRUE){
if (return_sim_data == TRUE) {
ret$sim_data <- x_sim
}
if (timing == FALSE) {
Expand Down
83 changes: 42 additions & 41 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
#'
#' @export
#'
fit = function(x_train, pred_train, fixed_features, c_int=c(mean(pred_train),1),featuremodel_object = NULL,
autoregressive_model="ctree", seed=NULL,decision = TRUE,...){

if(!is.null(featuremodel_object)){
featuremodel_object$time_fit <- as.difftime(0,units = "secs")
fit <- function(x_train, pred_train, fixed_features, c_int = c(mean(pred_train), 1), featuremodel_object = NULL,
autoregressive_model = "ctree", seed = NULL, decision = TRUE, ...) {
if (!is.null(featuremodel_object)) {
featuremodel_object$time_fit <- as.difftime(0, units = "secs")
return(featuremodel_object)
}

if(!is.null(seed)) set.seed(seed)
if (!is.null(seed)) set.seed(seed)

if (!is.matrix(x_train) && !is.data.frame(x_train)) {
stop("x_train should be a matrix or a data.frame/data.table.\n")
Expand All @@ -44,80 +43,82 @@ fit = function(x_train, pred_train, fixed_features, c_int=c(mean(pred_train),1),
)
}

if(decision){
decision0 <- (pred_train>=c_int[1] & pred_train<=c_int[2])*1
if (decision) {
decision0 <- (pred_train >= c_int[1] & pred_train <= c_int[2]) * 1

# x_train[,decision := decision0]
data.table::set(x_train, i = NULL, j="decision", value=decision0)
data.table::set(x_train, i = NULL, j = "decision", value = decision0)

fixed_features <- c(fixed_features,"decision")
fixed_features <- c(fixed_features, "decision")
}



if(length(fixed_features)==0){
x_train[,dummy:=1]
if (length(fixed_features) == 0) {
x_train[, dummy := 1]

mutable_features <- names(x_train)[!(names(x_train) %in% "dummy")]

current_x <- "dummy"

} else {
mutable_features <- names(x_train)[!(names(x_train) %in% fixed_features)]

current_x <- fixed_features
}

N_mutable = length(mutable_features)
N_mutable <- length(mutable_features)


featuremodel_list <- list()

time_fit_start = Sys.time()
time_fit_start <- Sys.time()
# Fit the models
for(i in seq_len(N_mutable)){

for (i in seq_len(N_mutable)) {
response <- mutable_features[i]
features <- current_x
if(autoregressive_model=="ctree"){
featuremodel_list[[i]] <- model.ctree(response,features,data=x_train,...)
} else if(autoregressive_model%in% c("rpart")){
featuremodel_list[[i]] <- model.rpart(response,features,data=x_train,...)
if (autoregressive_model == "ctree") {
featuremodel_list[[i]] <- model.ctree(response, features, data = x_train, ...)
} else if (autoregressive_model %in% c("rpart")) {
featuremodel_list[[i]] <- model.rpart(response, features, data = x_train, ...)
} else {
stop("autoregressive_model argument not recognized.")
}

current_x <- c(current_x, mutable_features[i])
# print(i)
# print(i)
}

time_fit = difftime(Sys.time(), time_fit_start, units = "secs")
ret <- list(featuremodel_list = featuremodel_list,
time_fit = time_fit,
fixed_features = fixed_features,
mutable_features = mutable_features,
autoregressive_model = autoregressive_model,
decision = decision,
x_train = x_train,
c_int = c_int)
time_fit <- difftime(Sys.time(), time_fit_start, units = "secs")
ret <- list(
featuremodel_list = featuremodel_list,
time_fit = time_fit,
fixed_features = fixed_features,
mutable_features = mutable_features,
autoregressive_model = autoregressive_model,
decision = decision,
x_train = x_train,
c_int = c_int
)

return(ret)
}

#' @keywords internal
model.ctree <- function(response,features,data,...){
model.ctree <- function(response, features, data, ...) {
formula <- as.formula(paste0(response, "~", paste0(features, collapse = "+")))
mod <- party::ctree(formula = formula,
data = data,
...)

mod <- party::ctree(
formula = formula,
data = data,
...
)
}

#' @keywords internal
model.rpart <- function(response,features,data,...){
model.rpart <- function(response, features, data, ...) {
formula <- as.formula(paste0(response, "~", paste0(features, collapse = "+")))
mod <- rpart::rpart(formula = formula,
data = data,
...)

mod <- rpart::rpart(
formula = formula,
data = data,
...
)
}
Loading

0 comments on commit 6427297

Please sign in to comment.