Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PLS regression with mixOmics engine - predictions not extracted from predict object #132

Open
marioem opened this issue Oct 18, 2022 · 5 comments

Comments

@marioem
Copy link

marioem commented Oct 18, 2022

Hi,

vi_permute fails on PLS perdict object.

library(plsmod) # parsnip helper for pls
#> Loading required package: parsnip
library(mixOmics) # for pls regression
#> Loading required package: MASS
#> Loading required package: lattice
#> Loading required package: ggplot2
#> 
#> Loaded mixOmics 6.21.0
#> Thank you for using mixOmics!
#> Tutorials: http://mixomics.org
#> Bookdown vignette: https://mixomicsteam.github.io/Bookdown
#> Questions, issues: Follow the prompts at http://mixomics.org/contact-us
#> Cite us:  citation('mixOmics')
#> 
#> Attaching package: 'mixOmics'
#> The following objects are masked from 'package:parsnip':
#> 
#>     pls, tune
library(tidymodels)

library(doMC)
#> Loading required package: foreach
#> 
#> Attaching package: 'foreach'
#> The following objects are masked from 'package:purrr':
#> 
#>     accumulate, when
#> Loading required package: iterators
#> Loading required package: parallel

registerDoMC(cores = parallel::detectCores() - 1) # Mac and Linux only :-)
tidymodels_prefer()
data(concrete, package = "modeldata")

concrete <- 
  concrete %>% 
  group_by(across(-compressive_strength)) %>% 
  summarize(compressive_strength = mean(compressive_strength),
            .groups = "drop")
nrow(concrete)
#> [1] 992

set.seed(1501)
concrete_split <- initial_split(concrete, strata = compressive_strength)
concrete_train <- training(concrete_split)
concrete_test  <- testing(concrete_split)

set.seed(1502)
# concrete_folds <- 
#   vfold_cv(concrete_train, strata = compressive_strength) # , repeats = 5)

concrete_folds <- 
  bootstraps(concrete_train, times = 5) # 5 times for the sake of time for this reprex

pls_spec <- pls(num_comp = tune()) %>% 
  set_mode("regression") %>% 
  set_engine("mixOmics")

normalized_rec <- 
  recipe(compressive_strength ~ ., data = concrete_train) %>% 
  step_normalize(all_predictors()) 

normalized <-
  workflow_set(
    preproc = list(normalized = normalized_rec),
    models = list(PLS = pls_spec)
  )

all_workflows <- normalized

bayes_ctrl <-
  control_bayes(
    seed = 828, 
    save_pred = TRUE,
    parallel_over = "everything",
    verbose = T,
    save_workflow = TRUE
  )

grid_results <-
  all_workflows %>%
  workflow_map("tune_bayes",
               resamples = concrete_folds,
               metrics = metric_set(rmse, mape, smape, mae, yardstick::ccc, huber_loss),
               iter = 5,
               control = bayes_ctrl)
#> 
#> ❯  Generating a set of 4 initial parameter results
#> ✓ Initialization complete
#> 
#> Optimizing rmse using the expected improvement
#> 
#> ── Iteration 1 ─────────────────────────────────────────────────────────────────
#> 
#> i Current best:      rmse=10.79 (@iter 0)
#> i Gaussian process model
#> ✓ Gaussian process model
#> ! No remaining candidate models
#> x Halting search
#> ✖ Optimization stopped prematurely; returning current results.

grid_results %>% 
  rank_results(rank_metric = "rmse") %>% 
  filter(.metric == "rmse") %>% 
  select(model, wflow_id, .config, rmse = mean, rank)
#> # A tibble: 4 × 5
#>   model wflow_id       .config               rmse  rank
#>   <chr> <chr>          <chr>                <dbl> <int>
#> 1 pls   normalized_PLS Preprocessor1_Model4  10.8     1
#> 2 pls   normalized_PLS Preprocessor1_Model3  10.9     2
#> 3 pls   normalized_PLS Preprocessor1_Model2  11.0     3
#> 4 pls   normalized_PLS Preprocessor1_Model1  11.7     4

best_tuneRmse <- 
  grid_results %>% 
  extract_workflow_set_result("normalized_PLS") %>% 
  select_best(metric = "rmse")

Best_test_resultsRmse <- 
  grid_results %>% 
  extract_workflow("normalized_PLS") %>% 
  finalize_workflow(best_tuneRmse) %>% 
  last_fit(split = concrete_split, metrics = metric_set(rmse, mape, smape, mae, yardstick::ccc, huber_loss))

Best_test_resultsRmse %>% 
extract_fit_parsnip() %>%
  vip::vip(method = "permute",
           num_features = 30,
           train = normalized_rec %>% prep() %>% bake(new_data = NULL), 
           target = "compressive_strength", 
           metric = "rmse", 
           nsim = 500,
           pred_wrapper = predict, 
           geom = "col", 
           all_permutations = F,
           aesthetics = list(color = "grey35"),
           include_type = T
  ) +
  ggtitle("Predictor importance - PLS")
