Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add randomPlantedForest learners (classif, regr) #304

Merged
merged 33 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
16343ad
Re-add rpf
jemus42 Feb 20, 2023
d88b49d
Add more missing bits
jemus42 Feb 20, 2023
a3c8322
Tweak paramtests
jemus42 Feb 21, 2023
d102698
Merge upstream
jemus42 Oct 17, 2023
66f3a6b
chore: update rpf learners and bibentry
jemus42 Oct 17, 2023
16c5ed6
docs: roxygenise
jemus42 Oct 17, 2023
5c3be24
docs: add to DESCRIPTION
jemus42 Oct 17, 2023
b84bfe7
fix: import randomPlantedForest
jemus42 Oct 17, 2023
5b08622
docs: fix crossref, pkg docs
jemus42 Oct 18, 2023
ec44772
add "threads" tag to rpf ntrheads param
jemus42 Oct 18, 2023
86ad7f0
chore: formatting
jemus42 Oct 18, 2023
c18b671
chore: update rpf paramtests to current template
jemus42 Oct 18, 2023
4b34f20
First round of fixes
jemus42 Oct 20, 2023
6c0e49e
docs: custom params, installation
jemus42 Oct 20, 2023
a56f9fa
Merge branch 'main' into rpf
jemus42 Oct 20, 2023
ca2f42d
Adress more comments
jemus42 Oct 20, 2023
20cfa05
Remove comments
jemus42 Oct 20, 2023
2d6482e
Formatting, remove default for max_interaction_ratio
jemus42 Oct 20, 2023
7d72bcf
Make multiclass prediction output more robust
jemus42 Oct 20, 2023
1fd74cf
Fix derp
jemus42 Oct 20, 2023
a9e0dda
Fix use of <-
jemus42 Oct 23, 2023
a8f52e4
fix: accidentally removed distr6 remote
jemus42 Oct 23, 2023
4a24f24
Merge branch 'main' into rpf
jemus42 Oct 24, 2023
c96cdc2
Initialize max_interaction_limit = Inf, adjust docs
jemus42 Oct 27, 2023
b0e008b
Roxygenmize
jemus42 Oct 27, 2023
df31664
Update R/learner_randomPlantedForest_regr_rpf.R
jemus42 Nov 2, 2023
9b48110
Update R/learner_randomPlantedForest_classif_rpf.R
jemus42 Nov 2, 2023
befc4bc
Update R/learner_randomPlantedForest_classif_rpf.R
jemus42 Nov 2, 2023
f6e9cdb
Update R/learner_randomPlantedForest_classif_rpf.R
jemus42 Nov 2, 2023
d693ca8
Update R/learner_randomPlantedForest_classif_rpf.R
jemus42 Nov 2, 2023
3065c3b
Merge branch 'main' into rpf
jemus42 Nov 2, 2023
ca3828f
Update R/learner_randomPlantedForest_regr_rpf.R
jemus42 Nov 2, 2023
71d502b
Re-add special_vals = list(Inf) for classif.rpf
jemus42 Nov 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ Authors@R: c(
person("Sebastian", "Fischer", , "sebf.fischer@gmail.com", role = c("cre", "aut")),
person("Zezhi", "Wang", ,"homura@mail.ustc.edu.cn", role = "ctb"),
person("John", "Zobolas", ,"bblodfon@gmail.com", role = "ctb",
comment = c(ORCID = "0000-0002-3609-8674"))
comment = c(ORCID = "0000-0002-3609-8674")),
person("Lukas", "Burk", , "burk@leibniz-bips.de", role = "ctb",
comment = c(ORCID = "0000-0001-7528-3795"))
)
Description: Extra learners for use in mlr3.
License: LGPL-3
Expand Down Expand Up @@ -85,6 +87,7 @@ Suggests:
prioritylasso (>= 0.3.1),
pseudo,
randomForest,
randomPlantedForest,
randomForestSRC,
ranger,
remotes,
Expand All @@ -108,6 +111,7 @@ Remotes:
catboost/catboost/catboost/R-package,
mlr-org/mlr3proba,
RaphaelS1/survivalmodels,
PlantedML/randomPlantedForest,
xoopR/distr6,
xoopR/param6,
xoopR/set6
Expand All @@ -119,3 +123,4 @@ NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3
Config/Needs/website: rmarkdown

2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export(LearnerClassifPART)
export(LearnerClassifPriorityLasso)
export(LearnerClassifRandomForest)
export(LearnerClassifRandomForestSRC)
export(LearnerClassifRandomPlantedForest)
export(LearnerDensKDEks)
export(LearnerDensLocfit)
export(LearnerDensLogspline)
Expand Down Expand Up @@ -65,6 +66,7 @@ export(LearnerRegrRSM)
export(LearnerRegrRVM)
export(LearnerRegrRandomForest)
export(LearnerRegrRandomForestSRC)
export(LearnerRegrRandomPlantedForest)
export(LearnerSurvAkritas)
export(LearnerSurvAorsf)
export(LearnerSurvBlackBoost)
Expand Down
10 changes: 9 additions & 1 deletion R/bibentries.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ bibentries = c( # nolint start
journal = "The Annals of Applied Statistics"
),

