From 9dc2d19cfe880d6baf9aba3929a8aabae845c53b Mon Sep 17 00:00:00 2001 From: Saulo Martiello Mastelini Date: Fri, 31 May 2024 10:54:07 -0300 Subject: [PATCH] format --- river/model_selection/sspt.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/river/model_selection/sspt.py b/river/model_selection/sspt.py index e164cb1168..c800c38beb 100644 --- a/river/model_selection/sspt.py +++ b/river/model_selection/sspt.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import collections import copy import math @@ -118,7 +120,7 @@ def __init__( self, estimator: base.Estimator, metric: metrics.base.Metric, - params_range: typing.Dict[str, typing.Tuple], + params_range: dict[str, tuple], drift_input: typing.Callable[[float, float], float], grace_period: int = 500, drift_detector: base.DriftDetector = drift.ADWIN(), @@ -145,7 +147,7 @@ def __init__( self._simplex = self._create_simplex(estimator) # Models expanded from the simplex - self._expanded: typing.Optional[typing.Dict] = None + self._expanded: dict | None = None # Convergence criterion self._old_centroid = None @@ -155,7 +157,7 @@ def __init__( if isinstance(border, compose.Pipeline): border = border[-1] - if isinstance(border, (base.Classifier, base.Regressor)): + if isinstance(border, base.Classifier | base.Regressor): self._scorer_name = "predict_one" elif isinstance(border, anomaly.base.AnomalyDetector): self._scorer_name = "score_one" @@ -189,7 +191,7 @@ def __flatten(self, prefix, scaled_hps, hp_data, est_data): def _traverse_hps( self, operation: str, hp_data: dict, est_1, *, func=None, est_2=None, hp_prefix=None, scaled_hps=None - ) -> typing.Optional[typing.Union[dict, numbers.Number]]: + ) -> dict | numbers.Number | None: """Traverse the hyperparameters of the estimator/pipeline and perform an operation. Parameters @@ -291,7 +293,7 @@ def _random_config(self): est_1=self.estimator._get_params() ) - def _create_simplex(self, model) -> typing.List: + def _create_simplex(self, model) -> list: # The simplex is divided in: # * 0: the best model # * 1: the 'good' model @@ -343,7 +345,7 @@ def _gen_new_estimator(self, e1, e2, func): return new - def _nelder_mead_expansion(self) -> typing.Dict: + def _nelder_mead_expansion(self) -> dict: """Create expanded models given the simplex models.""" expanded = {} # Midpoint between 'best' and 'good' @@ -455,7 +457,7 @@ def _models_converged(self) -> bool: ) self._old_centroid = new_centroid ndim = len(scaled_params_b) - r_sphere = max_dist * math.sqrt((ndim / (2 * (ndim + 1)))) + r_sphere = max_dist * math.sqrt(ndim / (2 * (ndim + 1))) if r_sphere < self.convergence_sphere or centroid_distance == 0: return True