#> Error in predicted - actual: non-numeric argument to binary operator

Best_test_resultsRmse %>% 
  extract_fit_parsnip() %>%
  vip::vip(method = "permute",
           num_features = 30,
           train = normalized_rec %>% prep() %>% bake(new_data = NULL), 
           target = "compressive_strength", 
           metric = "rmse", 
           nsim = 500,
           pred_wrapper = function(object, newdata) { pred = predict(object, newdata); print(str(pred)); pred}, 
           geom = "col", 
           all_permutations = F,
           aesthetics = list(color = "grey35"),
           include_type = T
  ) +
  ggtitle("Predictor importance - PLS")
#> List of 4
#>  $ predict : num [1:743, 1, 1:4] 17.4 14.9 17.7 19.3 16.9 ...
#>   ..- attr(*, "dimnames")=List of 3
#>   .. ..$ : chr [1:743] "1" "2" "3" "4" ...
#>   .. ..$ : chr "Y"
#>   .. ..$ : chr [1:4] "dim1" "dim2" "dim3" "dim4"
#>  $ variates: num [1:743, 1:4] -1.65 -1.88 -1.62 -1.47 -1.7 ...
#>   ..- attr(*, "dimnames")=List of 2
#>   .. ..$ : chr [1:743] "1" "2" "3" "4" ...
#>   .. ..$ : chr [1:4] "dim1" "dim2" "dim3" "dim4"
#>  $ B.hat   : num [1:8, 1, 1:4] 0.3927 0.1143 -0.0773 -0.2285 0.3045 ...
#>   ..- attr(*, "dimnames")=List of 3
#>   .. ..$ : chr [1:8] "cement" "blast_furnace_slag" "fly_ash" "water" ...
#>   .. ..$ : chr "Y"
#>   .. ..$ : chr [1:4] "dim1" "dim2" "dim3" "dim4"
#>  $ call    : language predict.mixo_spls(object = object, newdata = newdata)
#>  - attr(*, "class")= chr "predict"
#> NULL
#> Error in predicted - actual: non-numeric argument to binary operator

Created on 2022-10-18 with reprex v2.0.2

sessionInfo()
R version 4.2.1 (2022-06-23)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Monterey 12.6

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] parallel  stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] mixOmics_6.21.0    MASS_7.3-58.1      plsmod_1.0.0       doMC_1.3.8         iterators_1.0.14   foreach_1.5.2      ranger_0.14.1      Cubist_0.4.0       lattice_0.20-45    xgboost_1.6.0.1    rpart_4.1.16       earth_5.3.1       
[13] plotmo_3.6.2       TeachingDemos_2.12 plotrix_3.8-2      Formula_1.2-4      baguette_1.0.0     rules_1.0.0        forcats_0.5.2      stringr_1.4.1      readr_2.1.3        tidyverse_1.3.2    yardstick_1.1.0    workflowsets_1.0.0
[25] workflows_1.1.0    tune_1.0.0.9000    tidyr_1.2.1        tibble_3.1.8       rsample_1.1.0      recipes_1.0.1      purrr_0.3.4        parsnip_1.0.2      modeldata_1.0.1    infer_1.0.3        ggplot2_3.3.6      dplyr_1.0.10      
[37] dials_1.0.0        scales_1.2.1       broom_1.0.1        tidymodels_1.0.0  

