-
Notifications
You must be signed in to change notification settings - Fork 31
Speed optimizations for iregnet
Interval regression is a class of machine learning models which is useful when predicted values should be real numbers, but outputs in the training data set may be partially observed. A common example is survival analysis, in which data are patient survival times.
For example, say that Alice and Bob came into the hospital and were treated for cancer on the same day in 2000. Now we are in 2016 and we would like to study the treatment efficacy. Say Alice died in 2010, and Bob is still alive. The survival time for Alice is 10 years, and although we do not know Bob’s survival time, we know it is in the interval (16, Infinity).
Say that we also measured some covariates (input variables) for Alice and Bob (age, sex, gene expression). We can fit an Accelerated Failure Time (AFT) model which takes those input variables and outputs a predicted survival time. L1 regularized AFT models are of interest when there are many input variables and we would like the model to automatically ignore those which are un-informative (do not help predicting survival time). Several papers describe L1 regularized AFT models
- L1 regularization for an AFT model with a weighted square loss function, Huang et al 2005 jian-huang@uiowa.edu writes “The AFT model is really just a linear regression model, except that the response is a transformation of the survival time (usually a log transform) and the response is usually subject to censoring. Given the values of the predictors, the output is just (transformed) survival time. So it can be used for predicting survival times.”
- L1 regularization for an AFT model with a pairwise loss function, Cai et al 2011 tcai.hsph@gmail.com writes “we only provide regression coefficients and do not provide actual predicted survival although it could be derived from the model by first estimating the residual distribution as in the standard AFT model.”
Interval regression (or interval censoring) is a generalization in which any kind of interval is an acceptable output in the training data. Any real-valued or positive-valued probability distribution may be used to model the outputs (e.g. normal or logistic if output is real-valued, log-normal or log-logistic if output is positive-valued like a survival time). For more details read this 1-page explanation of un-regularized parametric AFT models.
output | interval | likelihood | censoring |
---|---|---|---|
exactly 10 | (10, 10) | density function | none |
at least 16 | (16, Infinity) | cumulative distribution function | right |
at most 3 | (-Infinity, 3) | cumulative distribution function | left |
between -4 and 5 | (-4, 5) | cumulative distribution function | interval |
Another application of interval regression is in learning penalty functions for detecting change-points and peaks in genomic data (data viz).
The iregnet package was coded in GSOC2016 by @anujkhare. It is the first R package to support
- general interval output data (including left and interval censoring; not just observed and right-censored data typical of survival analysis),
- elastic net (L1 + L2) regularization, and
- a fast glmnet-like coordinate descent solver.
The coordinate descent solver was coded in C++ by following the mathematics of Simon et al (JSS). However, it is not as fast as the glmnet package. The main goal of this GSOC project will be to make iregnet as fast as glmnet.
- AdapEnetClass::WEnetCC.aft (arXiv paper) fits two different models, both with AFT weighted square loss and elastic net regularization.
- glmnet fits models for elastic net regularization with several loss functions, but neither AFT nor interval regression losses are supported.
- interval::icfit and survival::survreg provide solvers for non-regularized interval regression models.
- The PeakSegDP package contains a solver which uses the FISTA algorithm to fit an L1 regularized model for general interval output data. However, there are two issues: (1) it is not as fast as the coordinate descent algorithm implmented in glmnet, and (2) it does not support L2-regularization.
function/pkg | censoring | regularization | loss | algorithm |
---|---|---|---|---|
glmnet | none, right | L1 + L2 | Cox | coordinate descent |
glmnet | none | L1 + L2 | normal, logistic | coordinate descent |
AdapEnetClass | none, right | L1 + L2 | normal | LARS |
coxph | none, right, left, interval | none | Cox | ? |
survreg | none, right, left, interval | none | normal, logistic, Weibull | Newton-Raphson |
PeakSegDP | left, right, interval | L1 | squared hinge, log | FISTA |
iregnet | none, right, left, interval | L1 + L2 | normal, logistic, Weibull | coordinate descent |
The main goal of this GSOC project is to optimize the speed if the iregnet package, so that it is as fast as the glmnet package. The project should start by setting up Rperform for iregnet, then profiling the iregnet C++ code to find the slow parts, then re-coding those parts. If time permits it would be nice to have a vignette or a blog post with speed comparisons (before and after optimization, glmnet and iregnet).
The iregnet package is already useful for making predictions in data sets with possibly censored observations. After this GSOC project, its model fitting code will be even faster.
- Toby Dylan Hocking <tdhock5@gmail.com> proposed this project and can mentor.
- Anuj Khare <khareanuj18@gmail.com> coded iregnet in GSOC2016 and can mentor.
Students, please complete as many tests as possible before emailing the mentors. If we do not find a student who can complete the Hard test, then we should not approve this GSOC project.
- Easy: perform a side-by-side comparison of iregnet and glmnet for a lasso problem with no censored data. Consider the
prostate
cancer data set, which has no censored data. Use themicrobenchmark
package to time theiregnet
andglmnet
functions. Do the two functions return the same result? Which is faster? Plot of time versus data set size. (one plot for rows and one plot for columns) This kind of plot makes it very easy to see the differences in timings. - Medium: set up Rperform for iregnet. @anujkhare has already implemented some optimizations in the branch called
optimize
. UseRperform::plot_branchmetrics
to compare timings for the commits on theoptimize
branch against themaster
branch. Does theoptimize
branch result in speed improvements? For which tests? - Hard: demonstrate that you know how to do code profiling of C++ code. Run a code profiler on the iregnet R package and explain which parts of the code are taking the most time.
Students, please post a link to your test results here.