Skip to content

Commit

Permalink
Merge pull request #90 from mlr-org/tags
Browse files Browse the repository at this point in the history
refactor: add missing train tags and only extract task data once
  • Loading branch information
be-marc authored Sep 12, 2024
2 parents 2e4b002 + d014269 commit bbedd87
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 25 deletions.
2 changes: 1 addition & 1 deletion R/LearnerClustAffinityPropagation.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ LearnerClustAP = R6Class("LearnerClustAP",
},

.predict = function(task) {
pv = self$param_set$get_values()
pv = self$param_set$get_values(tags = "train")
sim_func = pv$s
exemplar_data = attributes(self$model)$exemplar_data

Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustAgnes.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ LearnerClustAgnes = R6Class("LearnerClustAgnes",
tags = "train"
),
trace.lev = p_int(0L, default = 0L, tags = "train"),
k = p_int(1L, default = 2L, tags = "predict"),
k = p_int(1L, default = 2L, tags = c("train", "predict")),
par.method = p_uty(
tags = "train",
depends = quote(method %in% c("flexible", "gaverage")),
Expand Down Expand Up @@ -65,7 +65,7 @@ LearnerClustAgnes = R6Class("LearnerClustAgnes",
),
private = list(
.train = function(task) {
pv = self$param_set$get_values()
pv = self$param_set$get_values(tags = "train")
m = invoke(cluster::agnes,
x = task$data(),
diss = FALSE,
Expand Down
5 changes: 3 additions & 2 deletions R/LearnerClustDBSCAN.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ LearnerClustDBSCAN = R6Class("LearnerClustDBSCAN",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
m = invoke(dbscan::dbscan, x = task$data(), .args = pv)
m = insert_named(m, list(data = task$data()))
data = task$data()
m = invoke(dbscan::dbscan, x = data, .args = pv)
m = insert_named(m, list(data = data))
if (self$save_assignments) {
self$assignments = m$cluster
}
Expand Down
5 changes: 3 additions & 2 deletions R/LearnerClustDBSCANfpc.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ LearnerClustDBSCANfpc = R6Class("LearnerClustDBSCANfpc",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
m = invoke(fpc::dbscan, data = task$data(), .args = pv)
m = insert_named(m, list(data = task$data()))
data = task$data()
m = invoke(fpc::dbscan, data = data, .args = pv)
m = insert_named(m, list(data = data))
if (self$save_assignments) {
self$assignments = m$cluster
}
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustDiana.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ LearnerClustDiana = R6Class("LearnerClustDiana",
metric = p_fct(default = "euclidean", levels = c("euclidean", "manhattan"), tags = "train"),
stand = p_lgl(default = FALSE, tags = "train"),
trace.lev = p_int(0L, default = 0L, tags = "train"),
k = p_int(1L, default = 2L, tags = "predict")
k = p_int(1L, default = 2L, tags = c("train", "predict"))
)

param_set$set_values(k = 2L)
Expand All @@ -46,7 +46,7 @@ LearnerClustDiana = R6Class("LearnerClustDiana",
),
private = list(
.train = function(task) {
pv = self$param_set$get_values()
pv = self$param_set$get_values(tags = "train")
m = invoke(cluster::diana,
x = task$data(),
diss = FALSE,
Expand Down
5 changes: 3 additions & 2 deletions R/LearnerClustHDBSCAN.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ LearnerClustHDBSCAN = R6Class("LearnerClustHDBSCAN",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
m = invoke(dbscan::hdbscan, x = task$data(), .args = pv)
m = insert_named(m, list(data = task$data()))
data = task$data()
m = invoke(dbscan::hdbscan, x = data, .args = pv)
m = insert_named(m, list(data = data))

if (self$save_assignments) {
self$assignments = m$cluster
Expand Down
4 changes: 2 additions & 2 deletions R/LearnerClustHclust.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ LearnerClustHclust = R6Class("LearnerClustHclust",
diag = p_lgl(default = FALSE, tags = c("train", "dist")),
upper = p_lgl(default = FALSE, tags = c("train", "dist")),
p = p_dbl(default = 2, tags = c("train", "dist"), depends = quote(distmethod == "minkowski")),
k = p_int(1L, default = 2L, tags = "predict")
k = p_int(1L, default = 2L, tags = c("train", "predict"))
)

param_set$set_values(k = 2L, distmethod = "euclidean")
Expand All @@ -54,7 +54,7 @@ LearnerClustHclust = R6Class("LearnerClustHclust",
),
private = list(
.train = function(task) {
pv = self$param_set$get_values()
pv = self$param_set$get_values(tags = "train")
dist = invoke(stats::dist,
x = task$data(),
method = pv$d %??% "euclidean",
Expand Down
5 changes: 3 additions & 2 deletions R/LearnerClustKKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ LearnerClustKKMeans = R6Class("LearnerClustKKMeans",

c = kernlab::centers(self$model)
K = kernlab::kernelf(self$model)
data = task$data()

# kernel product between each new datapoint and the centers
d_xc = matrix(kernlab::kernelMatrix(K, as.matrix(task$data()), c), ncol = nrow(c))
d_xc = matrix(kernlab::kernelMatrix(K, as.matrix(data), c), ncol = nrow(c))
# kernel product between each new datapoint and itself: rows are identical
d_xx = matrix(
rep(diag(kernlab::kernelMatrix(K, as.matrix(task$data()))), each = ncol(d_xc)),
rep(diag(kernlab::kernelMatrix(K, as.matrix(data))), each = ncol(d_xc)),
ncol = ncol(d_xc), byrow = TRUE
)
# kernel product between each center and itself: columns are identical
Expand Down
10 changes: 6 additions & 4 deletions R/LearnerClustMiniBatchKMeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,20 @@ LearnerClustMiniBatchKMeans = R6Class("LearnerClustMiniBatchKMeans",
stopf("`CENTROIDS` must have same number of rows as `clusters`")
}

m = invoke(ClusterR::MiniBatchKmeans, data = task$data(), .args = pv)
data = task$data()
m = invoke(ClusterR::MiniBatchKmeans, data = data, .args = pv)
if (self$save_assignments) {
self$assignments = as.integer(invoke(predict, m, newdata = task$data()))
self$assignments = as.integer(invoke(predict, m, newdata = data))
}
m
},

.predict = function(task) {
partition = as.integer(invoke(predict, self$model, newdata = task$data()))
data = task$data()
partition = as.integer(invoke(predict, self$model, newdata = data))
prob = NULL
if (self$predict_type == "prob") {
prob = invoke(predict, self$model, newdata = task$data(), fuzzy = TRUE)
prob = invoke(predict, self$model, newdata = data, fuzzy = TRUE)
colnames(prob) = seq_len(ncol(prob))
}
PredictionClust$new(task = task, partition = partition, prob = prob)
Expand Down
5 changes: 3 additions & 2 deletions R/LearnerClustOPTICS.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ LearnerClustOPTICS = R6Class("LearnerClustOPTICS",
private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
m = invoke(dbscan::optics, x = task$data(), .args = remove_named(pv, "eps_cl"))
m = insert_named(m, list(data = task$data()))
data = task$data()
m = invoke(dbscan::optics, x = data, .args = remove_named(pv, "eps_cl"))
m = insert_named(m, list(data = data))
m = invoke(dbscan::extractDBSCAN, object = m, eps_cl = pv$eps_cl)

if (self$save_assignments) {
Expand Down
6 changes: 2 additions & 4 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
#' @importFrom stats model.frame terms predict runif dist
"_PACKAGE"

utils::globalVariables("type")

mlr3cluster_tasks = new.env()
mlr3cluster_learners = new.env()

Expand All @@ -28,7 +26,7 @@ register_learner = function(name, constructor) {
register_mlr3 = function() {
# reflections
mlr_reflections = utils::getFromNamespace("mlr_reflections", ns = "mlr3")
mlr_reflections$task_types = mlr_reflections$task_types[type != "clust"]
mlr_reflections$task_types = mlr_reflections$task_types[!"clust"]
mlr_reflections$task_types = setkeyv(rbind(mlr_reflections$task_types, rowwise_table(
~type, ~package, ~task, ~learner, ~prediction, ~prediction_data, ~measure,
"clust", "mlr3cluster", "TaskClust", "LearnerClust", "PredictionClust", "PredictionDataClust", "MeasureClust"
Expand Down Expand Up @@ -70,7 +68,7 @@ register_mlr3 = function() {
walk(names(mlr3cluster_learners), function(id) mlr_learners$remove(id))
walk(names(measures), function(id) mlr_measures$remove(paste("clust", id, sep = ".")))

mlr_reflections$task_types = mlr_reflections$task_types[type != "clust"]
mlr_reflections$task_types = mlr_reflections$task_types[!"clust"]
reflections = c(
"measure_properties", "default_measures", "learner_properties",
"learner_predict_types", "task_properties", "task_col_roles"
Expand Down

0 comments on commit bbedd87

Please sign in to comment.