loaded via a namespace (and not attached):
  [1] readxl_1.4.1        backports_1.4.1     igraph_1.3.5        plyr_1.8.7          splines_4.2.1       BiocParallel_1.30.3 listenv_0.8.0       digest_0.6.29       htmltools_0.5.3     fansi_1.0.3         magrittr_2.0.3      memoise_2.0.1      
 [13] googlesheets4_1.0.1 tzdb_0.3.0          globals_0.16.1      modelr_0.1.9        gower_1.0.0         matrixStats_0.62.0  rARPACK_0.11-0      R.utils_2.12.0      hardhat_1.2.0       colorspace_2.0-3    vip_0.3.2           ggrepel_0.9.1      
 [25] rvest_1.0.3         warp_0.2.0          haven_2.5.1         xfun_0.33           callr_3.7.2         crayon_1.5.2        jsonlite_1.8.2      libcoin_1.0-9       survival_3.4-0      glue_1.6.2          gtable_0.3.1        gargle_1.2.1       
 [37] ipred_0.9-13        R.cache_0.16.0      clipr_0.8.0         future.apply_1.9.1  mvtnorm_1.1-3       DBI_1.1.3           Rcpp_1.0.9          GPfit_1.0-8         lava_1.6.10         prodlim_2019.11.13  httr_1.4.4          RColorBrewer_1.1-3 
 [49] ellipsis_0.3.2      pkgconfig_2.0.3     R.methodsS3_1.8.2   nnet_7.3-18         dbplyr_2.2.1        utf8_1.2.2          tidyselect_1.1.2    rlang_1.0.6         DiceDesign_1.9      reshape2_1.4.4      munsell_0.5.0       cellranger_1.1.0   
 [61] tools_4.2.1         cachem_1.0.6        cli_3.4.1           generics_0.1.3      evaluate_0.16       fastmap_1.1.0       yaml_2.3.5          processx_3.7.0      knitr_1.40          fs_1.5.2            future_1.28.0       nlme_3.1-159       
 [73] R.oo_1.25.0         xml2_1.3.3          compiler_4.2.1      rstudioapi_0.14     slider_0.2.2        reprex_2.0.2        lhs_1.1.5           stringi_1.7.8       ps_1.7.1            highr_0.9           RSpectra_0.16-1     Matrix_1.5-1       
 [85] styler_1.7.0        conflicted_1.1.0    vctrs_0.4.2         pillar_1.8.1        lifecycle_1.0.2     furrr_0.3.1         corpcor_1.6.10      data.table_1.14.2   R6_2.5.1            gridExtra_2.3       C50_0.1.6           parallelly_1.32.1  
 [97] codetools_0.2-18    assertthat_0.2.1    withr_2.5.0         hms_1.1.2           grid_4.2.1          timeDate_4021.106   class_7.3-20        rmarkdown_2.16      inum_1.0-4          googledrive_2.0.0   partykit_1.2-16     lubridate_1.8.0    
[109] ellipse_0.4.3 

This can be worked around by setting

pred_wrapper = function(object, newdata) { pred = predict(object, newdata); pred$predict}

but it would be nice if vi_ functions were aware of this PLS idiosyncrasy, and worked with it out-of-the-box.

Cheers,

Mariusz

@brandongreenwell-8451
Copy link

Hi @marioem, thank you for reaching out and for providing a very detailed example. Just from the initial look, it seems as though the prediction wrapper you supplied just needs modified. For RMSE, the prediction wrapper needs to return a single vector of predictions, but it looks like just calling predict() on this type of object returns a list? If so, maybe you're looking for something closer to predict(object, newdata)$predict (or whatever the name of the actual vector of predictions is called). Is this the case here?

@marioem
Copy link
Author

marioem commented Oct 18, 2022

Hi @brandongreenwell-8451 ,

the first use of vip in above example uses bare-foot predict which resolves to predict.pls method from mixOmics, with return values documented here. As probably many users of both vip and mixOmics seeing that predict.pls method is available, would put it as is into pred_wrapper, I was wondering if the improvement to vi_ functions would be possible to make this type of return value transparent to the user. Examples in documentation of vi_permute use plain predict and that primed me to think that this function would handle every predict method provided by the packages implementing models.

@brandongreenwell-8451
Copy link

brandongreenwell-8451 commented Oct 18, 2022

Hi @marioem, I'm not sure I understand what you mean? The docs in the link specify the output as a list with four components. The proper prediction wrapper extracts the prediction component and works as expected:

Best_test_resultsRmse %>% 
  extract_fit_parsnip() %>%
  vip::vip(method = "permute",
           num_features = 30,
           train = normalized_rec %>% prep() %>% bake(new_data = NULL), 
           target = "compressive_strength", 
           metric = "rmse", 
           nsim = 500,
           #######
           pred_wrapper = function(object, newdata) { predict(object, newdata)$predict}, 
           #######
           geom = "col", 
           all_permutations = F,
           aesthetics = list(color = "grey35"),
           include_type = T
  ) +
  ggtitle("Predictor importance - PLS")

Or are you suggesting an improvement to the docs to help prevent the subtle confusion in the proper prediction wrapper to supply? If so, the tidymodels team opened a similar issue here. The docs for vi_permute() explain that for regression problems, the prediction wrapper should return a numeric (and atomic) vector of prediction, which does mean that the user will have to be familiar with what the generic predict() will return for their model objects.

I'd be happy to modify the examples, and maybe add a special vignette if you think it'd be helpful for general users?

@marioem
Copy link
Author

marioem commented Oct 18, 2022

Hi @brandongreenwell-8451 ,

update to the documentation would be very helpful. We (users) are spoiled by S3 and we often automatically assume every predict method behaves as predict.lm, and my assumption here was that pred_wrapper = predict is all that is required for vi_permute to work with mixOmics methods as well.

@brandongreenwell-8451
Copy link

Makes sense to me. I'll put some thought into and make sure the docs are much clearer and will add an example of when this is not the case! Thanks again for posting the issue, I'll close this once the docs are updated accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants