Skip to content

Commit

Permalink
tests + current_merit method
Browse files Browse the repository at this point in the history
  • Loading branch information
danielnowakassis committed Sep 6, 2024
1 parent 313b742 commit fca6254
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 64 deletions.
4 changes: 2 additions & 2 deletions river/stream/iter_arff.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def iter_arff(
x = {
name: cast(val) if cast else val
for name, cast, val in zip(names, casts, r.rstrip().split(","))
if val != "?" and val != ''
if val != "?" and val != ""
}

# Handle target
Expand All @@ -189,7 +189,7 @@ def iter_arff(
y = x.pop(target) if target else None
except KeyError:
y = None

yield x, y

# Close the file if we opened it
Expand Down
2 changes: 1 addition & 1 deletion river/tree/hoeffding_adaptive_tree_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class HoeffdingAdaptiveTreeClassifier(HoeffdingTreeClassifier):
>>> metric = metrics.Accuracy()
>>> evaluate.progressive_val_score(dataset, model, metric)
Accuracy: 91.49%
"""

Expand Down
80 changes: 54 additions & 26 deletions river/tree/last_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@

from .hoeffding_tree_classifier import HoeffdingTreeClassifier
from .nodes.branch import DTBranch
from .nodes.last_nodes import LeafMajorityClassWithDetector, LeafNaiveBayesWithDetector, LeafNaiveBayesAdaptiveWithDetector
from .nodes.last_nodes import (
LeafMajorityClassWithDetector,
LeafNaiveBayesAdaptiveWithDetector,
LeafNaiveBayesWithDetector,
)
from .nodes.leaf import HTLeaf
from .split_criterion import GiniSplitCriterion, HellingerDistanceCriterion, InfoGainSplitCriterion
from .splitter import GaussianSplitter, Splitter
from .splitter import Splitter


class LASTClassifier(HoeffdingTreeClassifier, base.Classifier):
Expand All @@ -28,9 +32,9 @@ class LASTClassifier(HoeffdingTreeClassifier, base.Classifier):
- 'nb' - Naive Bayes</br>
- 'nba' - Naive Bayes Adaptive</br>
change_detector
Change detector that will be created at each leaf of the tree.
Change detector that will be created at each leaf of the tree.
track_error
If True, the change detector will have binary inputs for error predictions,
If True, the change detector will have binary inputs for error predictions,
otherwise the input will be the split criteria.
nb_threshold
Number of instances a leaf should observe before allowing Naive Bayes.
Expand Down Expand Up @@ -97,6 +101,7 @@ class LASTClassifier(HoeffdingTreeClassifier, base.Classifier):
>>> metric = metrics.Accuracy()
>>> evaluate.progressive_val_score(dataset, model, metric)
Accuracy: 92.50%
"""

Expand All @@ -105,8 +110,8 @@ def __init__(
max_depth: int | None = None,
split_criterion: str = "info_gain",
leaf_prediction: str = "nba",
change_detector:base.DriftDetector| None = None,
track_error : bool = True,
change_detector: base.DriftDetector | None = None,
track_error: bool = True,
nb_threshold: int = 0,
nominal_attributes: list | None = None,
splitter: Splitter | None = None,
Expand All @@ -120,23 +125,23 @@ def __init__(
merit_preprune: bool = True,
):
super().__init__(
grace_period=None,
grace_period=1, #no usage
max_depth=max_depth,
split_criterion=split_criterion,
delta=None,
tau=None,
delta=1., #no usage
tau=1, #no usage
leaf_prediction=leaf_prediction,
nb_threshold = nb_threshold,
nb_threshold=nb_threshold,
binary_split=binary_split,
max_size=max_size,
memory_estimate_period=memory_estimate_period,
stop_mem_management=stop_mem_management,
remove_poor_attrs=remove_poor_attrs,
merit_preprune=merit_preprune,
nominal_attributes = nominal_attributes,
splitter = splitter,
min_branch_fraction = min_branch_fraction,
max_share_to_split = max_share_to_split,
nominal_attributes=nominal_attributes,
splitter=splitter,
min_branch_fraction=min_branch_fraction,
max_share_to_split=max_share_to_split,
)
self.change_detector = change_detector if change_detector is not None else drift.ADWIN()
self.track_error = track_error
Expand All @@ -148,7 +153,6 @@ def __init__(
def _mutable_attributes(self):
return {}


def _new_leaf(self, initial_stats=None, parent=None):
if initial_stats is None:
initial_stats = {}
Expand All @@ -159,20 +163,43 @@ def _new_leaf(self, initial_stats=None, parent=None):

if not self.track_error:
if self._leaf_prediction == self._MAJORITY_CLASS:
return LeafMajorityClassWithDetector(initial_stats, depth, self.splitter, self.change_detector.clone())
return LeafMajorityClassWithDetector(
initial_stats, depth, self.splitter, self.change_detector.clone()
)
elif self._leaf_prediction == self._NAIVE_BAYES:
return LeafNaiveBayesWithDetector(initial_stats, depth, self.splitter, self.change_detector.clone())
return LeafNaiveBayesWithDetector(
initial_stats, depth, self.splitter, self.change_detector.clone()
)
else: # Naives Bayes Adaptive (default)
return LeafNaiveBayesAdaptiveWithDetector(initial_stats, depth, self.splitter, self.change_detector.clone())
return LeafNaiveBayesAdaptiveWithDetector(
initial_stats, depth, self.splitter, self.change_detector.clone()
)
else:
split_criterion = self._new_split_criterion()
if self._leaf_prediction == self._MAJORITY_CLASS:
return LeafMajorityClassWithDetector(initial_stats, depth, self.splitter, self.change_detector.clone(), split_criterion)
return LeafMajorityClassWithDetector(
initial_stats,
depth,
self.splitter,
self.change_detector.clone(),
split_criterion,
)
elif self._leaf_prediction == self._NAIVE_BAYES:
return LeafNaiveBayesWithDetector(initial_stats, depth, self.splitter, self.change_detector.clone(), split_criterion)
return LeafNaiveBayesWithDetector(
initial_stats,
depth,
self.splitter,
self.change_detector.clone(),
split_criterion,
)
else: # Naives Bayes Adaptive (default)
return LeafNaiveBayesAdaptiveWithDetector(initial_stats, depth, self.splitter, self.change_detector.clone(), split_criterion)

return LeafNaiveBayesAdaptiveWithDetector(
initial_stats,
depth,
self.splitter,
self.change_detector.clone(),
split_criterion,
)

def _new_split_criterion(self):
if self._split_criterion == self._GINI_SPLIT:
Expand All @@ -181,8 +208,10 @@ def _new_split_criterion(self):
split_criterion = InfoGainSplitCriterion(self.min_branch_fraction)
elif self._split_criterion == self._HELLINGER_SPLIT:
if not self.track_error:
raise ValueError("The Heillinger distance cannot estimate the purity of a single distribution.\
Use another split criterion or set track_error to True")
raise ValueError(
"The Heillinger distance cannot estimate the purity of a single distribution.\
Use another split criterion or set track_error to True"
)
split_criterion = HellingerDistanceCriterion(self.min_branch_fraction)
else:
split_criterion = InfoGainSplitCriterion(self.min_branch_fraction)
Expand Down Expand Up @@ -315,7 +344,7 @@ def learn_one(self, x, y, *, w=1.0):
self._n_inactive_leaves += 1
else:
weight_seen = node.total_weight
#check if the change detector triggered a change
# check if the change detector triggered a change
if node.change_detector.drift_detected:
p_branch = p_node.branch_no(x) if isinstance(p_node, DTBranch) else None
self._attempt_to_split(node, p_node, p_branch)
Expand Down Expand Up @@ -345,4 +374,3 @@ def learn_one(self, x, y, *, w=1.0):

if self._train_weight_seen_by_model % self.memory_estimate_period == 0:
self._estimate_model_size()

40 changes: 18 additions & 22 deletions river/tree/nodes/last_nodes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from __future__ import annotations

from river.tree.utils import BranchFactory
from river.utils.norm import normalize_values_in_dict

from ..splitter.nominal_splitter_classif import NominalSplitterClassif
from ..utils import do_naive_bayes_prediction, round_sig_fig
from ..utils import do_naive_bayes_prediction
from .htc_nodes import LeafMajorityClass


Expand All @@ -24,23 +20,26 @@ class LeafMajorityClassWithDetector(LeafMajorityClass):
Other parameters passed to the learning node.
"""

def __init__(self, stats, depth, splitter,change_detector, split_criterion = None, **kwargs):
def __init__(self, stats, depth, splitter, change_detector, split_criterion=None, **kwargs):
super().__init__(stats, depth, splitter, **kwargs)
self.change_detector = change_detector
self.split_criterion = split_criterion #if None, the change detector will have binary inputs

self.split_criterion = (
split_criterion # if None, the change detector will have binary inputs
)

def learn_one(self, x, y, *, w=1, tree=None):
self.update_stats(y, w)
if self.is_active():
if self.split_criterion is None:
mc_pred = self.prediction(x)
detector_input = (max(mc_pred, key=mc_pred.get) != y)
detector_input = max(mc_pred, key=mc_pred.get) != y
self.change_detector.update(detector_input)
else:
detector_input = self.split_criterion.purity(self.stats)
detector_input = self.split_criterion.current_merit(self.stats)
self.change_detector.update(detector_input)
self.update_splitters(x, y, w, tree.nominal_attributes)


class LeafNaiveBayesWithDetector(LeafMajorityClassWithDetector):
"""Leaf that uses Naive Bayes models.
Expand All @@ -57,18 +56,18 @@ class LeafNaiveBayesWithDetector(LeafMajorityClassWithDetector):
Other parameters passed to the learning node.
"""

def __init__(self, stats, depth, splitter,change_detector, split_criterion = None, **kwargs):
super().__init__(stats, depth, splitter,change_detector,split_criterion,**kwargs)
def __init__(self, stats, depth, splitter, change_detector, split_criterion=None, **kwargs):
super().__init__(stats, depth, splitter, change_detector, split_criterion, **kwargs)

def learn_one(self, x, y, *, w=1, tree=None):
self.update_stats(y, w)
if self.is_active():
if self.split_criterion is None:
nb_pred = self.prediction(x)
detector_input = (max(nb_pred, key=nb_pred.get) == y)
detector_input = max(nb_pred, key=nb_pred.get) == y
self.change_detector.update(detector_input)
else:
detector_input = self.split_criterion.purity(self.stats)
detector_input = self.split_criterion.current_merit(self.stats)
self.change_detector.update(detector_input)
self.update_splitters(x, y, w, tree.nominal_attributes)

Expand Down Expand Up @@ -108,8 +107,8 @@ class LeafNaiveBayesAdaptiveWithDetector(LeafMajorityClassWithDetector):
Other parameters passed to the learning node.
"""

def __init__(self, stats, depth, splitter, change_detector,split_criterion = None, **kwargs):
super().__init__(stats, depth, splitter, change_detector, split_criterion,**kwargs)
def __init__(self, stats, depth, splitter, change_detector, split_criterion=None, **kwargs):
super().__init__(stats, depth, splitter, change_detector, split_criterion, **kwargs)
self._mc_correct_weight = 0.0
self._nb_correct_weight = 0.0

Expand Down Expand Up @@ -150,13 +149,10 @@ def learn_one(self, x, y, *, w=1.0, tree=None):
else:
self.change_detector.update(detector_input_mc)
else:
detector_input = self.split_criterion.purity(self.stats)
detector_input = self.split_criterion.current_merit(self.stats)
self.change_detector.update(detector_input)
self.update_splitters(x, y, w, tree.nominal_attributes)




def prediction(self, x, *, tree=None):
"""Get the probabilities per class for a given instance.
Expand Down Expand Up @@ -188,4 +184,4 @@ def disable_attribute(self, att_index):
att_index
Attribute index.
"""
pass
pass
9 changes: 4 additions & 5 deletions river/tree/split_criterion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ def merit_of_split(self, pre_split_dist, post_split_dist):
-------
Value of the merit of splitting
"""

@abc.abstractmethod
def purity(self, dist):
"""Compute how pure (how close the distribution is to have only a single class)
the distribution is.
def current_merit(self, dist):
"""Compute the merit of the distribution.
Parameters
----------
Expand All @@ -44,7 +43,7 @@ def purity(self, dist):
Returns
-------
Value of purity of the distribution according to the splitting merit
Value of merit of the distribution according to the splitting criterion
"""

@staticmethod
Expand Down
5 changes: 2 additions & 3 deletions river/tree/split_criterion/gini_split_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ def merit_of_split(self, pre_split_dist, post_split_dist):
post_split_dist[i], dist_weights[i]
)
return 1.0 - gini


def purity(self, dist):

def current_merit(self, dist):
return self.compute_gini(dist, sum(dist.values()))

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion river/tree/split_criterion/hellinger_distance_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def merit_of_split(self, pre_split_dist, post_split_dist):
return -math.inf
return self.compute_hellinger(post_split_dist)

def purity(self, dist):
def current_merit(self, dist):
raise ValueError("The Heillinger distance is for 2 or more sets of data.")

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions river/tree/split_criterion/info_gain_split_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def compute_entropy(self, dist):
return self._compute_entropy_dict(dist)
elif isinstance(dist, list):
return self._compute_entropy_list(dist)
def purity(self, dist):

def current_merit(self, dist):
return self.compute_entropy(dist)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions river/tree/split_criterion/variance_ratio_split_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def merit_of_split(self, pre_split_dist, post_split_dist):
vr -= (n_i / n) * (self.compute_var(post_split_dist[i]) / var)
return vr

def purity(self, dist):
def current_merit(self, dist):
return self.compute_var(dist)

@staticmethod
def compute_var(dist):
return dist.get()
Expand Down

0 comments on commit fca6254

Please sign in to comment.