Skip to content

Commit

Permalink
Classification single learners (#14)
Browse files Browse the repository at this point in the history
* feat: classif single learners

* test: classif single learners

* fix: remove extra_trees & lda

* test: rename

* test: extra_trees and glmnet
  • Loading branch information
b-zhou authored Sep 13, 2024
1 parent 1421b79 commit eed0eeb
Show file tree
Hide file tree
Showing 15 changed files with 398 additions and 0 deletions.
6 changes: 6 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ RoxygenNote: 7.3.2
Collate:
'aaa.R'
'LearnerClassifAuto.R'
'LearnerClassifAutoCatboost.R'
'LearnerClassifAutoGlmnet.R'
'LearnerClassifAutoKKNN.R'
'LearnerClassifAutoLightGBM.R'
'LearnerClassifAutoNnet.R'
'LearnerClassifAutoRanger.R'
'LearnerClassifAutoSVM.R'
'LearnerClassifAutoXgboost.R'
'LearnerRegrAuto.R'
Expand Down
41 changes: 41 additions & 0 deletions R/LearnerClassifAutoCatboost.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#' @title Classification Gradient Boosted Decision Trees Auto Learner
#'
#' @description
#' Classification auto learner.
#'
#' @template param_id
#'
#' @export
LearnerClassifAutoCatboost = R6Class("LearnerClassifAutoCatboost",
inherit = LearnerClassifAuto,
public = list(

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_catboost") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"catboost_eval_metric",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "catboost")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "catboost")
}
)
)

#' @include aaa.R
learners[["classif.auto_catboost"]] = LearnerClassifAutoCatboost
40 changes: 40 additions & 0 deletions R/LearnerClassifAutoGlmnet.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#' @title Classification GLM with Elastic Net Regularization Auto Learner
#'
#' @description
#' Classification auto learner.
#'
#' @template param_id
#'
#' @export
LearnerClassifAutoGlmnet = R6Class("LearnerClassifAutoGlmnet",
inherit = LearnerClassifAuto,
public = list(

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_glmnet") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "glmnet")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "glmnet")
}
)
)

#' @include aaa.R
learners[["classif.auto_glmnet"]] = LearnerClassifAutoGlmnet
40 changes: 40 additions & 0 deletions R/LearnerClassifAutoKKNN.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#' @title Classification k-Nearest-Neighbor Auto Learner
#'
#' @description
#' Classification auto learner.
#'
#' @template param_id
#'
#' @export
LearnerClassifAutoKKNN = R6Class("LearnerClassifAutoKKNN",
inherit = LearnerClassifAuto,
public = list(

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_kknn") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "kknn")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "kknn")
}
)
)

#' @include aaa.R
learners[["classif.auto_kknn"]] = LearnerClassifAutoKKNN
41 changes: 41 additions & 0 deletions R/LearnerClassifAutoLightGBM.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#' @title Classification LightGBM Auto Learner
#'
#' @description
#' Classification auto learner.
#'
#' @template param_id
#'
#' @export
LearnerClassifAutoLightGBM = R6Class("LearnerClassifAutoLightGBM",
inherit = LearnerClassifAuto,
public = list(

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_lightgbm") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"lightgbm_eval_metric",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "lightgbm")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "lightgbm")
}
)
)

#' @include aaa.R
learners[["classif.auto_lightgbm"]] = LearnerClassifAutoLightGBM
40 changes: 40 additions & 0 deletions R/LearnerClassifAutoNnet.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#' @title Classification Neural Network Auto Learner
#'
#' @description
#' Classification auto learner.
#'
#' @template param_id
#'
#' @export
LearnerClassifAutoNnet = R6Class("LearnerClassifAutoNnet",
inherit = LearnerClassifAuto,
public = list(

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_nnet") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "nnet")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "nnet")
}
)
)

#' @include aaa.R
learners[["classif.auto_nnet"]] = LearnerClassifAutoNnet
40 changes: 40 additions & 0 deletions R/LearnerClassifAutoRanger.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#' @title Classification Ranger Auto Learner
#'
#' @description
#' Classification auto learner.
#'
#' @template param_id
#'
#' @export
LearnerClassifAutoRanger = R6Class("LearnerClassifAutoRanger",
inherit = LearnerClassifAuto,
public = list(

#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "classif.auto_ranger") {
super$initialize(id = id)

# reduce parameter set to the relevant parameters
private$.param_set = private$.param_set$subset(
c("learner_ids",
"learner_timeout",
"small_data_size",
"small_data_resampling",
"max_cardinality",
"resampling",
"terminator",
"measure",
"lhs_size",
"callbacks",
"store_benchmark_result")
)

self$param_set$set_values(learner_ids = "ranger")
self$packages = c("mlr3tuning", "mlr3learners", "mlr3pipelines", "mlr3mbo", "mlr3automl", "ranger")
}
)
)

