Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford committed Oct 30, 2023
2 parents b72077d + 8a7cae7 commit c4afffc
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 87 deletions.
13 changes: 9 additions & 4 deletions docs/releases/unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ River's mini-batch methods now support pandas v2. In particular, River conforms
- Made `score_one` method of `anomaly.LocalOutlierFactor` stateless
- Defined default score for uninitialized detector

## covariance

- Added `_from_state` method to `covariance.EmpiricalCovariance` to warm start from previous knowledge.

## clustering

- Add fixes to `cluster.DBSTREAM` algorithm, including:
Expand All @@ -22,6 +26,10 @@ River's mini-batch methods now support pandas v2. In particular, River conforms

- Added `datasets.WebTraffic`, which is a dataset that counts the occurrences of events on a website. It is a multi-output regression dataset with two outputs.

## drift

- Add `drift.NoDrift` to allow disabling the drift detection capabilities of models. This detector does nothing and always returns `False` when queried whether or not a concept drift was detected.

## evaluate

- Added a `yield_predictions` parameter to `evaluate.iter_progressive_val_score`, which allows including predictions in the output.
Expand All @@ -30,17 +38,14 @@ River's mini-batch methods now support pandas v2. In particular, River conforms

- Simplify inner the structures of `forest.ARFClassifier` and `forest.ARFRegressor` by removing redundant class hierarchy. Simplify how concept drift logging can be accessed in individual trees and in the forest as a whole.

## covariance

- Added `_from_state` method to `covariance.EmpiricalCovariance` to warm start from previous knowledge.

## proba

- Added `_from_state` method to `proba.MultivariateGaussian` to warm start from previous knowledge.

## tree

- Fix a bug in `tree.splitter.NominalSplitterClassif` that generated a mismatch between the number of existing tree branches and the number of tracked branches.
- Fix a bug in `tree.ExtremelyFastDecisionTreeClassifier` where the split re-evaluation failed when the current branch's feature was not available as a split option. The fix also enables the tree to pre-prune a leaf via the tie-breaking mechanism.

## utils

Expand Down
2 changes: 2 additions & 0 deletions river/drift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .adwin import ADWIN
from .dummy import DummyDriftDetector
from .kswin import KSWIN
from .no_drift import NoDrift
from .page_hinkley import PageHinkley
from .retrain import DriftRetrainingClassifier

Expand All @@ -22,6 +23,7 @@
"DriftRetrainingClassifier",
"DummyDriftDetector",
"KSWIN",
"NoDrift",
"PageHinkley",
"PeriodicTrigger",
]
76 changes: 76 additions & 0 deletions river/drift/no_drift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from __future__ import annotations

from river import base
from river.base.drift_detector import DriftDetector


class NoDrift(base.DriftDetector):
"""Dummy class used to turn off concept drift detection capabilities of adaptive models.
It always signals that no concept drift was detected.
Examples
--------
>>> from river import drift
>>> from river import evaluate
>>> from river import forest
>>> from river import metrics
>>> from river.datasets import synth
>>> dataset = synth.ConceptDriftStream(
... seed=8,
... position=500,
... width=40,
... ).take(700)
We can turn off the warning detection capabilities of Adaptive Random Forest (ARF) or
other similar models. Thus, the base models will reset immediately after identifying a drift,
bypassing the background model building phase:
>>> adaptive_model = forest.ARFClassifier(
... leaf_prediction="mc",
... warning_detector=drift.NoDrift(),
... seed=8
... )
We can also turn off the concept drift handling capabilities completely:
>>> stationary_model = forest.ARFClassifier(
... leaf_prediction="mc",
... warning_detector=drift.NoDrift(),
... drift_detector=drift.NoDrift(),
... seed=8
... )
Let's put that to test:
>>> for x, y in dataset:
... adaptive_model = adaptive_model.learn_one(x, y)
... stationary_model = stationary_model.learn_one(x, y)
The adaptive model:
>>> adaptive_model.n_drifts_detected()
2
>>> adaptive_model.n_warnings_detected()
0
The stationary one:
>>> stationary_model.n_drifts_detected()
0
>>> stationary_model.n_warnings_detected()
0
"""

def __init__(self):
super().__init__()

def update(self, x: int | float) -> DriftDetector:
return self

