Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catboost #100

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ inst/doc
.DS_Store
logs
derby.log
catboost_info
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -77,5 +77,6 @@ Suggests:
sparklyr (>= 0.8.0),
tinytest,
varImp,
xgboost
xgboost,
catboost
RoxygenNote: 7.1.0
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ S3method(vi,Learner)
S3method(vi,WrappedModel)
S3method(vi,default)
S3method(vi,model_fit)
S3method(vi,workflow)
S3method(vi_firm,default)
S3method(vi_model,C5.0)
S3method(vi_model,H2OBinomialModel)
Expand All @@ -13,6 +14,7 @@ S3method(vi_model,H2ORegressionModel)
S3method(vi_model,Learner)
S3method(vi_model,RandomForest)
S3method(vi_model,WrappedModel)
S3method(vi_model,catboost.Model)
S3method(vi_model,cforest)
S3method(vi_model,constparty)
S3method(vi_model,cubist)
Expand All @@ -39,11 +41,13 @@ S3method(vi_model,randomForest)
S3method(vi_model,ranger)
S3method(vi_model,rpart)
S3method(vi_model,train)
S3method(vi_model,workflow)
S3method(vi_model,xgb.Booster)
S3method(vi_permute,default)
S3method(vi_shap,default)
S3method(vip,default)
S3method(vip,model_fit)
S3method(vip,workflow)
export("%>%")
export("%T>%")
export(add_sparklines)
Expand Down
11 changes: 11 additions & 0 deletions R/get_feature_names.R
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,14 @@ get_feature_names.xgb.Booster <- function(object, ...) {
object$feature_names
}
}

# Package: catboost -------------------------------------------------------------

#' @keywords internal
get_feature_names.catboost.Model <- function(object, ...) {
if (is.null(rownames(fit$feature_importances))) {
get_feature_names.default(object)
} else {
rownames(fit$feature_importances)
}
}
6 changes: 6 additions & 0 deletions R/vi.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ vi.model_fit <- function(object, ...) { # package: parsnip
vi(object$fit, ...)
}

#' @rdname vi
#'
#' @export
vi.workflow <- function(object, ...) { # package: workflows
vi(object$fit$fit$fit, ...)
}

#' @rdname vi
#'
Expand Down
45 changes: 45 additions & 0 deletions R/vi_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@
#'
#' }}
#'
#' \item{\code{\link[catboost]{catboost}}}{See \code{\link[catboost]{catboost.get_feature_importance}} or visit
#' \url{https://catboost.ai/docs/concepts/r-reference_catboost-get_feature_importance.html}
#' for details.}
#'
#' }
#'
#' @note Inspired by the \code{\link[caret]{varImp}} function.
Expand Down Expand Up @@ -724,6 +728,14 @@ vi_model.model_fit <- function(object, ...) { # package: parsnip
vi_model(object$fit, ...)
}

# Package: parsnip -------------------------------------------------------------

#' @rdname vi_model
#'
#' @export
vi_model.workflow <- function(object, ...) { # package: workflows
vi_model(object$fit$fit$fit, ...)
}

# Package: party ---------------------------------------------------------------

Expand Down Expand Up @@ -1335,3 +1347,36 @@ vi_model.xgb.Booster <- function(object, type = c("gain", "cover", "frequency"),
tib

}

# Package: catboost -------------------------------------------------------------

#' @rdname vi_model
#'
#' @export
vi_model.catboost.Model <- function(object, type = c("FeatureImportance", "PredictionValuesChange", "LossFunctionChange", "Interaction"), ...) {

# Determine which type of variable importance to compute
type <- match.arg(type)

# Construct model-specific variable importance scores
imp <- catboost::catboost.get_feature_importance(model = object, type = type, ...)
var_names <- get_feature_names.catboost.Model(object)

if(type %in% c("LossFunctionChange", "FeatureImportance", "PredictionValuesChange")) {
tib <- tibble::enframe(imp[,1], name = "Variable", value = "Importance")
} else if(type == "Interaction") {
tib <- tibble::as_tibble(imp)
tib <- setNames(tib, c("Variable1", "Variable2", "Importance"))
tib$Variable1 <- var_names[tib$Variable1 + 1]
tib$Variable2 <- var_names[tib$Variable2 + 1]
}

# Add variable importance type attribute
attr(tib, which = "type") <- type

# Add "vi" class
class(tib) <- c("vi", class(tib))

# Return results
tib
}
16 changes: 16 additions & 0 deletions R/vi_shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,19 @@ vi_shap.default <- function(object, feature_names = NULL, train = NULL, ...) {
tib

}

