-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
415 additions
and
196 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#' Compute feature importance in a model | ||
#' | ||
#' Creates a \code{data.table} of feature importances in a model. | ||
#' | ||
#' @param model object of class \code{lgb.Booster}. | ||
#' @param percentage whether to show importance in relative percentage. | ||
#' | ||
#' @return | ||
#' | ||
#' For a tree model, a \code{data.table} with the following columns: | ||
#' \itemize{ | ||
#' \item \code{Feature} Feature names in the model. | ||
#' \item \code{Gain} The total gain of this feature's splits. | ||
#' \item \code{Cover} The number of observation related to this feature. | ||
#' \item \code{Frequency} The number of times a feature splited in trees. | ||
#' } | ||
#' | ||
#' @examples | ||
#' | ||
#' data(agaricus.train, package = 'lightgbm') | ||
#' train <- agaricus.train | ||
#' dtrain <- lgb.Dataset(train$data, label = train$label) | ||
#' | ||
#' params = list(objective = "binary", | ||
#' learning_rate = 0.01, num_leaves = 63, max_depth = -1, | ||
#' min_data_in_leaf = 1, min_sum_hessian_in_leaf = 1) | ||
#' model <- lgb.train(params, dtrain, 20) | ||
#' model <- lgb.train(params, dtrain, 20) | ||
#' | ||
#' tree_imp1 <- lgb.importance(model, percentage = TRUE) | ||
#' tree_imp2 <- lgb.importance(model, percentage = FALSE) | ||
#' | ||
#' @importFrom magrittr %>% %T>% | ||
#' @importFrom data.table := | ||
#' @export | ||
|
||
lgb.importance <- function(model, percentage = TRUE) { | ||
if (!any(class(model) == "lgb.Booster")) { | ||
stop("'model' has to be an object of class lgb.Booster") | ||
} | ||
tree_dt <- lgb.model.dt.tree(model) | ||
tree_imp <- tree_dt %>% | ||
magrittr::extract(., | ||
i = is.na(split_index) == FALSE, | ||
j = .(Gain = sum(split_gain), Cover = sum(internal_count), Frequency = .N), | ||
by = "split_feature") %T>% | ||
data.table::setnames(., old = "split_feature", new = "Feature") %>% | ||
magrittr::extract(., i = order(Gain, decreasing = TRUE)) | ||
if (percentage) { | ||
tree_imp[, ":="(Gain = Gain / sum(Gain), | ||
Cover = Cover / sum(Cover), | ||
Frequency = Frequency / sum(Frequency))] | ||
} | ||
return(tree_imp) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#' Parse a LightGBM model json dump | ||
#' | ||
#' Parse a LightGBM model json dump into a \code{data.table} structure. | ||
#' | ||
#' @param model object of class \code{lgb.Booster} | ||
#' | ||
#' @return | ||
#' A \code{data.table} with detailed information about model trees' nodes and leafs. | ||
#' | ||
#' The columns of the \code{data.table} are: | ||
#' | ||
#' \itemize{ | ||
#' \item \code{tree_index}: ID of a tree in a model (integer) | ||
#' \item \code{split_index}: ID of a node in a tree (integer) | ||
#' \item \code{split_feature}: for a node, it's a feature name (character); | ||
#' for a leaf, it simply labels it as \code{'NA'} | ||
#' \item \code{node_parent}: ID of the parent node for current node (integer) | ||
#' \item \code{leaf_index}: ID of a leaf in a tree (integer) | ||
#' \item \code{leaf_parent}: ID of the parent node for current leaf (integer) | ||
#' \item \code{split_gain}: Split gain of a node | ||
#' \item \code{threshold}: Spliting threshold value of a node | ||
#' \item \code{decision_type}: Decision type of a node | ||
#' \item \code{internal_value}: Node value | ||
#' \item \code{internal_count}: The number of observation collected by a node | ||
#' \item \code{leaf_value}: Leaf value | ||
#' \item \code{leaf_count}: The number of observation collected by a leaf | ||
#' } | ||
#' | ||
#' @examples | ||
#' | ||
#' data(agaricus.train, package = 'lightgbm') | ||
#' train <- agaricus.train | ||
#' dtrain <- lgb.Dataset(train$data, label = train$label) | ||
#' | ||
#' params = list(objective = "binary", | ||
#' learning_rate = 0.01, num_leaves = 63, max_depth = -1, | ||
#' min_data_in_leaf = 1, min_sum_hessian_in_leaf = 1) | ||
#' model <- lgb.train(params, dtrain, 20) | ||
#' model <- lgb.train(params, dtrain, 20) | ||
#' | ||
#' tree_dt <- lgb.model.dt.tree(model) | ||
#' | ||
#' @importFrom magrittr %>% | ||
#' @importFrom data.table := | ||
#' @export | ||
|
||
lgb.model.dt.tree <- function(model, num_iteration = NULL) { | ||
json_model <- lgb.dump(model, num_iteration = num_iteration) | ||
parsed_json_model <- jsonlite::fromJSON(json_model, | ||
simplifyVector = TRUE, | ||
simplifyDataFrame = FALSE, | ||
simplifyMatrix = FALSE, | ||
flatten = FALSE) | ||
tree_list <- lapply(parsed_json_model$tree_info, single.tree.parse) | ||
tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE) | ||
tree_dt[, split_feature := Lookup(split_feature, | ||
seq(0, parsed_json_model$max_feature_idx, by = 1), | ||
parsed_json_model$feature_names)] | ||
return(tree_dt) | ||
} | ||
|
||
single.tree.parse <- function(lgb_tree) { | ||
single_tree_dt <- data.table::data.table(tree_index = integer(0), | ||
split_index = integer(0), split_feature = integer(0), node_parent = integer(0), | ||
leaf_index = integer(0), leaf_parent = integer(0), | ||
split_gain = numeric(0), threshold = numeric(0), decision_type = character(0), | ||
internal_value = integer(0), internal_count = integer(0), | ||
leaf_value = integer(0), leaf_count = integer(0)) | ||
pre_order_traversal <- function(tree_node_leaf, parent_index = NA) { | ||
if (!is.null(tree_node_leaf$split_index)) { | ||
single_tree_dt <<- data.table::rbindlist(l = list(single_tree_dt, | ||
c(tree_node_leaf[c("split_index", "split_feature", | ||
"split_gain", "threshold", "decision_type", | ||
"internal_value", "internal_count")], | ||
"node_parent" = parent_index)), | ||
use.names = TRUE, fill = TRUE) | ||
pre_order_traversal(tree_node_leaf$left_child, parent_index = tree_node_leaf$split_index) | ||
pre_order_traversal(tree_node_leaf$right_child, parent_index = tree_node_leaf$split_index) | ||
} else if (!is.null(tree_node_leaf$leaf_index)) { | ||
single_tree_dt <<- data.table::rbindlist(l = list(single_tree_dt, | ||
tree_node_leaf[c("leaf_index", "leaf_parent", | ||
"leaf_value", "leaf_count")]), | ||
use.names = TRUE, fill = TRUE) | ||
} | ||
} | ||
pre_order_traversal(lgb_tree$tree_structure) | ||
single_tree_dt[, tree_index := lgb_tree$tree_index] | ||
return(single_tree_dt) | ||
} | ||
|
||
Lookup <- function(key, key_lookup, value_lookup, missing = NA) { | ||
match(key, key_lookup) %>% | ||
magrittr::extract(value_lookup, .) %>% | ||
magrittr::inset(. , is.na(.), missing) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.