Skip to content

Commit

Permalink
Merge pull request #88 from mlr-org/empty_predictions
Browse files Browse the repository at this point in the history
support for empty prediction sets (mlr-org/mlr3#1089)
  • Loading branch information
be-marc authored Aug 22, 2024
2 parents 560496d + 7fe1a8d commit 2593441
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 0 deletions.
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ Suggests:
RWeka,
stream,
testthat (>= 3.0.0)
Remotes:
mlr-org/mlr3
Config/testthat/edition: 3
Encoding: UTF-8
Roxygen: list(markdown = TRUE, r6 = TRUE)
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(as_task_clust,data.frame)
S3method(as_task_clust,formula)
S3method(c,PredictionDataClust)
S3method(check_prediction_data,PredictionDataClust)
S3method(create_empty_prediction_data,TaskClust)
S3method(filter_prediction_data,PredictionDataClust)
S3method(is_missing_prediction_data,PredictionDataClust)
export(LearnerClust)
Expand Down
16 changes: 16 additions & 0 deletions R/PredictionDataClust.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,19 @@ filter_prediction_data.PredictionDataClust = function(pdata, row_ids, ...) {

pdata
}

#' @export
create_empty_prediction_data.TaskClust = function(task, learner) {
predict_types = mlr_reflections$learner_predict_types[["clust"]][[learner$predict_type]]

pdata = list(
row_ids = integer(),
partition = integer()
)

if ("prob" %in% predict_types) {
pdata$prob = matrix(integer())
}

set_class(pdata, "PredictionDataClust")
}
24 changes: 24 additions & 0 deletions tests/testthat/test_PredictionClust.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,27 @@ test_that("filter works", {
expect_set_equal(pdata$row_ids, 1:3)
expect_integer(pdata$partition, len = 3)
})

test_that("construction of empty PredictionDataClust", {
task = tsk("usarrests")

learner = lrn("clust.featureless", predict_type = "partition")
learner$train(task)
pred = learner$predict(task, row_ids = integer())
expect_prediction(pred)
expect_set_equal(pred$predict_types, "partition")
expect_integer(pred$row_ids, len = 0L)
expect_numeric(pred$partition, len = 0L)
expect_null(pred$prob)
expect_data_table(as.data.table(pred), nrows = 0L, ncols = 2L)

learner = lrn("clust.featureless", predict_type = "prob")
learner$train(task)
pred = learner$predict(task, row_ids = integer())
expect_prediction(pred)
expect_set_equal(pred$predict_types, c("partition", "prob"))
expect_integer(pred$row_ids, len = 0L)
expect_numeric(pred$partition, len = 0L)
expect_numeric(pred$prob, len = 0L)
expect_data_table(as.data.table(pred), nrows = 0L, ncols = 3L)
})

0 comments on commit 2593441

Please sign in to comment.