Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipelines to Imports refactoring #419

Merged
merged 18 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.6.9
Version: 0.7.0
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -60,6 +60,7 @@ Imports:
distr6 (>= 1.8.4),
ggplot2,
mlr3misc (>= 0.7.0),
mlr3pipelines (>= 0.7.0),
mlr3viz,
paradox (>= 1.0.0),
R6,
Expand All @@ -71,7 +72,6 @@ Suggests:
knitr,
lgr,
lifecycle,
mlr3pipelines (>= 0.3.4),
param6 (>= 0.2.4),
pracma,
rpart,
Expand Down
10 changes: 7 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ S3method(pecs,list)
S3method(plot,LearnerSurv)
S3method(plot,TaskDens)
S3method(plot,TaskSurv)
export(.c_get_unique_times)
export(.c_weight_survival_score)
export(.surv_return)
export(LearnerDens)
Expand Down Expand Up @@ -115,13 +114,18 @@ import(paradox)
importFrom(R6,R6Class)
importFrom(Rcpp,sourceCpp)
importFrom(graphics,plot)
importFrom(mlr3pipelines,"%>>%")
importFrom(mlr3pipelines,Graph)
importFrom(mlr3pipelines,as_graph)
importFrom(mlr3pipelines,gunion)
importFrom(mlr3pipelines,pipeline_greplicate)
importFrom(mlr3pipelines,po)
importFrom(mlr3pipelines,ppl)
importFrom(mlr3viz,fortify)
importFrom(stats,complete.cases)
importFrom(stats,density)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,predict)
importFrom(stats,reformulate)
importFrom(stats,sd)
importFrom(survival,Surv)
importFrom(utils,data)
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# mlr3proba 0.7.0

* Add `mlr3pipelines` to `Imports` and set minimum latest version from CRAN (`0.7.0`)
* Refactor code to minimize namespace calling and imports such as `mlr3pipelines::` or `R6::`
* Add experimental badge in the documentation of a few more PipeOps
* Add argument `scale_lp` for AFT `distrcompose` pipeop + respective pipeline

# mlr3proba 0.6.9

