Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Saulo Martiello Mastelini committed May 31, 2024
1 parent 828d7f5 commit 9dc2d19
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions river/model_selection/sspt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import collections
import copy
import math
Expand Down Expand Up @@ -118,7 +120,7 @@ def __init__(
self,
estimator: base.Estimator,
metric: metrics.base.Metric,
params_range: typing.Dict[str, typing.Tuple],
params_range: dict[str, tuple],
drift_input: typing.Callable[[float, float], float],
grace_period: int = 500,
drift_detector: base.DriftDetector = drift.ADWIN(),
Expand All @@ -145,7 +147,7 @@ def __init__(
self._simplex = self._create_simplex(estimator)

# Models expanded from the simplex
self._expanded: typing.Optional[typing.Dict] = None
self._expanded: dict | None = None

# Convergence criterion
self._old_centroid = None
Expand All @@ -155,7 +157,7 @@ def __init__(
if isinstance(border, compose.Pipeline):
border = border[-1]

if isinstance(border, (base.Classifier, base.Regressor)):
if isinstance(border, base.Classifier | base.Regressor):
self._scorer_name = "predict_one"
elif isinstance(border, anomaly.base.AnomalyDetector):
self._scorer_name = "score_one"
Expand Down Expand Up @@ -189,7 +191,7 @@ def __flatten(self, prefix, scaled_hps, hp_data, est_data):

def _traverse_hps(
self, operation: str, hp_data: dict, est_1, *, func=None, est_2=None, hp_prefix=None, scaled_hps=None
) -> typing.Optional[typing.Union[dict, numbers.Number]]:
) -> dict | numbers.Number | None:
"""Traverse the hyperparameters of the estimator/pipeline and perform an operation.
Parameters
Expand Down Expand Up @@ -291,7 +293,7 @@ def _random_config(self):
est_1=self.estimator._get_params()
)

def _create_simplex(self, model) -> typing.List:
def _create_simplex(self, model) -> list:
# The simplex is divided in:
# * 0: the best model
# * 1: the 'good' model
Expand Down Expand Up @@ -343,7 +345,7 @@ def _gen_new_estimator(self, e1, e2, func):

return new

def _nelder_mead_expansion(self) -> typing.Dict:
def _nelder_mead_expansion(self) -> dict:
"""Create expanded models given the simplex models."""
expanded = {}
# Midpoint between 'best' and 'good'
Expand Down Expand Up @@ -455,7 +457,7 @@ def _models_converged(self) -> bool:
)
self._old_centroid = new_centroid
ndim = len(scaled_params_b)
r_sphere = max_dist * math.sqrt((ndim / (2 * (ndim + 1))))
r_sphere = max_dist * math.sqrt(ndim / (2 * (ndim + 1)))

if r_sphere < self.convergence_sphere or centroid_distance == 0:
return True
Expand Down

0 comments on commit 9dc2d19

Please sign in to comment.