#' @rdname vi_shap
#'
#' @export
vi_shap.catboost.Model <- function(object, feature_names = NULL, train = NULL, ...) {
# Try to extract feature names if not supplied
if (is.null(feature_names)) {
feature_names <- get_feature_names(object)
}

# catboost do not give access to the training data directly from the model object.
if (is.null(train)) {
stop("Please provide a `catboost.Pool` object to the train argument. See `catboost::catboost.load_pool()`.")
}

}
7 changes: 6 additions & 1 deletion R/vip.R
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,16 @@ vip.default <- function(

}


#' @rdname vip
#'
#' @export
vip.model_fit <- function(object, ...) {
vip(object$fit, ...)
}

#' @rdname vip
#'
#' @export
vip.workflow <- function(object, ...) {
vip(object$fit$fit$fit, ...)
}
60 changes: 60 additions & 0 deletions inst/tinytest/test_pkg_catboost.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Exits
if (!requireNamespace("catboost", quietly = TRUE)) {
exit_file("Package catboost missing")
}

# # Load required packages
# suppressMessages({
# library(catboost)
# })

# Generate Friedman benchmark data
friedman1 <- gen_friedman(seed = 101)

# Fit model(s)
set.seed(101)
fit <- catboost::catboost.train(
learn_pool = catboost::catboost.load_pool(friedman1[,-1], friedman1[,1, drop = TRUE]),
params = list(logging_level = "Silent", iterations = 10)
)

# Compute VI scores
vis_FeatureImportance_default <- vi_model(fit)
vis_FeatureImportance <- vi_model(fit, type = "FeatureImportance")
vis_PredictionValuesChange <- vi_model(fit, type = "PredictionValuesChange")
vis_LossFunctionChange <- vi_model(fit, type = "LossFunctionChange", pool = catboost::catboost.load_pool(friedman1[,-1], friedman1[,1, drop = TRUE]))
vis_Interaction <- vi_model(fit, type = "Interaction")

# Expectations for `vi_model()`
expect_identical()
expect_identical()
expect_identical()

# Expectations for `get_training_data()`
expect_error(vip:::get_training_data.default(fit))

# Expectations for `get_feature_names()`
expect_identical(
current = vip:::get_feature_names.catboost.Model(fit),
target = paste0("x", 1L:10L)
)

# Call `vip::vip()` directly
p <- vip(fit, method = "model", include_type = TRUE)

# Expect `p` to be a `"gg" "ggplot"` object
expect_identical(
current = class(p),
target = c("gg", "ggplot")
)

# Display VIPs side by side
grid.arrange(
vip(vis_FeatureImportance_default, include_type = TRUE),
vip(vis_FeatureImportance, include_type = TRUE),
vip(vis_PredictionValuesChange, include_type = TRUE),
vip(vis_LossFunctionChange, include_type = TRUE),
# vip(vis_Interaction, include_type = TRUE),
p,
nrow = 1
)
59 changes: 59 additions & 0 deletions inst/tinytest/test_pkg_workflows.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Exits
if (!requireNamespace("workflows", quietly = TRUE)) {
exit_file("Package workflows missing")
}

# Load required packages
suppressMessages({
library(workflows)
})

# Generate Friedman benchmark data
friedman1 <- gen_friedman(seed = 101)

# Fit a linear model
lin <- parsnip::linear_reg() %>%
parsnip::set_engine("lm")

wf <- workflows::workflow() %>%
workflows::add_model(lin) %>%
workflows::add_formula(y ~ .)

lin_fit <- wf %>%
parsnip::fit(data = friedman1)

# Compute model-based VI scores
vis <- vi(lin_fit, scale = TRUE)

# Expect `vi()` and `vi_model()` to both work
expect_identical(
current = vi(lin_fit, sort = FALSE),
target = vi_model(lin_fit)
)

# Check class
expect_identical(class(vis), target = c("vi", "tbl_df", "tbl", "data.frame"))

# Check dimensions (should be one row for each feature)
expect_identical(ncol(friedman1) - 1L, target = nrow(vis))

# Display VIP
vip(vis, geom = "point")

# Try permutation importance
set.seed(953) # for reproducibility
p <- vip(
object = lin_fit,
method = "permute",
train = friedman1,
target = "y",
pred_wrapper = predict,
metric = "rmse",
nsim = 30,
geom = "violin",
jitter = TRUE,
all_permutation = TRUE,
mapping = ggplot2::aes(color = Variable)
)
expect_true(inherits(p, what = "ggplot"))
p # display VIP
3 changes: 3 additions & 0 deletions man/vi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 18 additions & 1 deletion man/vi_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/vip.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.