#' @include aaa.R
learners[["classif.auto_ranger"]] = LearnerClassifAutoRanger
25 changes: 25 additions & 0 deletions tests/testthat/test_LearnerClassifAutoCatboost.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
test_that("LearnerClassifAutoCatboost is initialized", {
learner = lrn("classif.auto_catboost",
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 10))

expect_null(learner$graph)
expect_null(learner$tuning_space)
})

test_that("LearnerClassifAutoCatboost is trained", {
rush_plan(n_workers = 2)
lgr::get_logger("mlr3automl")$set_threshold("debug")

task = tsk("penguins")
learner = lrn("classif.auto_catboost",
small_data_size = 1,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 6)
)

expect_class(learner$train(task), "LearnerClassifAutoCatboost")
expect_equal(learner$graph$param_set$values$branch.selection, "catboost")
expect_equal(learner$model$instance$result$branch.selection, "catboost")
})
25 changes: 25 additions & 0 deletions tests/testthat/test_LearnerClassifAutoGlmnet.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
test_that("LearnerClassifAutoGlmnet is initialized", {
learner = lrn("classif.auto_glmnet",
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 10))

expect_null(learner$graph)
expect_null(learner$tuning_space)
})

test_that("LearnerClassifAutoGlmnet is trained", {
rush_plan(n_workers = 2)
lgr::get_logger("mlr3automl")$set_threshold("debug")

task = tsk("penguins")
learner = lrn("classif.auto_glmnet",
small_data_size = 1,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 6)
)

expect_class(learner$train(task), "LearnerClassifAutoGlmnet")
expect_equal(learner$graph$param_set$values$branch.selection, "glmnet")
expect_equal(learner$model$instance$result$branch.selection, "glmnet")
})
25 changes: 25 additions & 0 deletions tests/testthat/test_LearnerClassifAutoKKNN.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
test_that("LearnerClassifAutoKKNN is initialized", {
learner = lrn("classif.auto_kknn",
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 10))

expect_null(learner$graph)
expect_null(learner$tuning_space)
})

test_that("LearnerClassifAutoKKNN is trained", {
rush_plan(n_workers = 2)
lgr::get_logger("mlr3automl")$set_threshold("debug")

task = tsk("penguins")
learner = lrn("classif.auto_kknn",
small_data_size = 1,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 6)
)

expect_class(learner$train(task), "LearnerClassifAutoKKNN")
expect_equal(learner$graph$param_set$values$branch.selection, "kknn")
expect_equal(learner$model$instance$result$branch.selection, "kknn")
})
25 changes: 25 additions & 0 deletions tests/testthat/test_LearnerClassifAutoLightGBM.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
test_that("LearnerClassifAutoLightGBM is initialized", {
learner = lrn("classif.auto_lightgbm",
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 10))

expect_null(learner$graph)
expect_null(learner$tuning_space)
})

test_that("LearnerClassifAutoLightGBM is trained", {
rush_plan(n_workers = 2)
lgr::get_logger("mlr3automl")$set_threshold("debug")

task = tsk("penguins")
learner = lrn("classif.auto_lightgbm",
small_data_size = 1,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 6)
)

expect_class(learner$train(task), "LearnerClassifAutoLightGBM")
expect_equal(learner$graph$param_set$values$branch.selection, "lightgbm")
expect_equal(learner$model$instance$result$branch.selection, "lightgbm")
})
25 changes: 25 additions & 0 deletions tests/testthat/test_LearnerClassifAutoNnet.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
test_that("LearnerClassifAutoNnet is initialized", {
learner = lrn("classif.auto_nnet",
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 10))

expect_null(learner$graph)
expect_null(learner$tuning_space)
})

test_that("LearnerClassifAutoNnet is trained", {
rush_plan(n_workers = 2)
lgr::get_logger("mlr3automl")$set_threshold("debug")

task = tsk("penguins")
learner = lrn("classif.auto_nnet",
small_data_size = 1,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 6)
)

expect_class(learner$train(task), "LearnerClassifAutoNnet")
expect_equal(learner$graph$param_set$values$branch.selection, "nnet")
expect_equal(learner$model$instance$result$branch.selection, "nnet")
})
Loading

0 comments on commit eed0eeb

Please sign in to comment.