@property
def drift_detected(self):
return False
57 changes: 27 additions & 30 deletions river/forest/adaptive_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from river import base, metrics, stats
from river.drift import ADWIN
from river.drift import ADWIN, NoDrift
from river.tree.hoeffding_tree_classifier import HoeffdingTreeClassifier
from river.tree.hoeffding_tree_regressor import HoeffdingTreeRegressor
from river.tree.nodes.arf_htc_nodes import (
Expand All @@ -32,8 +32,8 @@ def __init__(
n_models: int,
max_features: bool | str | int,
lambda_value: int,
drift_detector: base.DriftDetector | None,
warning_detector: base.DriftDetector | None,
drift_detector: base.DriftDetector,
warning_detector: base.DriftDetector,
metric: metrics.base.MultiClassMetric | metrics.base.RegressionMetric,
disable_weighted_vote,
seed,
Expand All @@ -50,31 +50,32 @@ def __init__(

self._rng = random.Random(self.seed)

self._warning_detectors: list[base.DriftDetector] = (
None # type: ignore
if self.warning_detector is None
else [self.warning_detector.clone() for _ in range(self.n_models)]
)
self._drift_detectors: list[base.DriftDetector] = (
None # type: ignore
if self.drift_detector is None
else [self.drift_detector.clone() for _ in range(self.n_models)]
)
self._warning_detectors: list[base.DriftDetector]
self._warning_detection_disabled = True
if not isinstance(self.warning_detector, NoDrift):
self._warning_detectors = [self.warning_detector.clone() for _ in range(self.n_models)]
self._warning_detection_disabled = False

self._drift_detectors: list[base.DriftDetector]
self._drift_detection_disabled = True
if not isinstance(self.drift_detector, NoDrift):
self._drift_detectors = [self.drift_detector.clone() for _ in range(self.n_models)]
self._drift_detection_disabled = False

# The background models
self._background: list[BaseTreeClassifier | BaseTreeRegressor | None] = (
None if self.warning_detector is None else [None] * self.n_models # type: ignore
None if self._warning_detection_disabled else [None] * self.n_models # type: ignore
)

# Performance metrics used for weighted voting/aggregation
self._metrics = [self.metric.clone() for _ in range(self.n_models)]

# Drift and warning logging
self._warning_tracker: dict = (
collections.defaultdict(int) if self.warning_detector is not None else None # type: ignore
collections.defaultdict(int) if not self._warning_detection_disabled else None # type: ignore
)
self._drift_tracker: dict = (
collections.defaultdict(int) if self.drift_detector is not None else None # type: ignore
collections.defaultdict(int) if not self._drift_detection_disabled else None # type: ignore
)

@property
Expand All @@ -101,12 +102,10 @@ def _drift_detector_input(
def _new_base_model(self) -> BaseTreeClassifier | BaseTreeRegressor:
raise NotImplementedError

def n_warnings_detected(self, tree_id: int | None = None) -> int | None:
def n_warnings_detected(self, tree_id: int | None = None) -> int:
"""Get the total number of concept drift warnings detected, or the number on an individual
tree basis (optionally).
If warning detection is disabled, will return `None`.
Parameters
----------
tree_id
Expand All @@ -119,20 +118,18 @@ def n_warnings_detected(self, tree_id: int | None = None) -> int | None:
"""

if self.warning_detector is None:
return None
if self._warning_detection_disabled:
return 0

if tree_id is None:
return sum(self._warning_tracker.values())

return self._warning_tracker[tree_id]

def n_drifts_detected(self, tree_id: int | None = None) -> int | None:
def n_drifts_detected(self, tree_id: int | None = None) -> int:
"""Get the total number of concept drifts detected, or such number on an individual
tree basis (optionally).
If drift detection is disabled, will return `None`.
Parameters
----------
tree_id
Expand All @@ -145,8 +142,8 @@ def n_drifts_detected(self, tree_id: int | None = None) -> int | None:
"""

if self.drift_detector is None:
return None
if self._drift_detection_disabled:
return 0

if tree_id is None:
return sum(self._drift_tracker.values())
Expand All @@ -171,13 +168,13 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs):

k = poisson(rate=self.lambda_value, rng=self._rng)
if k > 0:
if self.warning_detector is not None and self._background[i] is not None:
if not self._warning_detection_disabled and self._background[i] is not None:
self._background[i].learn_one(x=x, y=y, sample_weight=k) # type: ignore

model.learn_one(x=x, y=y, sample_weight=k)

drift_input = None
if self.drift_detector is not None and self.warning_detector is not None:
if not self._warning_detection_disabled:
drift_input = self._drift_detector_input(i, y, y_pred)
self._warning_detectors[i].update(drift_input)

Expand All @@ -189,7 +186,7 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs):
# Update warning tracker
self._warning_tracker[i] += 1

if self.drift_detector is not None:
if not self._drift_detection_disabled:
drift_input = (
drift_input
if drift_input is not None
Expand All @@ -198,7 +195,7 @@ def learn_one(self, x: dict, y: base.typing.Target, **kwargs):
self._drift_detectors[i].update(drift_input)

if self._drift_detectors[i].drift_detected:
if self.warning_detector is not None and self._background[i] is not None:
if not self._warning_detection_disabled and self._background[i] is not None:
self.data[i] = self._background[i]
self._background[i] = None
self._warning_detectors[i] = self.warning_detector.clone()
Expand Down
Loading

0 comments on commit c4afffc

Please sign in to comment.