* New `PipeOp`s: `PipeOpTaskSurvClassifIPCW`, `PipeOpPredClassifSurvIPCW`
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerDens.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#' # get a specific learner from mlr_learners:
#' mlr_learners$get("dens.hist")
#' lrn("dens.hist")
LearnerDens = R6::R6Class("LearnerDens",
LearnerDens = R6Class("LearnerDens",
inherit = Learner,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerDensHistogram.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' @templateVar caller [graphics::hist()]
#'
#' @export
LearnerDensHistogram = R6::R6Class("LearnerDensHistogram",
LearnerDensHistogram = R6Class("LearnerDensHistogram",
inherit = LearnerDens,
public = list(
#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/LearnerDensKDE.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#' `r format_bib("silverman_1986")`
#'
#' @export
LearnerDensKDE = R6::R6Class("LearnerDensKDE",
LearnerDensKDE = R6Class("LearnerDensKDE",
inherit = LearnerDens,
public = list(
#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureDensLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#'
#' @family Density estimation measures
#' @export
MeasureDensLogloss = R6::R6Class("MeasureDensLogloss",
MeasureDensLogloss = R6Class("MeasureDensLogloss",
inherit = MeasureDens,
public = list(
#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureRegrLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#' \deqn{L(f, y) = -\log(f(y))}{L(f, y) = -log(f(y))}
#'
#' @export
MeasureRegrLogloss = R6::R6Class("MeasureRegrLogloss",
MeasureRegrLogloss = R6Class("MeasureRegrLogloss",
inherit = MeasureRegr,
public = list(
#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvGraf.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @export
MeasureSurvGraf = R6::R6Class("MeasureSurvGraf",
MeasureSurvGraf = R6Class("MeasureSurvGraf",
inherit = MeasureSurv,
public = list(
#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvIntLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @export
MeasureSurvIntLogloss = R6::R6Class("MeasureSurvIntLogloss",
MeasureSurvIntLogloss = R6Class("MeasureSurvIntLogloss",
inherit = MeasureSurv,
public = list(
#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvLogloss.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @export
MeasureSurvLogloss = R6::R6Class("MeasureSurvLogloss",
MeasureSurvLogloss = R6Class("MeasureSurvLogloss",
inherit = MeasureSurv,
public = list(
#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvMAE.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#'
#' @family response survival measures
#' @export
MeasureSurvMAE = R6::R6Class("MeasureSurvMAE",
MeasureSurvMAE = R6Class("MeasureSurvMAE",
inherit = MeasureSurv,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvMSE.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#'
#' @family response survival measures
#' @export
MeasureSurvMSE = R6::R6Class("MeasureSurvMSE",
MeasureSurvMSE = R6Class("MeasureSurvMSE",
inherit = MeasureSurv,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvRCLL.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @export
MeasureSurvRCLL = R6::R6Class("MeasureSurvRCLL",
MeasureSurvRCLL = R6Class("MeasureSurvRCLL",
inherit = MeasureSurv,
public = list(
#' @description
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvRMSE.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#'
#' @family response survival measures
#' @export
MeasureSurvRMSE = R6::R6Class("MeasureSurvRMSE",
MeasureSurvRMSE = R6Class("MeasureSurvRMSE",
inherit = MeasureSurv,
public = list(
#' @description Creates a new instance of this [R6][R6::R6Class] class.
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvSchmid.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
#' @family Probabilistic survival measures
#' @family distr survival measures
#' @export
MeasureSurvSchmid = R6::R6Class("MeasureSurvSchmid",
MeasureSurvSchmid = R6Class("MeasureSurvSchmid",
inherit = MeasureSurv,
public = list(
#' @description
Expand Down
4 changes: 1 addition & 3 deletions R/PipeOpBreslow.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@
#' @seealso [pipeline_distrcompositor]
#' @export
#' @family survival compositors
#' @examples
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines"), quietly = TRUE)
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
#' library(mlr3)
#' library(mlr3pipelines)
#' task = tsk("rats")
Expand All @@ -57,7 +56,6 @@
#' b$train(list(train_task))
#' p = b$predict(list(test_task))[[1L]]
#' }
#' }
PipeOpBreslow = R6Class("PipeOpBreslow",
inherit = mlr3pipelines::PipeOp,
public = list(
Expand Down
6 changes: 3 additions & 3 deletions R/PipeOpCrankCompositor.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@
#' If `TRUE`, then the `crank` will be overwritten.
#'
#' @seealso [pipeline_crankcompositor]
#' @references
#' `r format_bib("sonabend_2022", "ishwaran_2008")`
#' @family survival compositors
#' @examples
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines"), quietly = TRUE)
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
#' library(mlr3pipelines)
#' task = tsk("rats")
#'
Expand All @@ -47,7 +48,6 @@
#' poc = po("crankcompose", param_vals = list(overwrite = TRUE))
#' poc$predict(list(pred))[[1L]]
#' }
#' }
#' @export
PipeOpCrankCompositor = R6Class("PipeOpCrankCompositor",
inherit = mlr3pipelines::PipeOp,
Expand Down
42 changes: 33 additions & 9 deletions R/PipeOpDistrCompositor.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#' @template param_pipelines
#'
#' @description
#' `r lifecycle::badge("experimental")`
#'
#' Estimates (or 'composes') a survival distribution from a predicted baseline
#' survival distribution (`distr`) and a linear predictor (`lp`) from two [PredictionSurv]s.
#'
Expand Down Expand Up @@ -46,21 +48,37 @@
#' nothing and returns the given [PredictionSurv]. If `TRUE`, then the `distr` is overwritten
#' with the `distr` composed from `lp` - this is useful for changing the prediction
#' `distr` from one model form to another.
#' * `scale_lp` :: `logical(1)` \cr
#' This option is only applicable to `form` equal to `"aft"`. If `TRUE`, it
#' min-max scales the linear prediction scores to be in the interval \eqn{[0,1]},
#' avoiding extrapolation of the baseline \eqn{S_0(t)} on the transformed time
#' points \eqn{\frac{t}{\exp(lp)}}, as these will be \eqn{\in [\frac{t}{e}, t]},
#' and so always smaller than the maximum time point for which we have estimated
#' \eqn{S_0(t)}.
#' Note that this is just a **heuristic** to get reasonable results in the
#' case you observe survival predictions to be e.g. constant after the AFT
#' composition and it definitely provides no guarantee for creating calibrated
#' distribution predictions (as none of these methods do). Therefore, it is
#' set to `FALSE` by default.
#'
#' @section Internals:
#' The respective `form`s above have respective survival distributions:
#' \deqn{aft: S(t) = S_0(\frac{t}{exp(lp)})}{aft: S(t) = S0(t/exp(lp))}
#' \deqn{ph: S(t) = S_0(t)^{exp(lp)}}{ph: S(t) = S0(t)^exp(lp)}
#' \deqn{po: S(t) = \frac{S_0(t)}{exp(-lp) + (1-exp(-lp)) S_0(t)}}{po: S(t) = S0(t) / [exp(-lp) + S0(t) (1-exp(-lp))]} # nolint
#' where \eqn{S_0}{S0} is the estimated baseline survival distribution, and \eqn{lp} is the
#' \deqn{aft: S(t) = S_0(\frac{t}{\exp(lp)})}
#' \deqn{ph: S(t) = S_0(t)^{\exp(lp)}}
#' \deqn{po: S(t) = \frac{S_0(t)}{\exp(-lp) + (1-\exp(-lp)) S_0(t)}}
#' where \eqn{S_0} is the estimated baseline survival distribution, and \eqn{lp} is the
#' predicted linear predictor.
#'
#' For an example use of the `"aft"` composition using Kaplan-Meier as a baseline
#' distribution, see Norman et al. (2024).
#'
#' @seealso [pipeline_distrcompositor]
#' @references
#' `r format_bib("norman_2024")`
#' @export
#' @family survival compositors
#' @examples
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines"), quietly = TRUE)
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
#' library(mlr3)
#' library(mlr3pipelines)
#' task = tsk("rats")
Expand All @@ -71,7 +89,6 @@
#' pod = po("distrcompose", param_vals = list(form = "aft", overwrite = TRUE))
#' pod$predict(list(base = base, pred = pred))[[1]]
#' }
#' }
PipeOpDistrCompositor = R6Class("PipeOpDistrCompositor",
inherit = mlr3pipelines::PipeOp,
public = list(
Expand All @@ -80,9 +97,10 @@ PipeOpDistrCompositor = R6Class("PipeOpDistrCompositor",
initialize = function(id = "distrcompose", param_vals = list()) {
param_set = ps(
form = p_fct(default = "aft", levels = c("aft", "ph", "po"), tags = "predict"),
overwrite = p_lgl(default = FALSE, tags = "predict")
overwrite = p_lgl(default = FALSE, tags = "predict"),
scale_lp = p_lgl(default = FALSE, tags = "predict")
)
param_set$set_values(form = "aft", overwrite = FALSE)
param_set$set_values(form = "aft", overwrite = FALSE, scale_lp = FALSE)

super$initialize(
id = id,
Expand Down Expand Up @@ -145,6 +163,12 @@ PipeOpDistrCompositor = R6Class("PipeOpDistrCompositor",
if (form == "ph") {
cdf = 1 - (survmat ^ exp(lpmat))
} else if (form == "aft") {
# add heuristic to keep the transformed t/exp(lp) time points within
# the domain of S_0(t)
if (self$param_set$values$scale_lp) {
lpmat = (lpmat - min(lpmat)) / (max(lpmat) - min(lpmat))
}
# calculate cdf = 1 - S_0(t) on the time points t/exp(lp)
mtc = findInterval(timesmat / exp(lpmat), times)
mtc[mtc == 0] = NA
cdf = 1 - matrix(survmat[1L, mtc], nr, nc, FALSE)
Expand Down
3 changes: 1 addition & 2 deletions R/PipeOpPredClassifSurvDiscTime.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpPredClassifSurvDiscTime = R6Class(
"PipeOpPredClassifSurvDiscTime",
PipeOpPredClassifSurvDiscTime = R6Class("PipeOpPredClassifSurvDiscTime",
inherit = mlr3pipelines::PipeOp,

public = list(
Expand Down
3 changes: 1 addition & 2 deletions R/PipeOpPredClassifSurvIPCW.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
#' @family PipeOps
#' @family Transformation PipeOps
#' @export
PipeOpPredClassifSurvIPCW = R6Class(
"PipeOpPredClassifSurvIPCW",
PipeOpPredClassifSurvIPCW = R6Class("PipeOpPredClassifSurvIPCW",
inherit = mlr3pipelines::PipeOp,

public = list(
Expand Down
4 changes: 1 addition & 3 deletions R/PipeOpPredRegrSurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
#' If `NULL` then assumed no censoring in the dataset. Otherwise should be a vector of `0/1`s
#' of same length as the prediction object, where `1` is dead and `0` censored.
#'
#' @examples
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines"), quietly = TRUE)
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
#' library(mlr3)
#' library(mlr3pipelines)
#'
Expand All @@ -45,7 +44,6 @@
#' new_pred = po$predict(list(pred = pred, task = task_surv))[[1]]
#' all.equal(new_pred$truth, task_surv$truth())
#' }
#' }
#' @family PipeOps
#' @family Transformation PipeOps
#' @include PipeOpPredTransformer.R
Expand Down
5 changes: 1 addition & 4 deletions R/PipeOpPredSurvRegr.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
#' @section State:
#' The `$state` is a named `list` with the `$state` elements inherited from [PipeOpPredTransformer].
#'
#'
#' @examples
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines"), quietly = TRUE)
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
#' library(mlr3)
#' library(mlr3pipelines)
#' library(survival)
Expand All @@ -29,7 +27,6 @@
#' new_pred = po$predict(list(pred = pred))[[1]]
#' print(new_pred)
#' }
#' }
#' @family PipeOps
#' @family Transformation PipeOps
#' @include PipeOpPredTransformer.R
Expand Down
9 changes: 4 additions & 5 deletions R/PipeOpProbregrCompositor.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#' @template param_pipelines
#'
#' @description
#' Combines a predicted `reponse` and `se` from [PredictionRegr] with a specified probability
#' `r lifecycle::badge("experimental")`
#'
#' Combines a predicted `response` and `se` from [PredictionRegr] with a specified probability
#' distribution to estimate (or 'compose') a `distr` prediction.
#'
#' @section Dictionary:
Expand Down Expand Up @@ -38,10 +40,8 @@
#' distribution location and scale parameters respectively.
#'
#' @export
#' @examples
#' @examplesIf mlr3misc::require_namespaces(c("mlr3pipelines", "rpart"), quietly = TRUE)
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE) &&
#' requireNamespace("rpart", quietly = TRUE)) {
#' library(mlr3)
#' library(mlr3pipelines)
#' set.seed(1)
Expand All @@ -61,7 +61,6 @@
#' poc = po("compose_probregr")
#' poc$predict(list(pred_response, pred_se))[[1]]
#' }
#' }
PipeOpProbregr = R6Class("PipeOpProbregr",
inherit = mlr3pipelines::PipeOp,
public = list(
Expand Down
Loading
Loading