Skip to content

Commit

Permalink
Merge pull request #504 from sebp/sklearn-1-6
Browse files Browse the repository at this point in the history
Add support for scikit-learn 1.6
  • Loading branch information
sebp authored Jan 12, 2025
2 parents 5ab06da + 83618d4 commit adf382d
Show file tree
Hide file tree
Showing 37 changed files with 217 additions and 167 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ repos:
hooks:
- id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.1
rev: v0.8.4
hooks:
- id: ruff
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Requirements
- numpy
- osqp
- pandas 1.4.0 or later
- scikit-learn 1.4 or 1.5
- scikit-learn 1.6
- scipy
- C/C++ compiler

Expand Down
2 changes: 1 addition & 1 deletion ci/appveyor/py310.ps1
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.10.*"
$env:CI_PANDAS_VERSION="1.5.*"
$env:CI_NUMPY_VERSION="1.25.*"
$env:CI_SKLEARN_VERSION="1.4.*"
$env:CI_SKLEARN_VERSION="1.6.*"
2 changes: 1 addition & 1 deletion ci/appveyor/py311.ps1
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.11.*"
$env:CI_PANDAS_VERSION="2.0.*"
$env:CI_NUMPY_VERSION="1.26.*"
$env:CI_SKLEARN_VERSION="1.5.*"
$env:CI_SKLEARN_VERSION="1.6.*"
2 changes: 1 addition & 1 deletion ci/appveyor/py312.ps1
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.12.*"
$env:CI_PANDAS_VERSION="2.2.*"
$env:CI_NUMPY_VERSION="2.0.*"
$env:CI_SKLEARN_VERSION="1.5.*"
$env:CI_SKLEARN_VERSION="1.6.*"
2 changes: 1 addition & 1 deletion ci/appveyor/py313.ps1
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
$env:CI_PYTHON_VERSION="3.13.*"
$env:CI_PANDAS_VERSION="2.2.*"
$env:CI_NUMPY_VERSION="2.1.*"
$env:CI_SKLEARN_VERSION="1.5.*"
$env:CI_SKLEARN_VERSION="1.6.*"
2 changes: 1 addition & 1 deletion ci/deps/py310.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.10.*'
export CI_PANDAS_VERSION='1.5.*'
export CI_NUMPY_VERSION='1.25.*'
export CI_SKLEARN_VERSION='1.4.*'
export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=false
2 changes: 1 addition & 1 deletion ci/deps/py311.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.11.*'
export CI_PANDAS_VERSION='2.0.*'
export CI_NUMPY_VERSION='1.26.*'
export CI_SKLEARN_VERSION='1.5.*'
export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=true
2 changes: 1 addition & 1 deletion ci/deps/py312.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.12.*'
export CI_PANDAS_VERSION='2.2.*'
export CI_NUMPY_VERSION='2.0.*'
export CI_SKLEARN_VERSION='1.5.*'
export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=true
2 changes: 1 addition & 1 deletion ci/deps/py313.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
export CI_PYTHON_VERSION='3.13.*'
export CI_PANDAS_VERSION='2.2.*'
export CI_NUMPY_VERSION='2.1.*'
export CI_SKLEARN_VERSION='1.5.*'
export CI_SKLEARN_VERSION='1.6.*'
export CI_NO_SLOW=false
2 changes: 1 addition & 1 deletion doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@
}

intersphinx_mapping = {
"sklearn": ("https://scikit-learn.org/1.5", None),
"sklearn": ("https://scikit-learn.org/1.6", None),
"cython": ("https://cython.readthedocs.io/en/latest/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"pandas": ("https://pandas.pydata.org/docs/", None),
Expand Down
2 changes: 1 addition & 1 deletion doc/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,6 @@ The current minimum dependencies to run scikit-survival are:
- numpy
- osqp
- pandas 1.4.0 or later
- scikit-learn 1.4 or 1.5
- scikit-learn 1.6
- scipy
- C/C++ compiler
19 changes: 13 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ requires = [
"numpy>=2.0.0",

# scikit-learn requirements
"scikit-learn~=1.4.0; python_version<='3.12'",
"scikit-learn~=1.5.0; python_version=='3.13'",
"scikit-learn~=1.6.1; python_version<='3.13'",
"scikit-learn; python_version>'3.13'",
]
build-backend = "setuptools.build_meta"
Expand Down Expand Up @@ -51,7 +50,7 @@ dependencies = [
"osqp !=0.6.0,!=0.6.1",
"pandas >=1.4.0",
"scipy >=1.3.2",
"scikit-learn >=1.4.0,<1.6",
"scikit-learn >=1.6.1,<1.7",
]
dynamic = ["version"]

Expand Down Expand Up @@ -188,13 +187,21 @@ target-version = "py310"
ignore = ["C408"]
ignore-init-module-imports = true
select = [
"C4",
"C9",
# pycodestyle
"E",
"W",
# mccabe
"C90",
# pyflakes
"F",
# isort
"I",
# flake8-builtins
"A",
# flake8-comprehensions
"C4",
# flake8-pytest-style
"PT",
"W",
]

[tool.ruff.lint.flake8-pytest-style]
Expand Down
2 changes: 1 addition & 1 deletion sksurv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def show_versions():
max(map(len, deps)),
max(map(len, sys_info.keys())),
)
fmt = "{0:<%ds}: {1}" % minwidth
fmt = f"{{0:<{minwidth}s}}: {{1}}"

print("SYSTEM")
print("------")
Expand Down
6 changes: 4 additions & 2 deletions sksurv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,7 @@ def score(self, X, y):
result = concordance_index_censored(y[name_event], y[name_time], risk_score)
return result[0]

def _more_tags(self):
return {"requires_y": True}
def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.target_tags.required = True
return tags
4 changes: 4 additions & 0 deletions sksurv/bintrees/_binarytrees.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
cimport cython
from libcpp cimport bool
from libcpp.cast cimport dynamic_cast

Expand Down Expand Up @@ -76,18 +77,21 @@ cdef class BaseTree:
return self.count_larger(key)


@cython.final
cdef class RBTree(BaseTree):
def __cinit__(self, int size):
if size <= 0:
raise ValueError('size must be greater zero')
self.treeptr = new rbtree(size)

@cython.final
cdef class AVLTree(BaseTree):
def __cinit__(self, int size):
if size <= 0:
raise ValueError('size must be greater zero')
self.treeptr = dynamic_cast[rbtree_ptr](new avl(size))

@cython.final
cdef class AATree(BaseTree):
def __cinit__(self, int size):
if size <= 0:
Expand Down
30 changes: 18 additions & 12 deletions sksurv/ensemble/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree._tree import DTYPE
from sklearn.utils import check_random_state
from sklearn.utils._param_validation import Interval, StrOptions
from sklearn.utils.extmath import squared_norm
from sklearn.utils.validation import _check_sample_weight, check_array, check_is_fitted
from sklearn.utils.validation import (
_check_sample_weight,
check_array,
check_is_fitted,
check_random_state,
validate_data,
)

from ..base import SurvivalAnalysisMixin
from ..linear_model.coxph import BreslowEstimator
Expand Down Expand Up @@ -389,7 +394,7 @@ def fit(self, X, y, sample_weight=None):
if not self.warm_start:
self._clear_state()

X = self._validate_data(X, ensure_min_samples=2)
X = validate_data(self, X, ensure_min_samples=2)
event, time = check_array_survival(X, y)

sample_weight = _check_sample_weight(sample_weight, X)
Expand All @@ -398,7 +403,7 @@ def fit(self, X, y, sample_weight=None):
Xi = np.column_stack((np.ones(n_samples), X))

self._loss = LOSS_FUNCTIONS[self.loss]()
if isinstance(self._loss, (CensoredSquaredLoss, IPCWLeastSquaresError)):
if isinstance(self._loss, CensoredSquaredLoss | IPCWLeastSquaresError):
time = np.log(time)

if not self._is_fitted():
Expand Down Expand Up @@ -470,7 +475,7 @@ def predict(self, X):
Predicted risk scores.
"""
check_is_fitted(self, "estimators_")
X = self._validate_data(X, reset=False)
X = validate_data(self, X, reset=False)

return self._predict(X)

Expand Down Expand Up @@ -957,7 +962,7 @@ def _set_max_features(self):
max_features = max(1, int(np.log2(self.n_features_in_)))
elif self.max_features is None:
max_features = self.n_features_in_
elif isinstance(self.max_features, (numbers.Integral, np.integer)):
elif isinstance(self.max_features, numbers.Integral):
max_features = self.max_features
else: # float
max_features = max(1, int(self.max_features * self.n_features_in_))
Expand Down Expand Up @@ -1234,7 +1239,8 @@ def fit(self, X, y, sample_weight=None, monitor=None):
if not self.warm_start:
self._clear_state()

X = self._validate_data(
X = validate_data(
self,
X,
ensure_min_samples=2,
order="C",
Expand All @@ -1256,7 +1262,7 @@ def fit(self, X, y, sample_weight=None, monitor=None):
# self.loss is guaranteed to be a string
self._loss = self._get_loss(sample_weight=sample_weight)

if isinstance(self._loss, (CensoredSquaredLoss, IPCWLeastSquaresError)):
if isinstance(self._loss, CensoredSquaredLoss | IPCWLeastSquaresError):
time = np.log(time)

if self.n_iter_no_change is not None:
Expand Down Expand Up @@ -1315,13 +1321,13 @@ def fit(self, X, y, sample_weight=None, monitor=None):
begin_at_stage = self.estimators_.shape[0]
# The requirements of _raw_predict
# are more constrained than fit. It accepts only CSR
# matrices. Finite values have already been checked in _validate_data.
# matrices. Finite values have already been checked in validate_data.
X_train = check_array(
X_train,
dtype=DTYPE,
order="C",
accept_sparse="csr",
force_all_finite=False,
ensure_all_finite=False,
)
raw_predictions = self._raw_predict(X_train)
self._resize_state()
Expand Down Expand Up @@ -1390,7 +1396,7 @@ def _dropout_raw_predict(self, X):
return raw_predictions

def _dropout_staged_raw_predict(self, X):
X = self._validate_data(X, dtype=DTYPE, order="C", accept_sparse="csr")
X = validate_data(self, X, dtype=DTYPE, order="C", accept_sparse="csr")
raw_predictions = self._raw_predict_init(X)

n_estimators, K = self.estimators_.shape
Expand Down Expand Up @@ -1438,7 +1444,7 @@ def predict(self, X):
"""
check_is_fitted(self, "estimators_")

X = self._validate_data(X, reset=False, order="C", accept_sparse="csr", dtype=DTYPE)
X = validate_data(self, X, reset=False, order="C", accept_sparse="csr", dtype=DTYPE)
return self._predict(X)

def staged_predict(self, X):
Expand Down
38 changes: 25 additions & 13 deletions sksurv/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
_parallel_build_trees,
)
from sklearn.tree._tree import DTYPE
from sklearn.utils._tags import _safe_tags
from sklearn.utils.validation import check_is_fitted, check_random_state
from sklearn.utils._tags import get_tags
from sklearn.utils.validation import check_is_fitted, check_random_state, validate_data

from ..base import SurvivalAnalysisMixin
from ..metrics import concordance_index_censored
Expand All @@ -29,18 +29,20 @@
MAX_INT = np.iinfo(np.int32).max


def _more_tags_patch(self):
# BaseForest._more_tags calls
def _sklearn_tags_patch(self):
# BaseForest.__sklearn_tags__ calls
# type(self.estimator)(criterion=self.criterions),
# which is incompatible with LogrankCriterion
if isinstance(self, _BaseSurvivalForest):
estimator = type(self.estimator)()
else:
estimator = type(self.estimator)(criterion=self.criterion)
return {"allow_nan": _safe_tags(estimator, key="allow_nan")}
tags = super(BaseForest, self).__sklearn_tags__()
tags.input_tags.allow_nan = get_tags(estimator).input_tags.allow_nan
return tags


BaseForest._more_tags = _more_tags_patch
BaseForest.__sklearn_tags__ = _sklearn_tags_patch


class _BaseSurvivalForest(BaseForest, metaclass=ABCMeta):
Expand Down Expand Up @@ -104,7 +106,7 @@ def fit(self, X, y, sample_weight=None):
"""
self._validate_params()

X = self._validate_data(X, dtype=DTYPE, accept_sparse="csc", ensure_min_samples=2, force_all_finite=False)
X = validate_data(self, X, dtype=DTYPE, accept_sparse="csc", ensure_min_samples=2, ensure_all_finite=False)
event, time = check_array_survival(X, y)

# _compute_missing_values_in_feature_mask checks if X has missing values and
Expand All @@ -115,7 +117,7 @@ def fit(self, X, y, sample_weight=None):
X, estimator_name=self.__class__.__name__
)

self.n_features_in_ = X.shape[1]
self._n_samples, self.n_features_in_ = X.shape
time = time.astype(np.float64)
self.unique_times_, self.is_event_time_ = get_unique_times(time, event)
self.n_outputs_ = self.unique_times_.shape[0]
Expand All @@ -125,7 +127,18 @@ def fit(self, X, y, sample_weight=None):
y_numeric[:, 1] = event.astype(np.float64)

# Get bootstrap sample size
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
if not self.bootstrap and self.max_samples is not None: # pylint: disable=no-else-raise
raise ValueError(
"`max_sample` cannot be set if `bootstrap=False`. "
"Either switch to `bootstrap=True` or set "
"`max_sample=None`."
)
elif self.bootstrap:
n_samples_bootstrap = _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples)
else:
n_samples_bootstrap = None

self._n_samples_bootstrap = n_samples_bootstrap

# Check parameters
self._validate_estimator()
Expand All @@ -141,13 +154,12 @@ def fit(self, X, y, sample_weight=None):

n_more_estimators = self.n_estimators - len(self.estimators_)

if n_more_estimators < 0:
if n_more_estimators < 0: # pylint: disable=no-else-raise
raise ValueError(
f"n_estimators={self.n_estimators} must be larger or equal to "
f"len(estimators_)={len(self.estimators_)} when warm_start==True"
)

if n_more_estimators == 0:
elif n_more_estimators == 0:
warnings.warn("Warm-start fitting without increasing n_estimators does not fit new trees.", stacklevel=2)
else:
if self.warm_start and len(self.estimators_) > 0:
Expand Down Expand Up @@ -442,7 +454,7 @@ class RandomSurvivalForest(SurvivalAnalysisMixin, _BaseSurvivalForest):
`min_impurity_decrease` or `min_impurity_split` are absent.
In addition, the `feature_importances_` attribute is not available.
It is recommended to estimate feature importances via
`permutation-based methods <https://eli5.readthedocs.io>`_.
:func:`sklearn.inspection.permutation_importance`.
The features are always randomly permuted at each split. Therefore,
the best found split may vary, even with the same training data,
Expand Down
2 changes: 1 addition & 1 deletion sksurv/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import numpy as np
from sklearn.utils import check_consistent_length
from sklearn.utils.validation import check_consistent_length

__all__ = ["StepFunction"]

Expand Down
Loading

0 comments on commit adf382d

Please sign in to comment.