hiabu_2023 = bibentry("article",
title = "Random Planted Forest: a directly interpretable tree ensemble",
author = "Hiabu, Munir and Mammen, Enno and Meyer, Joseph T.",
journal = "arXiv preprint arXiv:2012.14563",
doi = "10.48550/ARXIV.2012.14563",
year = "2023"
),

hothorn_2015 = bibentry("article",
author = "Torsten Hothorn and Achim Zeileis",
title = "partykit: A Modular Toolkit for Recursive Partytioning in R",
Expand Down Expand Up @@ -574,4 +582,4 @@ bibentries = c( # nolint start
month = "01",
journal = "University of California, Berkeley"
)
) # nolint end
) # nolint end
126 changes: 126 additions & 0 deletions R/learner_randomPlantedForest_classif_rpf.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#' @title Classification Random Planted Forest Learner
#' @author jemus42
#' @name mlr_learners_classif.rpf
#'
#' @description
#' Random Planted Forest: A directly interpretable tree ensemble.
#'
#' Calls [randomPlantedForest::rpf()] from 'randomPlantedForest'.
#'
#' @section Initial parameter values:
#' - `loss`:
#' - Actual default: `"L2"`.
#' - Initial value: `"exponential"`.
#' - Reason for change: Using `"L2"` (or `"L1"`) loss does not guarantee predictions are valid
#' probabilities and more akin to the linear predictor of a GLM.

#' @section Custom mlr3 parameters:
#' - `max_interaction`:
jemus42 marked this conversation as resolved.
Show resolved Hide resolved
#' - This hyperparameter can alternatively be set via `max_interaction_ratio` as
#' `max_interaction = max(ceiling(max_interaction_ratio * n_features), 1)`.
#' The parameter `max_interaction_limit` can optionally be set as an upper bound, such that
#' `max_interaction_ratio * min(n_features, max_interaction_limit)` is used instead.
#' This is analogous to `mtry.ratio` in [`classif.ranger`][mlr3learners::mlr_learners_classif.ranger], with
#' `max_interaction_limit` as an additional constraint.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mention that max_interaction_limit is initialized to Inf

jemus42 marked this conversation as resolved.
Show resolved Hide resolved
#' The parameter `max_interaction_limit` is initialized to `Inf`.
#'
#' @templateVar id classif.rpf
#' @template learner
#'
#' @section Installation:
#' Package 'randomPlantedForest' is not on CRAN and has to be installed from GitHub via
#' `remotes::install_github("PlantedML/randomPlantedForest")`.
#'
#' @references
#' `r format_bib("hiabu_2023")`
#'
#' @template seealso_learner
#' @template example
#' @export
LearnerClassifRandomPlantedForest = R6Class("LearnerClassifRandomPlantedForest",
inherit = LearnerClassif,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
max_interaction = p_int(lower = 0, upper = Inf, default = 1, tags = "train"),
max_interaction_ratio = p_dbl(lower = 0, upper = 1, tags = "train"),
max_interaction_limit = p_int(lower = 1, upper = Inf, tags = c("required", "train")),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_interaction_limit = p_int(lower = 1, upper = Inf, tags = c("required", "train")),
max_interaction_limit = p_int(lower = 1, upper = Inf, tags = c("required", "train"), special_vals = list(Inf)),

ntrees = p_int(lower = 1, upper = Inf, default = 50, tags = "train"),
splits = p_int(lower = 1, upper = Inf, default = 30, tags = "train"),
split_try = p_int(lower = 1, upper = Inf, default = 10, tags = "train"),
t_try = p_dbl(lower = 0, upper = 1, default = 0.4, tags = "train"),
loss = p_fct(c("L1", "L2", "logit", "exponential"), default = "L2", tags = "train"),
delta = p_dbl(lower = 0, upper = 1, default = 1, tags = "train"),
epsilon = p_dbl(lower = 0, upper = 1, default = 0.1, tags = "train"),
deterministic = p_lgl(default = FALSE, tags = "train"),
nthreads = p_int(lower = 1, upper = Inf, default = 1, tags = c("train", "threads")),
cv = p_lgl(default = FALSE, tags = "train"),
purify = p_lgl(default = FALSE, tags = "train")
)

param_set$values = list(loss = "exponential", max_interaction_limit = Inf)

