-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Classification single learners (#14)
* feat: classif single learners * test: classif single learners * fix: remove extra_trees & lda * test: rename * test: extra_trees and glmnet
- Loading branch information
Showing
15 changed files
with
398 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#' @title Classification Gradient Boosted Decision Trees Auto Learner | ||
#' | ||
#' @description | ||
#' Classification auto learner. | ||
#' | ||
#' @template param_id | ||
#' | ||
#' @export | ||
LearnerClassifAutoCatboost = R6Class("LearnerClassifAutoCatboost", | ||
inherit = LearnerClassifAuto, | ||
public = list( | ||
|
||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function(id = "classif.auto_catboost") { | ||
super$initialize(id = id) | ||
|
||
# reduce parameter set to the relevant parameters | ||
private$.param_set = private$.param_set$subset( | ||
c("learner_ids", | ||
"learner_timeout", | ||
"catboost_eval_metric", | ||
"small_data_size", | ||
"small_data_resampling", | ||
"max_cardinality", | ||
"resampling", | ||
"terminator", | ||
"measure", | ||
"lhs_size", | ||
"callbacks", | ||
"store_benchmark_result") | ||
) | ||
|
||
self$param_set$set_values(learner_ids = "catboost") | ||
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "catboost") | ||
} | ||
) | ||
) | ||
|
||
#' @include aaa.R | ||
learners[["classif.auto_catboost"]] = LearnerClassifAutoCatboost |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#' @title Classification GLM with Elastic Net Regularization Auto Learner | ||
#' | ||
#' @description | ||
#' Classification auto learner. | ||
#' | ||
#' @template param_id | ||
#' | ||
#' @export | ||
LearnerClassifAutoGlmnet = R6Class("LearnerClassifAutoGlmnet", | ||
inherit = LearnerClassifAuto, | ||
public = list( | ||
|
||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function(id = "classif.auto_glmnet") { | ||
super$initialize(id = id) | ||
|
||
# reduce parameter set to the relevant parameters | ||
private$.param_set = private$.param_set$subset( | ||
c("learner_ids", | ||
"learner_timeout", | ||
"small_data_size", | ||
"small_data_resampling", | ||
"max_cardinality", | ||
"resampling", | ||
"terminator", | ||
"measure", | ||
"lhs_size", | ||
"callbacks", | ||
"store_benchmark_result") | ||
) | ||
|
||
self$param_set$set_values(learner_ids = "glmnet") | ||
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "glmnet") | ||
} | ||
) | ||
) | ||
|
||
#' @include aaa.R | ||
learners[["classif.auto_glmnet"]] = LearnerClassifAutoGlmnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#' @title Classification k-Nearest-Neighbor Auto Learner | ||
#' | ||
#' @description | ||
#' Classification auto learner. | ||
#' | ||
#' @template param_id | ||
#' | ||
#' @export | ||
LearnerClassifAutoKKNN = R6Class("LearnerClassifAutoKKNN", | ||
inherit = LearnerClassifAuto, | ||
public = list( | ||
|
||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function(id = "classif.auto_kknn") { | ||
super$initialize(id = id) | ||
|
||
# reduce parameter set to the relevant parameters | ||
private$.param_set = private$.param_set$subset( | ||
c("learner_ids", | ||
"learner_timeout", | ||
"small_data_size", | ||
"small_data_resampling", | ||
"max_cardinality", | ||
"resampling", | ||
"terminator", | ||
"measure", | ||
"lhs_size", | ||
"callbacks", | ||
"store_benchmark_result") | ||
) | ||
|
||
self$param_set$set_values(learner_ids = "kknn") | ||
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "kknn") | ||
} | ||
) | ||
) | ||
|
||
#' @include aaa.R | ||
learners[["classif.auto_kknn"]] = LearnerClassifAutoKKNN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#' @title Classification LightGBM Auto Learner | ||
#' | ||
#' @description | ||
#' Classification auto learner. | ||
#' | ||
#' @template param_id | ||
#' | ||
#' @export | ||
LearnerClassifAutoLightGBM = R6Class("LearnerClassifAutoLightGBM", | ||
inherit = LearnerClassifAuto, | ||
public = list( | ||
|
||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function(id = "classif.auto_lightgbm") { | ||
super$initialize(id = id) | ||
|
||
# reduce parameter set to the relevant parameters | ||
private$.param_set = private$.param_set$subset( | ||
c("learner_ids", | ||
"learner_timeout", | ||
"lightgbm_eval_metric", | ||
"small_data_size", | ||
"small_data_resampling", | ||
"max_cardinality", | ||
"resampling", | ||
"terminator", | ||
"measure", | ||
"lhs_size", | ||
"callbacks", | ||
"store_benchmark_result") | ||
) | ||
|
||
self$param_set$set_values(learner_ids = "lightgbm") | ||
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "lightgbm") | ||
} | ||
) | ||
) | ||
|
||
#' @include aaa.R | ||
learners[["classif.auto_lightgbm"]] = LearnerClassifAutoLightGBM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#' @title Classification Neural Network Auto Learner | ||
#' | ||
#' @description | ||
#' Classification auto learner. | ||
#' | ||
#' @template param_id | ||
#' | ||
#' @export | ||
LearnerClassifAutoNnet = R6Class("LearnerClassifAutoNnet", | ||
inherit = LearnerClassifAuto, | ||
public = list( | ||
|
||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function(id = "classif.auto_nnet") { | ||
super$initialize(id = id) | ||
|
||
# reduce parameter set to the relevant parameters | ||
private$.param_set = private$.param_set$subset( | ||
c("learner_ids", | ||
"learner_timeout", | ||
"small_data_size", | ||
"small_data_resampling", | ||
"max_cardinality", | ||
"resampling", | ||
"terminator", | ||
"measure", | ||
"lhs_size", | ||
"callbacks", | ||
"store_benchmark_result") | ||
) | ||
|
||
self$param_set$set_values(learner_ids = "nnet") | ||
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "nnet") | ||
} | ||
) | ||
) | ||
|
||
#' @include aaa.R | ||
learners[["classif.auto_nnet"]] = LearnerClassifAutoNnet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#' @title Classification Ranger Auto Learner | ||
#' | ||
#' @description | ||
#' Classification auto learner. | ||
#' | ||
#' @template param_id | ||
#' | ||
#' @export | ||
LearnerClassifAutoRanger = R6Class("LearnerClassifAutoRanger", | ||
inherit = LearnerClassifAuto, | ||
public = list( | ||
|
||
#' @description | ||
#' Creates a new instance of this [R6][R6::R6Class] class. | ||
initialize = function(id = "classif.auto_ranger") { | ||
super$initialize(id = id) | ||
|
||
# reduce parameter set to the relevant parameters | ||
private$.param_set = private$.param_set$subset( | ||
c("learner_ids", | ||
"learner_timeout", | ||
"small_data_size", | ||
"small_data_resampling", | ||
"max_cardinality", | ||
"resampling", | ||
"terminator", | ||
"measure", | ||
"lhs_size", | ||
"callbacks", | ||
"store_benchmark_result") | ||
) | ||
|
||
self$param_set$set_values(learner_ids = "ranger") | ||
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "ranger") | ||
} | ||
) | ||
) | ||
|
||
#' @include aaa.R | ||
learners[["classif.auto_ranger"]] = LearnerClassifAutoRanger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
test_that("LearnerClassifAutoCatboost is initialized", { | ||
learner = lrn("classif.auto_catboost", | ||
measure = msr("classif.ce"), | ||
terminator = trm("evals", n_evals = 10)) | ||
|
||
expect_null(learner$graph) | ||
expect_null(learner$tuning_space) | ||
}) | ||
|
||
test_that("LearnerClassifAutoCatboost is trained", { | ||
rush_plan(n_workers = 2) | ||
lgr::get_logger("mlr3automl")$set_threshold("debug") | ||
|
||
task = tsk("penguins") | ||
learner = lrn("classif.auto_catboost", | ||
small_data_size = 1, | ||
resampling = rsmp("holdout"), | ||
measure = msr("classif.ce"), | ||
terminator = trm("evals", n_evals = 6) | ||
) | ||
|
||
expect_class(learner$train(task), "LearnerClassifAutoCatboost") | ||
expect_equal(learner$graph$param_set$values$branch.selection, "catboost") | ||
expect_equal(learner$model$instance$result$branch.selection, "catboost") | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
test_that("LearnerClassifAutoGlmnet is initialized", { | ||
learner = lrn("classif.auto_glmnet", | ||
measure = msr("classif.ce"), | ||
terminator = trm("evals", n_evals = 10)) | ||
|
||
expect_null(learner$graph) | ||
expect_null(learner$tuning_space) | ||
}) | ||
|
||
test_that("LearnerClassifAutoGlmnet is trained", { | ||
rush_plan(n_workers = 2) | ||
lgr::get_logger("mlr3automl")$set_threshold("debug") | ||
|
||
task = tsk("penguins") | ||
learner = lrn("classif.auto_glmnet", | ||
small_data_size = 1, | ||
resampling = rsmp("holdout"), | ||
measure = msr("classif.ce"), | ||
terminator = trm("evals", n_evals = 6) | ||
) | ||
|
||
expect_class(learner$train(task), "LearnerClassifAutoGlmnet") | ||
expect_equal(learner$graph$param_set$values$branch.selection, "glmnet") | ||
expect_equal(learner$model$instance$result$branch.selection, "glmnet") | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
test_that("LearnerClassifAutoKKNN is initialized", { | ||
learner = lrn("classif.auto_kknn", | ||
measure = msr("classif.ce"), | ||
terminator = trm("evals", n_evals = 10)) | ||
|
||
expect_null(learner$graph) | ||
expect_null(learner$tuning_space) | ||
}) | ||
|
||
test_that("LearnerClassifAutoKKNN is trained", { | ||
rush_plan(n_workers = 2) | ||
lgr::get_logger("mlr3automl")$set_threshold("debug") | ||
|
||
task = tsk("penguins") | ||
learner = lrn("classif.auto_kknn", | ||
small_data_size = 1, | ||
resampling = rsmp("holdout"), | ||
measure = msr("classif.ce"), | ||
terminator = trm("evals", n_evals = 6) | ||
) | ||
|
||
expect_class(learner$train(task), "LearnerClassifAutoKKNN") | ||
expect_equal(learner$graph$param_set$values$branch.selection, "kknn") | ||
expect_equal(learner$model$instance$result$branch.selection, "kknn") | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
test_that("LearnerClassifAutoLightGBM is initialized", { | ||
learner = lrn("classif.auto_lightgbm", | ||
measure = msr("classif.ce"), | ||
terminator = trm("evals", n_evals = 10)) | ||
|
||
expect_null(learner$graph) | ||
expect_null(learner$tuning_space) | ||
}) | ||
|
||
test_that("LearnerClassifAutoLightGBM is trained", { | ||
rush_plan(n_workers = 2) | ||
lgr::get_logger("mlr3automl")$set_threshold("debug") | ||
|
||
task = tsk("penguins") | ||
learner = lrn("classif.auto_lightgbm", | ||
small_data_size = 1, | ||
resampling = rsmp("holdout"), | ||
measure = msr("classif.ce"), | ||
terminator = trm("evals", n_evals = 6) | ||
) | ||
|
||
expect_class(learner$train(task), "LearnerClassifAutoLightGBM") | ||
expect_equal(learner$graph$param_set$values$branch.selection, "lightgbm") | ||
expect_equal(learner$model$instance$result$branch.selection, "lightgbm") | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
test_that("LearnerClassifAutoNnet is initialized", { | ||
learner = lrn("classif.auto_nnet", | ||
measure = msr("classif.ce"), | ||
terminator = trm("evals", n_evals = 10)) | ||
|
||
expect_null(learner$graph) | ||
expect_null(learner$tuning_space) | ||
}) | ||
|
||
test_that("LearnerClassifAutoNnet is trained", { | ||
rush_plan(n_workers = 2) | ||
lgr::get_logger("mlr3automl")$set_threshold("debug") | ||
|
||
task = tsk("penguins") | ||
learner = lrn("classif.auto_nnet", | ||
small_data_size = 1, | ||
resampling = rsmp("holdout"), | ||
measure = msr("classif.ce"), | ||
terminator = trm("evals", n_evals = 6) | ||
) | ||
|
||
expect_class(learner$train(task), "LearnerClassifAutoNnet") | ||
expect_equal(learner$graph$param_set$values$branch.selection, "nnet") | ||
expect_equal(learner$model$instance$result$branch.selection, "nnet") | ||
}) |
Oops, something went wrong.