diff --git a/river/base/drift_detector.py b/river/base/drift_detector.py index c5972728f6..9e241cd362 100644 --- a/river/base/drift_detector.py +++ b/river/base/drift_detector.py @@ -22,12 +22,12 @@ class _BaseDriftDetector(base.Base): def __init__(self): self._drift_detected = False - def _reset(self): + def _reset(self) -> None: """Reset the detector's state.""" self._drift_detected = False @property - def drift_detected(self): + def drift_detected(self) -> bool: """Whether or not a drift is detected following the last update.""" return self._drift_detected @@ -57,7 +57,7 @@ class DriftDetector(_BaseDriftDetector): """A drift detector.""" @abc.abstractmethod - def update(self, x: int | float) -> DriftDetector: + def update(self, x: int | float) -> None: """Update the detector with a single data point. Parameters diff --git a/river/drift/kswin.py b/river/drift/kswin.py index 82547faa52..f2c799b67a 100644 --- a/river/drift/kswin.py +++ b/river/drift/kswin.py @@ -109,7 +109,7 @@ def _reset(self): super()._reset() self.p_value = 0 self.n = 0 - self.window: collections.deque = collections.deque(maxlen=self.window_size) + self.window = collections.deque(maxlen=self.window_size) self._rng = random.Random(self.seed) def update(self, x): diff --git a/river/forest/aggregated_mondrian_forest.py b/river/forest/aggregated_mondrian_forest.py index fc250a2734..12601d0cd5 100644 --- a/river/forest/aggregated_mondrian_forest.py +++ b/river/forest/aggregated_mondrian_forest.py @@ -169,7 +169,7 @@ def __init__( # memory of the classes self._classes: set[base.typing.ClfTarget] = set() - def _initialize_trees(self): + def _initialize_trees(self) -> None: self.data: list[MondrianTreeClassifier] = [] for _ in range(self.n_estimators): tree = MondrianTreeClassifier( @@ -287,7 +287,7 @@ def __init__( self.iteration = 0 - def _initialize_trees(self): + def _initialize_trees(self) -> None: """Initialize the forest.""" self.data: list[MondrianTreeRegressor] = [] diff --git a/river/tree/nodes/hatc_nodes.py b/river/tree/nodes/hatc_nodes.py index bab558192d..e85cbc4334 100644 --- a/river/tree/nodes/hatc_nodes.py +++ b/river/tree/nodes/hatc_nodes.py @@ -2,6 +2,7 @@ import math +from river import base from river import stats as st from river.utils.norm import normalize_values_in_dict from river.utils.random import poisson @@ -133,7 +134,7 @@ class AdaBranchClassifier(DTBranch): Other parameters passed to the split node. """ - def __init__(self, stats, *children, drift_detector, **attributes): + def __init__(self, stats: dict, *children, drift_detector: base.DriftDetector, **attributes): super().__init__(stats, *children, **attributes) self.drift_detector = drift_detector self._alternate_tree = None diff --git a/river/tree/nodes/sgt_nodes.py b/river/tree/nodes/sgt_nodes.py index e5a0542999..b649b1ac51 100644 --- a/river/tree/nodes/sgt_nodes.py +++ b/river/tree/nodes/sgt_nodes.py @@ -31,7 +31,7 @@ class SGTLeaf(Leaf): Parameters passed to the feature quantizers. """ - def __init__(self, prediction=0.0, depth=0, split_params=None): + def __init__(self, prediction: float = 0.0, depth: int = 0, split_params: dict | None = None): super().__init__() self._prediction = prediction self.depth = depth @@ -52,7 +52,7 @@ def reset(self): self._update_stats = GradHessStats() @staticmethod - def is_categorical(idx, x_val, nominal_attributes): + def is_categorical(idx: str, x_val, nominal_attributes: list[str]) -> bool: return not isinstance(x_val, numbers.Number) or idx in nominal_attributes def update(self, x: dict, gh: GradHess, sgt, w: float = 1.0): diff --git a/river/tree/stochastic_gradient_tree.py b/river/tree/stochastic_gradient_tree.py index 4a3821d6ca..a92598872b 100644 --- a/river/tree/stochastic_gradient_tree.py +++ b/river/tree/stochastic_gradient_tree.py @@ -7,7 +7,7 @@ from river import base, tree -from .losses import BinaryCrossEntropyLoss, SquaredErrorLoss +from .losses import BinaryCrossEntropyLoss, Loss, SquaredErrorLoss from .nodes.branch import DTBranch, NominalMultiwayBranch, NumericBinaryBranch from .nodes.sgt_nodes import SGTLeaf from .utils import BranchFactory, GradHessMerit @@ -23,15 +23,15 @@ class StochasticGradientTree(base.Estimator, abc.ABC): def __init__( self, - loss_func, - delta, - grace_period, - init_pred, - max_depth, - lambda_value, - gamma, - nominal_attributes, - feature_quantizer, + loss_func: Loss, + delta: float, + grace_period: int, + init_pred: float, + max_depth: int | None, + lambda_value: float, + gamma: float, + nominal_attributes: list[str] | None, + feature_quantizer: tree.splitter.Quantizer | None, ): # What really defines how a SGT works is its loss function self.loss_func = loss_func @@ -56,7 +56,7 @@ def __init__( self._root: SGTLeaf | DTBranch = SGTLeaf(prediction=self.init_pred) # set used to check whether categorical feature has been already split - self._split_features = set() + self._split_features: set[str] = set() self._n_splits = 0 self._n_node_updates = 0 self._n_observations = 0