Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear out MyPy's warnings #1581

Merged
merged 1 commit into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions river/base/drift_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ class _BaseDriftDetector(base.Base):
def __init__(self):
self._drift_detected = False

def _reset(self):
def _reset(self) -> None:
"""Reset the detector's state."""
self._drift_detected = False

@property
def drift_detected(self):
def drift_detected(self) -> bool:
"""Whether or not a drift is detected following the last update."""
return self._drift_detected

Expand Down Expand Up @@ -57,7 +57,7 @@ class DriftDetector(_BaseDriftDetector):
"""A drift detector."""

@abc.abstractmethod
def update(self, x: int | float) -> DriftDetector:
def update(self, x: int | float) -> None:
"""Update the detector with a single data point.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion river/drift/kswin.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _reset(self):
super()._reset()
self.p_value = 0
self.n = 0
self.window: collections.deque = collections.deque(maxlen=self.window_size)
self.window = collections.deque(maxlen=self.window_size)
self._rng = random.Random(self.seed)

def update(self, x):
Expand Down
4 changes: 2 additions & 2 deletions river/forest/aggregated_mondrian_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
# memory of the classes
self._classes: set[base.typing.ClfTarget] = set()

def _initialize_trees(self):
def _initialize_trees(self) -> None:
self.data: list[MondrianTreeClassifier] = []
for _ in range(self.n_estimators):
tree = MondrianTreeClassifier(
Expand Down Expand Up @@ -287,7 +287,7 @@ def __init__(

self.iteration = 0

def _initialize_trees(self):
def _initialize_trees(self) -> None:
"""Initialize the forest."""

self.data: list[MondrianTreeRegressor] = []
Expand Down
3 changes: 2 additions & 1 deletion river/tree/nodes/hatc_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math

from river import base
from river import stats as st
from river.utils.norm import normalize_values_in_dict
from river.utils.random import poisson
Expand Down Expand Up @@ -133,7 +134,7 @@ class AdaBranchClassifier(DTBranch):
Other parameters passed to the split node.
"""

def __init__(self, stats, *children, drift_detector, **attributes):
def __init__(self, stats: dict, *children, drift_detector: base.DriftDetector, **attributes):
super().__init__(stats, *children, **attributes)
self.drift_detector = drift_detector
self._alternate_tree = None
Expand Down
4 changes: 2 additions & 2 deletions river/tree/nodes/sgt_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SGTLeaf(Leaf):
Parameters passed to the feature quantizers.
"""

def __init__(self, prediction=0.0, depth=0, split_params=None):
def __init__(self, prediction: float = 0.0, depth: int = 0, split_params: dict | None = None):
super().__init__()
self._prediction = prediction
self.depth = depth
Expand All @@ -52,7 +52,7 @@ def reset(self):
self._update_stats = GradHessStats()

@staticmethod
def is_categorical(idx, x_val, nominal_attributes):
def is_categorical(idx: str, x_val, nominal_attributes: list[str]) -> bool:
return not isinstance(x_val, numbers.Number) or idx in nominal_attributes

def update(self, x: dict, gh: GradHess, sgt, w: float = 1.0):
Expand Down
22 changes: 11 additions & 11 deletions river/tree/stochastic_gradient_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from river import base, tree

from .losses import BinaryCrossEntropyLoss, SquaredErrorLoss
from .losses import BinaryCrossEntropyLoss, Loss, SquaredErrorLoss
from .nodes.branch import DTBranch, NominalMultiwayBranch, NumericBinaryBranch
from .nodes.sgt_nodes import SGTLeaf
from .utils import BranchFactory, GradHessMerit
Expand All @@ -23,15 +23,15 @@ class StochasticGradientTree(base.Estimator, abc.ABC):

def __init__(
self,
loss_func,
delta,
grace_period,
init_pred,
max_depth,
lambda_value,
gamma,
nominal_attributes,
feature_quantizer,
loss_func: Loss,
delta: float,
grace_period: int,
init_pred: float,
max_depth: int | None,
lambda_value: float,
gamma: float,
nominal_attributes: list[str] | None,
feature_quantizer: tree.splitter.Quantizer | None,
):
# What really defines how a SGT works is its loss function
self.loss_func = loss_func
Expand All @@ -56,7 +56,7 @@ def __init__(
self._root: SGTLeaf | DTBranch = SGTLeaf(prediction=self.init_pred)

# set used to check whether categorical feature has been already split
self._split_features = set()
self._split_features: set[str] = set()
self._n_splits = 0
self._n_node_updates = 0
self._n_observations = 0
Expand Down
Loading