From 03867378dd62c4d529078d8da1e11b2d1dfee5b3 Mon Sep 17 00:00:00 2001 From: kenzabenjelloun <74252706+kenzabenjelloun@users.noreply.github.com> Date: Mon, 3 Jul 2023 22:53:20 +0200 Subject: [PATCH] AMFRegressor (#1166) * AMF Classifier & Mondrian Tree Classifier implementation * [Pull request Update] - Adding a "mondrian" folder in the "tree" folder for better file structure - Using "random.choices" instead of the "sample_discrete" functions in "utils.py", and removing "sample_discrete" from the "utils.py" * [Pull Request] - Removing the "__repr__" method of AMF - Removing the @setter and @getter - Removing the "loss" parameter of the classifiers since only the "log-loss" is being used in the end * Updating docstring * [Pull request] - Making `learn_one` and `predict_proba_one` accepting all kinds of supported labels for `y` as input - `predict_proba_one` outputs a dictionary of scores with matching labels * [Fix] Reability Co-authored-by: Saulo Martiello Mastelini * [Fix] Language Co-authored-by: Saulo Martiello Mastelini * [Fix] Language Co-authored-by: Saulo Martiello Mastelini * [Fix] math package implementation usage Co-authored-by: Saulo Martiello Mastelini * [Pull request] - Leaving `__all__` in alphabetical order for the classifiers - Removing type parameters in the description of `log_2_sum` of math utils - Replacing java-like getters and setters by python-like properties and setter * - Adding support for random state (seed) - Replacing Overflow from infinity to maximum possible float (so it makes computations still possible) * [Ignoring testing environment] * Fixing style & typos Co-authored-by: Saulo Martiello Mastelini * [Pull request] - Fixing import order in __init__ file of ensemble - Using LaTeX formulation in AMFClassifier description - Making all nodes related methods private (it shouldn't be used outside) - Docstring syntax update and fixes - Importing river.base instead of typing module for better readability - Adding a short description to the MondrianTreeClassifier - Renaming MondrianTreeLeaf into MondrianLeaf - Reordering functions in MondrianTreeClassifier for better readability * Pre-commit clean up * Pre-commit clean up * [MyPy issue] - Trying to fix the left-right issue uppercast (that shouldn't be a problem normally, but mypy keeps being unhappy) - Fixing assignment issue to the parent during upward procedure - Fixing type assignment to the root branch of the tree - Fixing arg-type for list of intensities - Fixing arg-type issue with current samples proceeding - Fixing dirichlet arg-type issue - Fixing some typing issues - Removing call-overload as int in the memories features range list - Correcting output of predict function * Fixing MyPy issues (detyping) * suggestions and style issues fix * addingnecessary files, classes and methods for regressor * minor import modifications * minor list to typing.List and dict to typing.Dict modifs * minor modifs to pass tests * minor changes * changing names * Fixing predict function to support the "model not trained" situation instead of raising an exception * more style suggestions * testing * regressor fix * fixing docstring * [Pull request Update] - Fixing some TODOs from Mastelini suggestions - Factorizing a bit of code from nodes that should be shared with regressor - Removing branch structure as of now for future changes * Removing all "array-like" structure for full dict support * Pre-commit hookups fixes * regressor fix * Delete tests.py * [Pull request] - Adding suggestions from Mastelini on keys usage - Removing useless initialization of scores in the MondrianTreeClassifier * bug fix * fix conflicts * refactored, but has bugs * remove mypy skip * tests * tests * cleanup * better, but not fixed * minor fix * [Fixes] - Fixing scoring bug (no propagation of counts) - Removing unused parameters in docs - Replacing type union of Python 3.10 in 3.9 annotations - Adding little description for MondrianBranch * Pre-commit hookups fixes * fix some tests * Reworking intensities * fix remaining tests and remove duplicated method call * [Pull request] - Adding examples for AMF & Mondrian Tree Classifiers - Reordering __init__ in alphabetical order - Cleaning the comments - Adding string representation for nodes * Hiding MondrianTree from users visibility * Fixing import on Mondrian Tree example Co-authored-by: Saulo Martiello Mastelini * tests * merge fix * merge fix * docstring fixes --------- Co-authored-by: AlexandreChaussard Co-authored-by: Alexandre Chaussard <78101027+AlexandreChaussard@users.noreply.github.com> Co-authored-by: Saulo Martiello Mastelini Co-authored-by: Kenza Ben jelloun Co-authored-by: Saulo Martiello Mastelini --- river/forest/aggregated_mondrian_forest.py | 113 ++++- river/tree/mondrian/__init__.py | 5 +- .../tree/mondrian/mondrian_tree_classifier.py | 1 + river/tree/mondrian/mondrian_tree_nodes.py | 151 +++++- .../tree/mondrian/mondrian_tree_regressor.py | 432 ++++++++++++++++++ 5 files changed, 697 insertions(+), 5 deletions(-) create mode 100644 river/tree/mondrian/mondrian_tree_regressor.py diff --git a/river/forest/aggregated_mondrian_forest.py b/river/forest/aggregated_mondrian_forest.py index 08b7f1f62b..fcda6baa94 100644 --- a/river/forest/aggregated_mondrian_forest.py +++ b/river/forest/aggregated_mondrian_forest.py @@ -4,7 +4,7 @@ import random from river import base -from river.tree.mondrian import MondrianTreeClassifier +from river.tree.mondrian import MondrianTreeClassifier, MondrianTreeRegressor class AMFLearner(base.Ensemble, abc.ABC): @@ -217,3 +217,114 @@ def predict_proba_one(self, x): @property def _multiclass(self): return True + + +class AMFRegressor(AMFLearner, base.Regressor): + """Aggregated Mondrian Forest regressor for online learning. + + This algorithm is truly online, in the sense that a single pass is performed, and that + predictions can be produced anytime. + + Each node in a tree predicts according to the average of the labels + it contains. The prediction for a sample is computed as the aggregated predictions + of all the subtrees along the path leading to the leaf node containing the sample. + The aggregation weights are exponential weights with learning rate ``step`` and loss + ``loss`` when ``use_aggregation`` is ``True``. + + This computation is performed exactly thanks to a context tree weighting algorithm. + More details can be found in the paper cited in references below. + + The final predictions are the average of the predictions of each of the + ``n_estimators`` trees in the forest. + + Parameters + ---------- + n_estimators + The number of trees in the forest. + step + Step-size for the aggregation weights. + use_aggregation + Controls if aggregation is used in the trees. It is highly recommended to + leave it as `True`. + split_pure + Controls if nodes that contains only sample of the same class should be + split ("pure" nodes). Default is `False`, namely pure nodes are not split, + but `True` can be sometimes better. + seed + Random seed for reproducibility. + + Note + ---- + All the parameters of ``AMFRegressor`` become **read-only** after the first call + to ``partial_fit``. + + References + ---------- + [^1]: J. Mourtada, S. Gaiffas and E. Scornet, *AMF: Aggregated Mondrian Forests for Online Learning*, arXiv:1906.10529, 2019 + + """ + + def __init__( + self, + n_estimators: int = 10, + step: float = 1.0, + use_aggregation: bool = True, + split_pure: bool = False, + seed: int = None, + ): + + super().__init__( + n_estimators=n_estimators, + step=step, + loss="least-squares", + use_aggregation=use_aggregation, + split_pure=split_pure, + seed=seed, + ) + + self.iteration = 0 + + def _initialize_trees(self): + """Initialize the forest.""" + + self.data: list[MondrianTreeRegressor] = [] + for _ in range(self.n_estimators): + # We don't want to have the same stochastic scheme for each tree, or it'll break the randomness + # Hence we introduce a new seed for each, that is derived of the given seed by a deterministic process + seed = self._rng.randint(0, 9999999) + + tree = MondrianTreeRegressor( + self.step, + self.use_aggregation, + self.split_pure, + self.iteration, + seed, + ) + self.data.append(tree) + + def learn_one(self, x, y): + # Checking if the forest has been created + if not self.is_trained(): + self._initialize_trees() + + # we fit all the trees using the new sample + for tree in self: + tree.learn_one(x, y) + + self.iteration += 1 + + return self + + def predict_one(self, x): + + # Checking that the model has been trained once at least + if not self.is_trained(): + return None + + prediction = 0 + for tree in self: + tree.use_aggregation = self.use_aggregation + prediction += tree.predict_one(x) + prediction = prediction / self.n_estimators + + return prediction diff --git a/river/tree/mondrian/__init__.py b/river/tree/mondrian/__init__.py index c009792d3b..15b596ba75 100644 --- a/river/tree/mondrian/__init__.py +++ b/river/tree/mondrian/__init__.py @@ -3,12 +3,13 @@ implementations for the Mondrian trees. Note that this module is not exposed in the tree module, and is instead used by the -AMFClassifier class in the ensemble module. +AMFClassifier and AMFRegressor classes in the ensemble module. """ from __future__ import annotations from .mondrian_tree import MondrianTree from .mondrian_tree_classifier import MondrianTreeClassifier +from .mondrian_tree_regressor import MondrianTreeRegressor -__all__ = ["MondrianTree", "MondrianTreeClassifier"] +__all__ = ["MondrianTree", "MondrianTreeClassifier", "MondrianTreeRegressor"] diff --git a/river/tree/mondrian/mondrian_tree_classifier.py b/river/tree/mondrian/mondrian_tree_classifier.py index d2302733be..ff850acf1b 100644 --- a/river/tree/mondrian/mondrian_tree_classifier.py +++ b/river/tree/mondrian/mondrian_tree_classifier.py @@ -464,6 +464,7 @@ def predict_proba_one(self, x): # If the tree hasn't seen any sample, then it should return # the default empty dict + if not self._is_initialized: return {} diff --git a/river/tree/mondrian/mondrian_tree_nodes.py b/river/tree/mondrian/mondrian_tree_nodes.py index f262515175..491724be42 100644 --- a/river/tree/mondrian/mondrian_tree_nodes.py +++ b/river/tree/mondrian/mondrian_tree_nodes.py @@ -288,9 +288,9 @@ def update_downwards( Parameters ---------- x - Sample to proceed (as a list). + Sample to proceed. y - Class of the sample x_t. + Class of the sample x. dirichlet Dirichlet parameter of the tree. use_aggregation @@ -366,3 +366,150 @@ class MondrianBranchClassifier(MondrianNodeClassifier, MondrianBranch): def __init__(self, parent, time, depth, feature, threshold, *children): super().__init__(parent, time, depth, feature, threshold, *children) + + +class MondrianNodeRegressor(MondrianNode): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.n_samples = 0 + self.mean = 0.0 + + def replant(self, leaf: MondrianNodeRegressor, copy_all: bool = False): + """Transfer information from a leaf to a new branch.""" + self.weight = leaf.weight # type: ignore + self.log_weight_tree = leaf.log_weight_tree # type: ignore + self.mean = leaf.mean + + if copy_all: + self.memory_range_min = leaf.memory_range_min + self.memory_range_max = leaf.memory_range_max + self.n_samples = leaf.n_samples + + def predict(self) -> base.typing.RegTarget: + """Return the prediction of the node.""" + return self.mean + + def loss(self, sample_value: base.typing.RegTarget) -> float: + """Compute the loss of the node. + + Parameters + ---------- + sample_value + A given value. + """ + + r = self.predict() - sample_value + return r * r / 2 + + def update_weight( + self, + sample_value: base.typing.RegTarget, + use_aggregation: bool, + step: float, + ) -> float: + """Update the weight of the node given a label and the method used. + + Parameters + ---------- + sample_value + Label of a given sample. + use_aggregation + Whether to use aggregation or not during computation (given by the tree). + step + Step parameter of the tree. + """ + + loss_t = self.loss(sample_value) + if use_aggregation: + self.weight -= step * loss_t + return loss_t + + def update_downwards( + self, + x, + sample_value: base.typing.RegTarget, + use_aggregation: bool, + step: float, + do_update_weight: bool, + ): + """Update the node when running a downward procedure updating the tree. + + Parameters + ---------- + x + Sample to proceed (as a list). + sample_value + Label of the sample x. + use_aggregation + Should it use the aggregation or not + step + Step of the tree. + do_update_weight + Should we update the weights of the node as well. + """ + + # Updating the range of the feature values known by the node + # If it is the first sample, we copy the features vector into the min and max range + if self.n_samples == 0: + for feature in x: + x_f = x[feature] + self.memory_range_min[feature] = x_f + self.memory_range_max[feature] = x_f + # Otherwise, we update the range + else: + for feature in x: + x_f = x[feature] + if x_f < self.memory_range_min[feature]: + self.memory_range_min[feature] = x_f + if x_f > self.memory_range_max[feature]: + self.memory_range_max[feature] = x_f + + # One more sample in the node + self.n_samples += 1 + + if do_update_weight: + self.update_weight(sample_value, use_aggregation, step) + + # Update the mean of the labels in the node online + self.mean = (self.n_samples * self.mean + sample_value) / (self.n_samples + 1) + + +class MondrianLeafRegressor(MondrianNodeRegressor, MondrianLeaf): + """Mondrian Tree Regressor leaf node. + + Parameters + ---------- + parent + Parent node. + time + Split time of the node. + depth + The depth of the leaf. + """ + + def __init__(self, parent, time, depth): + super().__init__(parent, time, depth) + + +class MondrianBranchRegressor(MondrianNodeRegressor, MondrianBranch): + """Mondrian Tree Regressor branch node. + + Parameters + ---------- + parent + Parent node of the branch. + time + Split time characterizing the branch. + depth + Depth of the branch in the tree. + feature + Feature of the branch. + threshold + Acceptation threshold of the branch. + *children + Children nodes of the branch. + """ + + def __init__(self, parent, time, depth, feature, threshold, *children): + super().__init__(parent, time, depth, feature, threshold, *children) diff --git a/river/tree/mondrian/mondrian_tree_regressor.py b/river/tree/mondrian/mondrian_tree_regressor.py new file mode 100644 index 0000000000..218983a866 --- /dev/null +++ b/river/tree/mondrian/mondrian_tree_regressor.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import math +import sys + +from river import base, utils +from river.tree.mondrian.mondrian_tree import MondrianTree +from river.tree.mondrian.mondrian_tree_nodes import ( + MondrianBranchRegressor, + MondrianLeafRegressor, + MondrianNodeRegressor, +) + + +class MondrianTreeRegressor(MondrianTree, base.Regressor): + """Mondrian Tree Regressor. + + Parameters + ---------- + step + Step of the tree. + use_aggregation + Whether to use aggregation weighting techniques or not. + split_pure + Whether the tree should split pure leafs during training or not. + iteration + Number iterations to do during training. + seed + Random seed for reproducibility. + + Notes + ----- + The Mondrian Tree Regressor is a type of decision tree that bases splitting decisions over a + Mondrian process. + + References + ---------- + [^1] Balaji Lakshminarayanan, Daniel M. Roy, Yee Whye Teh. Mondrian Forests: Efficient Online Random Forests. + arXiv:1406.2673, pages 2-4 + """ + + def __init__( + self, + step: float = 0.1, + use_aggregation: bool = True, + split_pure: bool = False, + iteration: int = 0, + seed: int = None, + ): + + super().__init__( + step=step, + loss="least-squares", + use_aggregation=use_aggregation, + split_pure=split_pure, + iteration=iteration, + seed=seed, + ) + # Controls the randomness in the tree + self.seed = seed + + # The current sample being proceeded + self._x: dict[base.typing.FeatureName, int | float] + # The current label index being proceeded + self._y: base.typing.RegTarget + + # Initialization of the root of the tree + # It's the root so it doesn't have any parent (hence None) + self._root = MondrianLeafRegressor(None, 0.0, 0) + + def _is_initialized(self): + """Check if the tree has learnt at least one sample""" + return self.iteration != 0 + + def _predict(self, node: MondrianNodeRegressor) -> base.typing.RegTarget: + """Compute the prediction. + + Parameters + ---------- + node + Node to make predictions. + """ + + return node.predict() + + def _loss(self, node: MondrianNodeRegressor) -> float: + """Compute the loss for the given node regarding the current label. + + Parameters + ---------- + node + Node to evaluate the loss. + """ + + return node.loss(self._y) + + def _update_weight(self, node: MondrianNodeRegressor) -> float: + """Update the weight of the node regarding the current label with the tree parameters. + + Parameters + ---------- + node + Node to update the weight. + """ + + return node.update_weight(self._y, self.use_aggregation, self.step) + + def _update_downwards(self, node: MondrianNodeRegressor, do_update_weight): + """Update the node when running a downward procedure updating the tree. + + Parameters + ---------- + node + Target node. + do_update_weight + Whether we should update the weights or not. + """ + + return node.update_downwards( + self._x, self._y, self.use_aggregation, self.step, do_update_weight + ) + + def _compute_split_time( + self, + node: MondrianLeafRegressor | MondrianBranchRegressor, + extensions_sum: float, + ) -> float: + """Computes the split time of the given node. + + Parameters + ---------- + node + Target node. + """ + + # Don't split if the node is pure: all labels are equal to the one of y_t + # TODO: what do we do here ? Zero variance ? + if extensions_sum > 0: + # Sample an exponential with intensity = extensions_sum + # try catch to handle the Overflow situation in the exponential + try: + T = math.exp(1 / extensions_sum) + except OverflowError: + T = sys.float_info.max # we get the largest possible output instead + + time = node.time + # Splitting time of the node (if splitting occurs) + split_time = time + T + # If the node is a leaf we must split it + if isinstance(node, MondrianLeafRegressor): + return split_time + # Otherwise we apply Mondrian process dark magic :) + # 1. We get the creation time of the childs (left and right is the same) + left, _ = node.children + child_time = left.time + # 2. We check if splitting time occurs before child creation time + if split_time < child_time: + return split_time + + return 0.0 + + def _split( + self, + node: MondrianLeafRegressor | MondrianBranchRegressor, + split_time: float, + threshold: float, + feature: base.typing.FeatureName, + is_right_extension: bool, + ) -> MondrianBranchRegressor: + """Split the given node and attributes the split time, threshold, etc... to the node. + + Parameters + ---------- + node + Target node. + split_time + Split time of the node in the Mondrian process. + threshold + Threshold of acceptance of the node. + feature + Feature index of the node. + is_right_extension + Should we extend the tree in the right or left direction. + """ + + new_depth = node.depth + 1 + + # To calm down mypy + left: MondrianLeafRegressor | MondrianBranchRegressor + right: MondrianLeafRegressor | MondrianBranchRegressor + + # The node is already a branch: we create a new branch above it and move the existing + # node one level down the tree + if isinstance(node, MondrianBranchRegressor): + old_left, old_right = node.children + if is_right_extension: + left = MondrianBranchRegressor( + node, split_time, new_depth, node.feature, node.threshold + ) + right = MondrianLeafRegressor(node, split_time, new_depth) + left.replant(node) + + old_left.parent = left + old_right.parent = left + + left.children = (old_left, old_right) + else: + right = MondrianBranchRegressor( + node, split_time, new_depth, node.feature, node.threshold + ) + left = MondrianLeafRegressor(node, split_time, new_depth) + right.replant(node) + + old_left.parent = right + old_right.parent = right + + right.children = (old_left, old_right) + + # Update the level of the modified nodes + new_depth += 1 + old_left.update_depth(new_depth) + old_right.update_depth(new_depth) + + # Update split info + node.feature = feature + node.threshold = threshold + node.children = (left, right) + + return node + + # We promote the leaf to a branch + branch = MondrianBranchRegressor(node.parent, node.time, node.depth, feature, threshold) + left = MondrianLeafRegressor(branch, split_time, new_depth) + right = MondrianLeafRegressor(branch, split_time, new_depth) + branch.children = (left, right) + + # Copy properties from the previous leaf + branch.replant(node, True) + + if is_right_extension: + left.replant(node) + else: + right.replant(node) + + # To avoid leaving garbage behind + del node + + return branch + + def _go_downwards(self): + """Update the tree (downward procedure).""" + + # We update the nodes along the path which leads to the leaf containing the current + # sample. For each node on the path, we consider the possibility of splitting it, + # following the Mondrian process definition. + + # We start at the root + current_node = self._root + + if self.iteration == 0: + # If it's the first iteration, we just put the current sample in the range of root + self._update_downwards(current_node, False) + return current_node + else: + # Path from the parent to the current node + branch_no = None + while True: + # Computing the extensions to get the intensities + extensions_sum, extensions = current_node.range_extension(self._x) + + # If it's not the first iteration (otherwise the current node + # is root with no range), we consider the possibility of a split + split_time = self._compute_split_time(current_node, extensions_sum) + + if split_time > 0: + # We split the current node: because the current node is a + # leaf, or because we add a new node along the path + + # We normalize the range extensions to get probabilities + intensities = utils.norm.normalize_values_in_dict(extensions, inplace=False) + + # Sample the feature at random with a probability + # proportional to the range extensions + + candidates = sorted(list(self._x.keys())) + feature = self._rng.choices( + candidates, [intensities[c] for c in candidates], k=1 + )[0] + + x_f = self._x[feature] + + # Is it a right extension of the node ? + range_min, range_max = current_node.range(feature) + is_right_extension = x_f > range_max + if is_right_extension: + threshold = self._rng.uniform(range_max, x_f) + else: + threshold = self._rng.uniform(x_f, range_min) + + was_leaf = isinstance(current_node, MondrianLeafRegressor) + + # We split the current node + current_node = self._split( + current_node, + split_time, + threshold, + feature, + is_right_extension, + ) + + # The root node has become a branch + if current_node.parent is None: + self._root = current_node + # Update path from the previous parent to the recently updated node + elif was_leaf: + parent = current_node.parent + if branch_no == 0: + parent.children = (current_node, parent.children[1]) + else: + parent.children = (parent.children[0], current_node) + + # Update the current node + self._update_downwards(current_node, True) + + left, right = current_node.children + + # Now, get the next node + if is_right_extension: + current_node = right + else: + current_node = left + + # This is the leaf containing the sample point (we've just + # splitted the current node with the data point) + leaf = current_node + self._update_downwards(leaf, False) + return leaf + else: + # There is no split, so we just update the node and go to the next one + self._update_downwards(current_node, True) + if isinstance(current_node, MondrianLeafRegressor): + return current_node + else: + # Save the path direction to keep the tree consistent + try: + branch_no = current_node.branch_no(self._x) + current_node = current_node.children[branch_no] + except KeyError: # Missing split feature + branch_no, current_node = current_node.most_common_path() + + def _go_upwards(self, leaf: MondrianLeafRegressor): + """Update the tree (upwards procedure). + + Parameters + ---------- + leaf + Leaf to start from when going upward. + + """ + + current_node = leaf + + if self.iteration >= 1: + while True: + current_node.update_weight_tree() + if current_node.parent is None: + # We arrived at the root + break + # Note that the root node is updated as well + # We go up to the root in the tree + current_node = current_node.parent + + def learn_one(self, x, y): + # Setting current sample + self._x = x + self._y = y + + # Learning step + leaf = self._go_downwards() + if self.use_aggregation: + self._go_upwards(leaf) + + # Incrementing iteration + self.iteration += 1 + return self + + def predict_one(self, x): + """Predict the label of the samples. + + Parameters + ---------- + x + Feature vector + """ + + # If the tree hasn't seen any sample, then it should return + # the default empty dict + if not self._is_initialized: + return + + leaf = ( + self._root.traverse(x, until_leaf=True) + if isinstance(self._root, MondrianBranchRegressor) + else self._root + ) + + if not self.use_aggregation: + return self._predict(leaf) + + current = leaf + prediction = 0.0 + + while True: + # This test is useless ? + if isinstance(current, MondrianLeafRegressor): + prediction = self._predict(current) + else: + weight = current.weight + log_weight_tree = current.log_weight_tree + w = math.exp(weight - log_weight_tree) + # Get the predictions of the current node + pred_new = self._predict(current) + prediction = 0.5 * w * pred_new + (1 - 0.5 * w) * prediction + + # Root must be updated as well + if current.parent is None: + break + + # And now we go up + current = current.parent + + return prediction