Skip to content

Commit

Permalink
Add permutation filter and update documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
be-marc committed Apr 13, 2020
1 parent 52748f8 commit 7989c10
Show file tree
Hide file tree
Showing 23 changed files with 281 additions and 0 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Collate:
'FilterMRMR.R'
'FilterNJMIM.R'
'FilterPerformance.R'
'FilterPermutation.R'
'FilterVariance.R'
'flt.R'
'reexports.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export(FilterMIM)
export(FilterMRMR)
export(FilterNJMIM)
export(FilterPerformance)
export(FilterPermutation)
export(FilterVariance)
export(as.data.table)
export(flt)
Expand Down
117 changes: 117 additions & 0 deletions R/FilterPermutation.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#' @title Permutation Filter
#'
#' @name mlr_filters_permutation
#'
#' @description
#' Estimate how important individual features are by contrasting prediction
#' performances. Compute the change in performance from permuting the values of
#' a feature and compare that to the predictions made on the unmodified data.
#'
#' @section Parameters:
#' \describe{
#' \item{`standardize`}{`logical(1)`\cr
#' Standardize feature importance by maximum score.}
#' \item{`nmc`}{`integer(1)`}\cr
#' Number of Monte-Carlo iterations to use in computing the feature importance.
#' }
#'
#' @family Filter
#' @template seealso_filter
#' @export
FilterPermutation = R6Class("FilterPermutation",
inherit = Filter,
public = list(

#' @field learner ([mlr3::Learner])\cr
learner = NULL,
#' @field resampling ([mlr3::Resampling])\cr
resampling = NULL,
#' @field measure ([mlr3::Measure])\cr
measure = NULL,

#' @description Create a FilterDISR object.
#' @param id (`character(1)`)\cr
#' Identifier for the filter.
#' @param task_type (`character()`)\cr
#' Types of the task the filter can operator on. E.g., `"classif"` or
#' `"regr"`.
#' @param param_set ([paradox::ParamSet])\cr
#' Set of hyperparameters.
#' @param feature_types (`character()`)\cr
#' Feature types the filter operates on.
#' Must be a subset of
#' [`mlr_reflections$task_feature_types`][mlr3::mlr_reflections].
#' @param learner ([mlr3::Learner])\cr
#' [mlr3::Learner] to use for model fitting.
#' @param resampling ([mlr3::Resampling])\cr
#' [mlr3::Resampling] to be used within resampling.
#' @param measure ([mlr3::Measure])\cr
#' [mlr3::Measure] to be used for evaluating the performance.
initialize = function(id = "permutation",
task_type = learner$task_type,
param_set = ParamSet$new(list(
ParamLgl$new("standardize", default = FALSE),
ParamInt$new("nmc", default = 50L))),
feature_types = learner$feature_types,
learner = mlr3::lrn("classif.rpart"),
resampling = mlr3::rsmp("holdout"),
measure = mlr3::msr("classif.ce")) {

self$learner = learner = assert_learner(as_learner(learner, clone = TRUE))
self$resampling = assert_resampling(as_resampling(resampling))
self$measure = assert_measure(as_measure(measure,
task_type = learner$task_type, clone = TRUE), learner = learner)
packages = unique(c(self$learner$packages, self$measure$packages))

super$initialize(
id = id,
task_type = task_type,
feature_types = feature_types,
packages = packages,
param_set = param_set,
man = "mlr3filters::mlr_filters_performance"
)
}
),

private = list(
.calculate = function(task, nfeat) {
task = task$clone()
pars = self$param_set$values
fn = task$feature_names
pars$standardize = pars$standardize %??% FALSE
pars$nmc = pars$nmc %??% 50L

rr = resample(task, self$learner, self$resampling)
baseline = rr$aggregate(self$measure)

perf = map_dtr(seq(pars$nmc), function(i) {
set_names(map_dtc(fn, function(x) {
task = task$clone()
data = task$data()
column = data[, x, with = FALSE][[1]]
data[, (x) := column[sample(nrow(data))]]

# Empty task and fill with shuffled column
task$filter(rows = 0)
task$rbind(data)
rr = resample(task, self$learner, self$resampling)
rr$aggregate(self$measure)
}), fn)
})
delta = baseline - as.matrix(perf[, lapply(.SD, mean)])[1,]

if (self$measure$minimize) {
delta = -delta
}

if(pars$standardize) {
delta = delta/max(delta)
}
set_names(delta, fn)
}
)
)

#' @include mlr_filters.R
mlr_filters$add("permutation", FilterPermutation)
1 change: 1 addition & 0 deletions man/Filter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_anova.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_auc.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_carscore.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_cmim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_correlation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_disr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_find_correlation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_information_gain.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_jmi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_jmim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_kruskal_test.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_mim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_mrmr.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_njmim.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_filters_performance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

1 comment on commit 7989c10

@lintr-bot
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

R/FilterPermutation.R:102:65: style: Commas should always have a space after.

delta = baseline - as.matrix(perf[, lapply(.SD, mean)])[1,]
                                                                ^

R/FilterPermutation.R:108:9: style: Place a space before left parenthesis, except in a function call.

if(pars$standardize) {
        ^

R/FilterPermutation.R:109:24: style: Put spaces around all infix operators.

delta = delta/max(delta)
                      ~^~

Please sign in to comment.