super$initialize(
id = "classif.rpf",
packages = "randomPlantedForest",
feature_types = c("integer", "numeric", "factor", "ordered", "logical"),
predict_types = c("response", "prob"),
param_set = param_set,
properties = c("twoclass", "multiclass"),
man = "mlr3extralearners::mlr_learners_classif.rpf",
label = "Random Planted Forest"
)
}
),
private = list(
.train = function(task) {
# get parameters for training
pars = self$param_set$get_values(tags = "train")
# max_interaction_limit is needed but must not be passed to rpf(),
# while convert_ratio automatically removes max_interaction_ratio.
max_interaction_limit = pars[["max_interaction_limit"]]
pars[["max_interaction_limit"]] = NULL
n_features = length(task$feature_names)

pars = convert_ratio(
pars, "max_interaction", "max_interaction_ratio",
min(n_features, max_interaction_limit)
jemus42 marked this conversation as resolved.
Show resolved Hide resolved
)

invoke(
randomPlantedForest::rpf,
x = task$data(cols = task$feature_names),
y = task$data(cols = task$target_names),
.args = pars
)
},
.predict = function(task) {
pars = self$param_set$get_values(tags = "predict")
newdata = ordered_features(task, self)

if (self$predict_type == "response") {
pred = invoke(
predict, self$model, new_data = newdata,
type = "class", .args = pars
)
list(response = pred[[".pred_class"]])
} else {
pred = invoke(
predict, self$model, new_data = newdata,
type = "prob", .args = pars
)
# Result will be a df with one column per variable with names '.pred_<level>'
# we want the names without ".pred"
xn = names(pred)
xn[which(xn == paste0(".pred_", task$class_names))] = task$class_names
names(pred) = xn

list(prob = as.matrix(pred))
}
}
)
)

.extralrns_dict$add("classif.rpf", LearnerClassifRandomPlantedForest)
90 changes: 90 additions & 0 deletions R/learner_randomPlantedForest_regr_rpf.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#' @title Regression Random Planted Forest Learner
#' @author jemus42
#' @name mlr_learners_regr.rpf
#'
#' @description
#' Random Planted Forest: A directly interpretable tree ensemble.
#'
#' Calls [randomPlantedForest::rpf()] from 'randomPlantedForest'.
#'
#' @inheritSection mlr_learners_classif.rpf Custom mlr3 parameters
#' @templateVar id regr.rpf
#' @template learner
#' @inheritSection mlr_learners_classif.rpf Installation
#'
#' @references
#' `r format_bib("hiabu_2023")`
#'
#' @template seealso_learner
#' @template example
#' @export
LearnerRegrRandomPlantedForest = R6Class("LearnerRegrRandomPlantedForest",
inherit = LearnerRegr,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
max_interaction = p_int(lower = 0, upper = Inf, default = 1, tags = "train"),
max_interaction_ratio = p_dbl(lower = 0, upper = 1, tags = "train"),
max_interaction_limit = p_int(lower = 1, upper = Inf, tags = c("required", "train")),
jemus42 marked this conversation as resolved.
Show resolved Hide resolved
ntrees = p_int(lower = 1, upper = Inf, default = 50, tags = "train"),
splits = p_int(lower = 1, upper = Inf, default = 30, tags = "train"),
split_try = p_int(lower = 1, upper = Inf, default = 10, tags = "train"),
t_try = p_dbl(lower = 0, upper = 1, default = 0.4, tags = "train"),
deterministic = p_lgl(default = FALSE, tags = "train"),
nthreads = p_int(lower = 1, upper = Inf, default = 1, tags = c("train", "threads")),
sebffischer marked this conversation as resolved.
Show resolved Hide resolved
cv = p_lgl(default = FALSE, tags = "train"),
purify = p_lgl(default = FALSE, tags = "train")
)

param_set$values = list(max_interaction_limit = Inf)

super$initialize(
id = "regr.rpf",
packages = "randomPlantedForest",
feature_types = c("integer", "numeric", "factor", "ordered", "logical"),
predict_types = "response",
param_set = param_set,
properties = character(0),
man = "mlr3extralearners::mlr_learners_regr.rpf",
label = "Random Planted Forest"
)
}
),
private = list(
.train = function(task) {
# get parameters for training
pars = self$param_set$get_values(tags = "train")
# max_interaction_limit is needed but must not be passed to rpf(),
# while convert_ratio automatically removes max_interaction_ratio.
max_interaction_limit = pars[["max_interaction_limit"]]
pars[["max_interaction_limit"]] = NULL
n_features = length(task$feature_names)

pars = convert_ratio(
pars, "max_interaction", "max_interaction_ratio",
min(n_features, max_interaction_limit)
jemus42 marked this conversation as resolved.
Show resolved Hide resolved
)

invoke(
randomPlantedForest::rpf,
x = task$data(cols = task$feature_names),
y = task$data(cols = task$target_names),
.args = pars
)
},
.predict = function(task) {
pars = self$param_set$get_values(tags = "predict")
newdata = ordered_features(task, self)

pred = invoke(
predict, self$model, new_data = newdata,
type = "numeric", .args = pars
)
list(response = pred[[".pred"]])
}
)
)

.extralrns_dict$add("regr.rpf", LearnerRegrRandomPlantedForest)
1 change: 1 addition & 0 deletions man/mlr3extralearners-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading