diff --git a/src/glum/_glm.py b/src/glum/_glm.py index 0d962dd1..c3f6e134 100644 --- a/src/glum/_glm.py +++ b/src/glum/_glm.py @@ -27,7 +27,6 @@ import scipy.sparse as sps import scipy.sparse.linalg as splinalg import sklearn as skl -import sklearn.utils.validation import tabmat as tm from formulaic import Formula, FormulaSpec from formulaic.parser import DefaultFormulaParser @@ -75,18 +74,14 @@ if version.parse(skl.__version__).release < (1, 6): keyword_finiteness = "force_all_finite" -else: - keyword_finiteness = "ensure_all_finite" - -if hasattr(sklearn.utils.validation, "validate_data"): - validate_data = sklearn.utils.validation.validate_data -else: + _check_n_features = BaseEstimator._check_n_features validate_data = BaseEstimator._validate_data - -if hasattr(sklearn.utils.validation, "_check_n_features"): - _check_n_features = sklearn.utils.validation._check_n_features else: - _check_n_features = BaseEstimator._check_n_features + keyword_finiteness = "ensure_all_finite" + from sklearn.utils.validation import ( # type: ignore + _check_n_features, + validate_data, + ) _float_itemsize_to_dtype = {8: np.float64, 4: np.float32, 2: np.float16}