-
Notifications
You must be signed in to change notification settings - Fork 18
Regularized interval regression
Interval regression is a class of machine learning models which is useful when predicted values should be real numbers, but outputs in the training data set may be partially observed. A common example is survival analysis, in which data are patient survival times.
For example, say that Alice and Bob came into the hospital and were treated for cancer on the same day in 2000. Now we are in 2016 and we would like to study the treatment efficacy. Say Alice died in 2010, and Bob is still alive. The survival time for Alice is 10 years, and although we do not know Bob’s survival time, we know it is in the interval (16, Infinity).
Say that we also measured some covariates (input variables) for Alice and Bob (age, sex, gene expression). We can fit an Accelerated Failure Time (AFT) model which takes those input variables and outputs a predicted survival time. L1 regularized AFT models are of interest when there are many input variables and we would like the model to automatically ignore those which are un-informative (do not help predicting survival time). Several papers describe L1 regularized AFT models
- L1 regularization for an AFT model with a weighted square loss function, Huang et al 2005 jian-huang@uiowa.edu writes “The AFT model is really just a linear regression model, except that the response is a transformation of the survival time (usually a log transform) and the response is usually subject to censoring. Given the values of the predictors, the output is just (transformed) survival time. So it can be used for predicting survival times.”
- L1 regularization for an AFT model with a pairwise loss function, Cai et al 2011 tcai.hsph@gmail.com writes “we only provide regression coefficients and do not provide actual predicted survival although it could be derived from the model by first estimating the residual distribution as in the standard AFT model.”
Interval regression (or interval censoring) is a generalization in which any kind of interval is an acceptable output in the training data. Any real-valued or positive-valued probability distribution may be used to model the outputs (e.g. normal or logistic if output is real-valued, log-normal or log-logistic if output is positive-valued like a survival time). For more details read this 1-page explanation of un-regularized parametric AFT models.
output | interval | likelihood | censoring |
---|---|---|---|
exactly 10 | (10, 10) | density function | none |
at least 16 | (16, Infinity) | cumulative distribution function | right |
at most 3 | (-Infinity, 3) | cumulative distribution function | left |
between -4 and 5 | (-4, 5) | cumulative distribution function | interval |
Another application of interval regression is in learning penalty functions for detecting change-points and peaks in genomic data (data viz).
- AdapEnetClass::WEnetCC.aft (arXiv paper) fits two different models, both with AFT weighted square loss and elastic net regularization.
- glmnet fits models for elastic net regularization with several loss functions, but neither AFT nor interval regression losses are supported.
- interval::icfit and survival::survreg provide solvers for non-regularized interval regression models.
- The PeakSegDP package contains a solver which uses the FISTA algorithm to fit an L1 regularized model for general interval output data. However, there are two issues: (1) it is not as fast as the coordinate descent algorithm implmented in glmnet, and (2) it does not support L2-regularization.
Implement the first R package to support
- general interval output data (including left and interval censoring; not just observed and right-censored data typical of survival analysis),
- elastic net (L1 + L2) regularization, and
- a fast glmnet-like coordinate descent solver.
function/pkg | censoring | regularization | loss | algorithm |
---|---|---|---|---|
glmnet | none, right | L1 + L2 | Cox | coordinate descent |
glmnet | none | L1 + L2 | normal, logistic | coordinate descent |
AdapEnetClass | none, right | L1 + L2 | normal | LARS |
coxph | none, right, left, interval | none | Cox | ? |
survreg | none, right, left, interval | none | normal, logistic, Weibull | Newton-Raphson |
PeakSegDP | left, right, interval | L1 | squared hinge, log | FISTA |
THIS PROJECT | none, right, left, interval | L1 + L2 | normal, logistic, Weibull | coordinate descent |
There are two possible coding strategies
- Fork the solver from the glmnet source code and adapt it to work with the interval regression loss. Should be possible if you understand their FORTRAN code.
- Read Simon et al (JSS) and implement a coordinate descent solver from scratch in C code. (This coding strategy is preferred)
Project goals for the end of summer:
- Implement the Log-Normal loss for general interval outputs, and at least one other of the following AFT loss functions: Log-Logistic, exponential, Weibull.
- Package tests to make sure the global optimum is found in real data sets such as AdapEnetClass::MCLcleaned.
- Interface similar to glmnet. Outputs
y
should be a two column matrix (first column real-valued lower limit, possibly negative or -Inf; second column real-valued upper limit, possibly Inf).
iregnet(X, y,
family=c("weibull", "exponential", "lognormal", "loglogistic"),
alpha=1)
Would be nice, but not necessary before the end of summer:
- Vignette which compares computation time and solution accuracy with glmnet, survreg, icfit, and/or AdapEnetClass::WEnetCC.aft.
- Interface which accepts
survival::Surv
objects as outputsy
, for compatibility withsurvreg
. - Optimizations for sparse input matrices (Matrix package).
The iregnet package will be useful for making accurate predictions in data sets with possibly censored observations. As explained in the Related Work and Coding project sections, there is currently no other R package with support for (1) four types of censoring, (2) elastic net regularization, and (3) a fast coordinate descent algorithm like glmnet. This package will thus provide a fast algorithm for fitting a class of models that is not yet possible in R.
- Toby Dylan Hocking <tdhock5@gmail.com> proposed this project, would be a user of this package, and could mentor.
- Jelle Goeman <j.j.goeman@lumc.nl> maintains the penalized package and can co-mentor.
Please do not bother the following people, who have already said they are too busy to co-mentor this project.
- Noah Simon <nrsimon@u.washington.edu> implemented the elastic net for the Cox model, and said he could help out informally, but he can NOT commit to formal co-mentoring.
- Trevor Hastie <hastie@stanford.edu> maintains glmnet but said that he is too busy to mentor.
- Terry M Therneau <therneau.terry@mayo.edu> maintains survival and said he is too busy to mentor, but he gave us the following advice:
For creating code like this, I would recommend using the overall structure of the survreg routine for the “normal” part of the likelihood. This allows for a large number of distributions. By this I mean use of the “survreg.distributions” object to lay out the likelihood. The math framework follows chapter 2 of Kalbfleisch and Prentice, “The statistical analysis of failure time data”. Sections 6.7-6.9 of the survival package manual contain the gradient computations and may be useful.
You are free to make use of any portions or code from my package, without license or attribution. In particular use the front part that handles the formula. There is no good reason to write a new R routine which forces users to pre-create their own X matrix, unless your goal is to have very few users. The central portion of the maximization will need to be completely different than mine, of course, due to the penalty.
For the gbm package, which I’m currently advising wrt survival, the first derivative of the loglik can be written as X’m where “m” is a residual, and is the partial of the loglik with respect to eta (see page 81 of the survival package manual). For ordinary linear regression m is the ordinary residual, for a Cox model it is the martingale residual, etc. The maximizer only needs m and the loglik in order to adapt it to a new distribution. I suspect that same might be true for coordinatewise descent, or something similar.
Students, please complete as many tests as possible before emailing the mentors. If we do not find a student who can complete the Hard test, then we should not approve this GSOC project.
- Easy: write a knitr document in which you perform cross-validation
to compare at least 2 model predictions on at least 2 data sets.
Fit each model to each train set,
and compute prediction error of each model with respect to the
test set. Which model is most accurate?
- Models: survival::survreg, AdapEnetClass::cv.AWEnet, AdapEnetClass::cv.AWEnetCC.
- Data sets: survival::lung, survival::ovarian, utils::data(MCLcleaned, package=”AdapEnetClass”).
- Test error: the mean squared prediction error can not be computed since some observations are censored.
Instead, compute a zero-one loss:
- For censored observations, e.g. patient lived at least 10 years, count 1 test error if the prediction is less than 10 years.
- For un-censored observations, e.g. patient lived exactly 6 years, count 1 test error if the prediction is off by a factor of two (less than 3 years or more than 12 years).
- Medium: show that you know how to include FORTRAN/C code in an R package.
- Hard: write down the mathematical optimization problem for elastic
net regularized interval regression using the loss function which
corresponds to a log-logistic AFT model. Output data in the train
set can be any of the four censoring types described above (none,
left, right, interval). Write the subdifferential optimality
condition for this optimization problem. Using the arguments similar
to the glmnet/coxnet papers, derive the coordinate descent update
rule and a stopping criterion. Hint: for the coordinate descent update rule, take a look at
- glmnet paper, Hastie et al JSS, section 3,
- coxnet paper, Simon et al JSS, section 2.2
The reason why the papers write the quadratic estimation for the loss (\ell_Q in the papers) is because that sub-problem can be solved in closed form exactly (that is the coordinate descent update). Hint: something about the log-logistic distribution should appear in the weights.
Students, please post a link to your test results here.