diff --git a/docs/releases/unreleased.md b/docs/releases/unreleased.md index 73a0382824..1c2d3b9ffa 100644 --- a/docs/releases/unreleased.md +++ b/docs/releases/unreleased.md @@ -1,6 +1,6 @@ # Unreleased -Calling `learn_one` in a pipeline will now update each part of the pipeline in turn. Before the unsupervised parts of the pipeline were updated during `predict_one`. This is more intuitive for new users. The old behavior, which yields better results, can be restored by calling `learn_one` with the new `compose.pure_inference_mode` context manager. +Calling `learn_one` in a pipeline will now update each part of the pipeline in turn. Before the unsupervised parts of the pipeline were updated during `predict_one`. This is more intuitive for new users. The old behavior, which yields better results, can be restored by calling `learn_one` with the new `compose.learn_during_predict` context manager. ## compose @@ -15,7 +15,13 @@ Calling `learn_one` in a pipeline will now update each part of the pipeline in t ## forest - Fixed issue with `forest.ARFClassifier` which couldn't be passed a `CrossEntropy` metric. +- Fixed a bug in `forest.AMFClassifier` which slightly improves predictive accurary. +- Added `forest.AMFRegressor`. ## preprocessing - Added `preprocessing.OrdinalEncoder`, to map string features to integers. + +## utils + +- Added `utils.random.exponential` to retrieve random samples following an exponential distribution. diff --git a/river/forest/__init__.py b/river/forest/__init__.py index c73ebd50ae..f9210f0c2d 100644 --- a/river/forest/__init__.py +++ b/river/forest/__init__.py @@ -2,12 +2,13 @@ from __future__ import annotations from .adaptive_random_forest import ARFClassifier, ARFRegressor -from .aggregated_mondrian_forest import AMFClassifier +from .aggregated_mondrian_forest import AMFClassifier, AMFRegressor from .online_extra_trees import OXTRegressor __all__ = [ "ARFClassifier", "ARFRegressor", "AMFClassifier", + "AMFRegressor", "OXTRegressor", ] diff --git a/river/forest/aggregated_mondrian_forest.py b/river/forest/aggregated_mondrian_forest.py index 08b7f1f62b..c2a392eacb 100644 --- a/river/forest/aggregated_mondrian_forest.py +++ b/river/forest/aggregated_mondrian_forest.py @@ -4,7 +4,7 @@ import random from river import base -from river.tree.mondrian import MondrianTreeClassifier +from river.tree.mondrian import MondrianTreeClassifier, MondrianTreeRegressor class AMFLearner(base.Ensemble, abc.ABC): @@ -71,7 +71,7 @@ def _min_number_of_models(self): class AMFClassifier(AMFLearner, base.Classifier): """Aggregated Mondrian Forest classifier for online learning. - This implementation is truly online, in the sense that a single pass is performed, and that + This implementation is truly online[^1], in the sense that a single pass is performed, and that predictions can be produced anytime. Each node in a tree predicts according to the distribution of the labels @@ -139,11 +139,12 @@ class AMFClassifier(AMFLearner, base.Classifier): >>> metric = metrics.Accuracy() >>> evaluate.progressive_val_score(dataset, model, metric) - Accuracy: 84.97% + Accuracy: 85.37% References ---------- - J. Mourtada, S. Gaiffas and E. Scornet, *AMF: Aggregated Mondrian Forests for Online Learning*, arXiv:1906.10529, 2019. + [^1]: Mourtada, J., Gaïffas, S., & Scornet, E. (2021). AMF: Aggregated Mondrian forests for online + learning. Journal of the Royal Statistical Society Series B: Statistical Methodology, 83(3), 505-533. """ @@ -217,3 +218,116 @@ def predict_proba_one(self, x): @property def _multiclass(self): return True + + +class AMFRegressor(AMFLearner, base.Regressor): + """Aggregated Mondrian Forest regressor for online learning. + + This algorithm is truly online, in the sense that a single pass is performed, and that + predictions can be produced anytime. + + Each node in a tree predicts according to the average of the labels it contains. + The prediction for a sample is computed as the aggregated predictions of all the subtrees + along the path leading to the leaf node containing the sample. The aggregation weights are + exponential weights with learning rate `step` using a squared loss when `use_aggregation` + is `True`. + + This computation is performed exactly thanks to a context tree weighting algorithm. + More details can be found in the original paper[^1]. + + The final predictions are the average of the predictions of each of the + ``n_estimators`` trees in the forest. + + Parameters + ---------- + n_estimators + The number of trees in the forest. + step + Step-size for the aggregation weights. + use_aggregation + Controls if aggregation is used in the trees. It is highly recommended to + leave it as `True`. + seed + Random seed for reproducibility. + + Examples + -------- + + >>> from river import datasets + >>> from river import evaluate + >>> from river import forest + >>> from river import metrics + + >>> dataset = datasets.TrumpApproval() + >>> model = forest.AMFRegressor(seed=42) + >>> metric = metrics.MAE() + + >>> evaluate.progressive_val_score(dataset, model, metric) + MAE: 0.268533 + + References + ---------- + [^1]: Mourtada, J., Gaïffas, S., & Scornet, E. (2021). AMF: Aggregated Mondrian forests for online + learning. Journal of the Royal Statistical Society Series B: Statistical Methodology, 83(3), 505-533. + + """ + + def __init__( + self, + n_estimators: int = 10, + step: float = 1.0, + use_aggregation: bool = True, + seed: int = None, + ): + super().__init__( + n_estimators=n_estimators, + step=step, + loss="least-squares", + use_aggregation=use_aggregation, + seed=seed, + ) + + self.iteration = 0 + + def _initialize_trees(self): + """Initialize the forest.""" + + self.data: list[MondrianTreeRegressor] = [] + for _ in range(self.n_estimators): + # We don't want to have the same stochastic scheme for each tree, or it'll break the randomness + # Hence we introduce a new seed for each, that is derived of the given seed by a deterministic process + seed = self._rng.randint(0, 9999999) + + tree = MondrianTreeRegressor( + self.step, + self.use_aggregation, + self.iteration, + seed, + ) + self.data.append(tree) + + def learn_one(self, x, y): + # Checking if the forest has been created + if not self._is_initialized: + self._initialize_trees() + + # we fit all the trees using the new sample + for tree in self: + tree.learn_one(x, y) + + self.iteration += 1 + + return self + + def predict_one(self, x): + # Checking that the model has been trained once at least + if not self._is_initialized: + return None + + prediction = 0 + for tree in self: + tree.use_aggregation = self.use_aggregation + prediction += tree.predict_one(x) + prediction = prediction / self.n_estimators + + return prediction diff --git a/river/tree/mondrian/__init__.py b/river/tree/mondrian/__init__.py index c009792d3b..15b596ba75 100644 --- a/river/tree/mondrian/__init__.py +++ b/river/tree/mondrian/__init__.py @@ -3,12 +3,13 @@ implementations for the Mondrian trees. Note that this module is not exposed in the tree module, and is instead used by the -AMFClassifier class in the ensemble module. +AMFClassifier and AMFRegressor classes in the ensemble module. """ from __future__ import annotations from .mondrian_tree import MondrianTree from .mondrian_tree_classifier import MondrianTreeClassifier +from .mondrian_tree_regressor import MondrianTreeRegressor -__all__ = ["MondrianTree", "MondrianTreeClassifier"] +__all__ = ["MondrianTree", "MondrianTreeClassifier", "MondrianTreeRegressor"] diff --git a/river/tree/mondrian/mondrian_tree.py b/river/tree/mondrian/mondrian_tree.py index 6894bd5b95..f9b664ba8e 100644 --- a/river/tree/mondrian/mondrian_tree.py +++ b/river/tree/mondrian/mondrian_tree.py @@ -16,16 +16,15 @@ class MondrianTree(abc.ABC): step Step parameter of the tree. loss - Loss to minimize for each node of the tree - Pick between: "log", ... + Loss to minimize for each node of the tree. At the moment it is a placeholder. + In the future, different optimization metrics might become available. use_aggregation Whether or not the tree should it use aggregation. - split_pure - Whether or not the tree should split pure leaves when training. iteration Number of iterations to run when training. seed Random seed for reproducibility. + """ def __init__( @@ -33,7 +32,6 @@ def __init__( step: float = 0.1, loss: str = "log", use_aggregation: bool = True, - split_pure: bool = False, iteration: int = 0, seed: int | None = None, ): @@ -41,7 +39,6 @@ def __init__( self.step = step self.loss = loss self.use_aggregation = use_aggregation - self.split_pure = split_pure self.iteration = iteration # Controls the randomness in the tree diff --git a/river/tree/mondrian/mondrian_tree_classifier.py b/river/tree/mondrian/mondrian_tree_classifier.py index d2302733be..51e96a7dd2 100644 --- a/river/tree/mondrian/mondrian_tree_classifier.py +++ b/river/tree/mondrian/mondrian_tree_classifier.py @@ -1,7 +1,6 @@ from __future__ import annotations import math -import sys from river import base, utils from river.tree.mondrian.mondrian_tree import MondrianTree @@ -54,7 +53,7 @@ class MondrianTreeClassifier(MondrianTree, base.Classifier): >>> metric = metrics.Accuracy() >>> evaluate.progressive_val_score(dataset, model, metric) - Accuracy: 57.52% + Accuracy: 58.52% References ---------- @@ -76,11 +75,12 @@ def __init__( step=step, loss="log", use_aggregation=use_aggregation, - split_pure=split_pure, iteration=iteration, seed=seed, ) + self.dirichlet = dirichlet + self.split_pure = split_pure # Training attributes # The previously observed classes set @@ -107,6 +107,7 @@ def _score(self, node: MondrianNodeClassifier) -> float: ---------- node Node to evaluate the score. + """ return node.score(self._y, self.dirichlet, len(self._classes)) @@ -118,6 +119,7 @@ def _predict(self, node: MondrianNodeClassifier) -> dict[base.typing.ClfTarget, ---------- node Node to make predictions. + """ return node.predict(self.dirichlet, self._classes, len(self._classes)) @@ -129,6 +131,7 @@ def _loss(self, node: MondrianNodeClassifier) -> float: ---------- node Node to evaluate the loss. + """ return node.loss(self._y, self.dirichlet, len(self._classes)) @@ -140,6 +143,7 @@ def _update_weight(self, node: MondrianNodeClassifier) -> float: ---------- node Node to update the weight. + """ return node.update_weight( @@ -154,6 +158,7 @@ def _update_count(self, node: MondrianNodeClassifier): ---------- node Target node. + """ node.update_count(self._y) @@ -169,6 +174,7 @@ def _update_downwards( Target node. do_weight_update Whether we should update the weights or not. + """ return node.update_downwards( @@ -193,6 +199,7 @@ def _compute_split_time( ---------- node Target node. + """ # Don't split if the node is pure: all labels are equal to the one of y_t @@ -202,11 +209,7 @@ def _compute_split_time( # If x_t extends the current range of the node if extensions_sum > 0: # Sample an exponential with intensity = extensions_sum - # try catch to handle the Overflow situation in the exponential - try: - T = math.exp(1 / extensions_sum) - except OverflowError: - T = sys.float_info.max # we get the largest possible output instead + T = utils.random.exponential(1 / extensions_sum, rng=self._rng) time = node.time # Splitting time of the node (if splitting occurs) @@ -246,6 +249,7 @@ def _split( Feature of the node. is_right_extension Should we extend the tree in the right or left direction. + """ new_depth = node.depth + 1 @@ -420,6 +424,7 @@ def _go_upwards(self, leaf: MondrianLeafClassifier): ---------- leaf Leaf to start from when going upward. + """ current_node = leaf @@ -460,10 +465,12 @@ def predict_proba_one(self, x): ---------- x Feature vector. + """ # If the tree hasn't seen any sample, then it should return # the default empty dict + if not self._is_initialized: return {} diff --git a/river/tree/mondrian/mondrian_tree_nodes.py b/river/tree/mondrian/mondrian_tree_nodes.py index f262515175..d1c3a5cd37 100644 --- a/river/tree/mondrian/mondrian_tree_nodes.py +++ b/river/tree/mondrian/mondrian_tree_nodes.py @@ -3,7 +3,7 @@ import collections import math -from river import base +from river import base, stats from river.tree.base import Branch, Leaf from river.utils.math import log_sum_2_exp @@ -19,6 +19,7 @@ class MondrianLeaf(Leaf): Split time of the node for Mondrian process. depth Depth of the leaf. + """ def __init__(self, parent, time, depth): @@ -81,6 +82,7 @@ def update_depth(self, depth): ---------- depth Depth of the node. + """ self.depth = depth @@ -112,6 +114,7 @@ def range(self, feature) -> tuple[float, float]: ---------- feature Feature for which you want to know the range. + """ return ( @@ -126,6 +129,7 @@ def range_extension(self, x) -> tuple[float, dict[base.typing.ClfTarget, float]] ---------- x Sample to deal with. + """ extensions: dict[base.typing.ClfTarget, float] = {} @@ -176,6 +180,7 @@ def score(self, y: base.typing.ClfTarget, dirichlet: float, n_classes: int) -> f Notes ----- This uses Jeffreys prior with Dirichlet parameter for smoothing. + """ count = self.counts[y] @@ -197,6 +202,7 @@ def predict( The set of classes seen so far n_classes The total number of classes of the problem. + """ scores = {} @@ -215,6 +221,7 @@ def loss(self, y: base.typing.ClfTarget, dirichlet: float, n_classes: int) -> fl Dirichlet parameter of the problem. n_classes The total number of classes of the problem. + """ sc = self.score(y, dirichlet, n_classes) @@ -242,6 +249,7 @@ def update_weight( Step parameter of the tree. n_classes The total number of classes of the problem. + """ loss_t = self.loss(y, dirichlet, n_classes) @@ -257,6 +265,7 @@ def update_count(self, y): ---------- y Class of a given sample. + """ self.counts[y] += 1 @@ -269,6 +278,7 @@ def is_dirac(self, y: base.typing.ClfTarget) -> bool: ---------- y Class of a given sample. + """ return self.n_samples == self.counts[y] @@ -288,9 +298,9 @@ def update_downwards( Parameters ---------- x - Sample to proceed (as a list). + Sample to proceed. y - Class of the sample x_t. + Class of the sample x. dirichlet Dirichlet parameter of the tree. use_aggregation @@ -301,6 +311,7 @@ def update_downwards( Should we update the weights of the node as well. n_classes The total number of classes of the problem. + """ # Updating the range of the feature values known by the node @@ -339,6 +350,7 @@ class MondrianLeafClassifier(MondrianNodeClassifier, MondrianLeaf): Split time of the node. depth The depth of the leaf. + """ def __init__(self, parent, time, depth): @@ -362,6 +374,159 @@ class MondrianBranchClassifier(MondrianNodeClassifier, MondrianBranch): Acceptation threshold of the branch. *children Children nodes of the branch. + + """ + + def __init__(self, parent, time, depth, feature, threshold, *children): + super().__init__(parent, time, depth, feature, threshold, *children) + + +class MondrianNodeRegressor(MondrianNode): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.n_samples = 0 + self.mean = stats.Mean() + + def replant(self, leaf: MondrianNodeRegressor, copy_all: bool = False): + """Transfer information from a leaf to a new branch.""" + self.weight = leaf.weight # type: ignore + self.log_weight_tree = leaf.log_weight_tree # type: ignore + self.mean = leaf.mean + + if copy_all: + self.memory_range_min = leaf.memory_range_min + self.memory_range_max = leaf.memory_range_max + self.n_samples = leaf.n_samples + + def predict(self) -> base.typing.RegTarget: + """Return the prediction of the node.""" + return self.mean.get() + + def loss(self, sample_value: base.typing.RegTarget) -> float: + """Compute the loss of the node. + + Parameters + ---------- + sample_value + A given value. + + """ + + r = self.predict() - sample_value # type: ignore + return r * r / 2 + + def update_weight( + self, + sample_value: base.typing.RegTarget, + use_aggregation: bool, + step: float, + ) -> float: + """Update the weight of the node given a label and the method used. + + Parameters + ---------- + sample_value + Label of a given sample. + use_aggregation + Whether to use aggregation or not during computation (given by the tree). + step + Step parameter of the tree. + + """ + + loss_t = self.loss(sample_value) + if use_aggregation: + self.weight -= step * loss_t + return loss_t + + def update_downwards( + self, + x, + sample_value: base.typing.RegTarget, + use_aggregation: bool, + step: float, + do_update_weight: bool, + ): + """Update the node when running a downward procedure updating the tree. + + Parameters + ---------- + x + Sample to proceed (as a list). + sample_value + Label of the sample x. + use_aggregation + Should it use the aggregation or not + step + Step of the tree. + do_update_weight + Should we update the weights of the node as well. + + """ + + # Updating the range of the feature values known by the node + # If it is the first sample, we copy the features vector into the min and max range + if self.n_samples == 0: + for feature in x: + x_f = x[feature] + self.memory_range_min[feature] = x_f + self.memory_range_max[feature] = x_f + # Otherwise, we update the range + else: + for feature in x: + x_f = x[feature] + if x_f < self.memory_range_min[feature]: + self.memory_range_min[feature] = x_f + if x_f > self.memory_range_max[feature]: + self.memory_range_max[feature] = x_f + + # One more sample in the node + self.n_samples += 1 + + if do_update_weight: + self.update_weight(sample_value, use_aggregation, step) + + # Update the mean of the labels in the node online + self.mean.update(sample_value) + + +class MondrianLeafRegressor(MondrianNodeRegressor, MondrianLeaf): + """Mondrian Tree Regressor leaf node. + + Parameters + ---------- + parent + Parent node. + time + Split time of the node. + depth + The depth of the leaf. + + """ + + def __init__(self, parent, time, depth): + super().__init__(parent, time, depth) + + +class MondrianBranchRegressor(MondrianNodeRegressor, MondrianBranch): + """Mondrian Tree Regressor branch node. + + Parameters + ---------- + parent + Parent node of the branch. + time + Split time characterizing the branch. + depth + Depth of the branch in the tree. + feature + Feature of the branch. + threshold + Acceptation threshold of the branch. + *children + Children nodes of the branch. + """ def __init__(self, parent, time, depth, feature, threshold, *children): diff --git a/river/tree/mondrian/mondrian_tree_regressor.py b/river/tree/mondrian/mondrian_tree_regressor.py new file mode 100644 index 0000000000..c83d30cbbd --- /dev/null +++ b/river/tree/mondrian/mondrian_tree_regressor.py @@ -0,0 +1,428 @@ +from __future__ import annotations + +import math + +from river import base, utils +from river.tree.mondrian.mondrian_tree import MondrianTree +from river.tree.mondrian.mondrian_tree_nodes import ( + MondrianBranchRegressor, + MondrianLeafRegressor, + MondrianNodeRegressor, +) + + +class MondrianTreeRegressor(MondrianTree, base.Regressor): + """Mondrian Tree Regressor. + + Parameters + ---------- + step + Step of the tree. + use_aggregation + Whether to use aggregation weighting techniques or not. + iteration + Number iterations to do during training. + seed + Random seed for reproducibility. + + Notes + ----- + The Mondrian Tree Regressor is a type of decision tree that bases splitting decisions over a + Mondrian process. + + References + ---------- + [^1]: Balaji Lakshminarayanan, Daniel M. Roy, Yee Whye Teh. Mondrian Forests: Efficient Online Random Forests. + arXiv:1406.2673, pages 2-4. + + """ + + def __init__( + self, + step: float = 0.1, + use_aggregation: bool = True, + iteration: int = 0, + seed: int = None, + ): + super().__init__( + step=step, + loss="least-squares", + use_aggregation=use_aggregation, + iteration=iteration, + seed=seed, + ) + # Controls the randomness in the tree + self.seed = seed + + # The current sample being proceeded + self._x: dict[base.typing.FeatureName, int | float] + # The current label index being proceeded + self._y: base.typing.RegTarget + + # Initialization of the root of the tree + # It's the root so it doesn't have any parent (hence None) + self._root = MondrianLeafRegressor(None, 0.0, 0) + + def _is_initialized(self): + """Check if the tree has learnt at least one sample""" + return self.iteration != 0 + + def _predict(self, node: MondrianNodeRegressor) -> base.typing.RegTarget: + """Compute the prediction. + + Parameters + ---------- + node + Node to make predictions. + + """ + + return node.predict() # type: ignore + + def _loss(self, node: MondrianNodeRegressor) -> float: + """Compute the loss for the given node regarding the current label. + + Parameters + ---------- + node + Node to evaluate the loss. + + """ + + return node.loss(self._y) + + def _update_weight(self, node: MondrianNodeRegressor) -> float: + """Update the weight of the node regarding the current label with the tree parameters. + + Parameters + ---------- + node + Node to update the weight. + + """ + + return node.update_weight(self._y, self.use_aggregation, self.step) + + def _update_downwards(self, node: MondrianNodeRegressor, do_update_weight): + """Update the node when running a downward procedure updating the tree. + + Parameters + ---------- + node + Target node. + do_update_weight + Whether we should update the weights or not. + + """ + + return node.update_downwards( + self._x, self._y, self.use_aggregation, self.step, do_update_weight + ) + + def _compute_split_time( + self, + node: MondrianLeafRegressor | MondrianBranchRegressor, + extensions_sum: float, + ) -> float: + """Computes the split time of the given node. + + Parameters + ---------- + node + Target node. + + """ + + if extensions_sum > 0: + # Sample an exponential with intensity = extensions_sum + T = utils.random.exponential(1 / extensions_sum, rng=self._rng) + + time = node.time + # Splitting time of the node (if splitting occurs) + split_time = time + T + # If the node is a leaf we must split it + if isinstance(node, MondrianLeafRegressor): + return split_time + # Otherwise we apply Mondrian process dark magic :) + # 1. We get the creation time of the childs (left and right is the same) + left, _ = node.children + child_time = left.time + # 2. We check if splitting time occurs before child creation time + if split_time < child_time: + return split_time + + return 0.0 + + def _split( + self, + node: MondrianLeafRegressor | MondrianBranchRegressor, + split_time: float, + threshold: float, + feature: base.typing.FeatureName, + is_right_extension: bool, + ) -> MondrianBranchRegressor: + """Split the given node and attributes the split time, threshold, etc... to the node. + + Parameters + ---------- + node + Target node. + split_time + Split time of the node in the Mondrian process. + threshold + Threshold of acceptance of the node. + feature + Feature index of the node. + is_right_extension + Should we extend the tree in the right or left direction. + + """ + + new_depth = node.depth + 1 + + # To calm down mypy + left: MondrianLeafRegressor | MondrianBranchRegressor + right: MondrianLeafRegressor | MondrianBranchRegressor + + # The node is already a branch: we create a new branch above it and move the existing + # node one level down the tree + if isinstance(node, MondrianBranchRegressor): + old_left, old_right = node.children + if is_right_extension: + left = MondrianBranchRegressor( + node, split_time, new_depth, node.feature, node.threshold + ) + right = MondrianLeafRegressor(node, split_time, new_depth) + left.replant(node) + + old_left.parent = left + old_right.parent = left + + left.children = (old_left, old_right) + else: + right = MondrianBranchRegressor( + node, split_time, new_depth, node.feature, node.threshold + ) + left = MondrianLeafRegressor(node, split_time, new_depth) + right.replant(node) + + old_left.parent = right + old_right.parent = right + + right.children = (old_left, old_right) + + # Update the level of the modified nodes + new_depth += 1 + old_left.update_depth(new_depth) + old_right.update_depth(new_depth) + + # Update split info + node.feature = feature + node.threshold = threshold + node.children = (left, right) + + return node + + # We promote the leaf to a branch + branch = MondrianBranchRegressor(node.parent, node.time, node.depth, feature, threshold) + left = MondrianLeafRegressor(branch, split_time, new_depth) + right = MondrianLeafRegressor(branch, split_time, new_depth) + branch.children = (left, right) + + # Copy properties from the previous leaf + branch.replant(node, True) + + if is_right_extension: + left.replant(node) + else: + right.replant(node) + + # To avoid leaving garbage behind + del node + + return branch + + def _go_downwards(self): + """Update the tree (downward procedure).""" + + # We update the nodes along the path which leads to the leaf containing the current + # sample. For each node on the path, we consider the possibility of splitting it, + # following the Mondrian process definition. + + # We start at the root + current_node = self._root + + if self.iteration == 0: + # If it's the first iteration, we just put the current sample in the range of root + self._update_downwards(current_node, False) + return current_node + else: + # Path from the parent to the current node + branch_no = None + while True: + # Computing the extensions to get the intensities + extensions_sum, extensions = current_node.range_extension(self._x) + + # If it's not the first iteration (otherwise the current node + # is root with no range), we consider the possibility of a split + split_time = self._compute_split_time(current_node, extensions_sum) + + if split_time > 0: + # We split the current node: because the current node is a + # leaf, or because we add a new node along the path + + # We normalize the range extensions to get probabilities + intensities = utils.norm.normalize_values_in_dict(extensions, inplace=False) + + # Sample the feature at random with a probability + # proportional to the range extensions + + candidates = sorted(list(self._x.keys())) + feature = self._rng.choices( + candidates, [intensities[c] for c in candidates], k=1 + )[0] + + x_f = self._x[feature] + + # Is it a right extension of the node ? + range_min, range_max = current_node.range(feature) + is_right_extension = x_f > range_max + if is_right_extension: + threshold = self._rng.uniform(range_max, x_f) + else: + threshold = self._rng.uniform(x_f, range_min) + + was_leaf = isinstance(current_node, MondrianLeafRegressor) + + # We split the current node + current_node = self._split( + current_node, + split_time, + threshold, + feature, + is_right_extension, + ) + + # The root node has become a branch + if current_node.parent is None: + self._root = current_node + # Update path from the previous parent to the recently updated node + elif was_leaf: + parent = current_node.parent + if branch_no == 0: + parent.children = (current_node, parent.children[1]) + else: + parent.children = (parent.children[0], current_node) + + # Update the current node + self._update_downwards(current_node, True) + + left, right = current_node.children + + # Now, get the next node + if is_right_extension: + current_node = right + else: + current_node = left + + # This is the leaf containing the sample point (we've just + # splitted the current node with the data point) + leaf = current_node + self._update_downwards(leaf, False) + return leaf + else: + # There is no split, so we just update the node and go to the next one + self._update_downwards(current_node, True) + if isinstance(current_node, MondrianLeafRegressor): + return current_node + else: + # Save the path direction to keep the tree consistent + try: + branch_no = current_node.branch_no(self._x) + current_node = current_node.children[branch_no] + except KeyError: # Missing split feature + branch_no, current_node = current_node.most_common_path() + + def _go_upwards(self, leaf: MondrianLeafRegressor): + """Update the tree (upwards procedure). + + Parameters + ---------- + leaf + Leaf to start from when going upward. + + """ + + current_node = leaf + + if self.iteration >= 1: + while True: + current_node.update_weight_tree() + if current_node.parent is None: + # We arrived at the root + break + # Note that the root node is updated as well + # We go up to the root in the tree + current_node = current_node.parent + + def learn_one(self, x, y): + # Setting current sample + self._x = x + self._y = y + + # Learning step + leaf = self._go_downwards() + if self.use_aggregation: + self._go_upwards(leaf) + + # Incrementing iteration + self.iteration += 1 + return self + + def predict_one(self, x): + """Predict the label of the samples. + + Parameters + ---------- + x + Feature vector. + + """ + + # If the tree hasn't seen any sample, then it should return + # the default empty dict + if not self._is_initialized: + return + + leaf = ( + self._root.traverse(x, until_leaf=True) + if isinstance(self._root, MondrianBranchRegressor) + else self._root + ) + + if not self.use_aggregation: + return self._predict(leaf) + + current = leaf + prediction = 0.0 + + while True: + # This test is useless ? + if isinstance(current, MondrianLeafRegressor): + prediction = self._predict(current) + else: + weight = current.weight + log_weight_tree = current.log_weight_tree + w = math.exp(weight - log_weight_tree) + # Get the predictions of the current node + pred_new = self._predict(current) + prediction = 0.5 * w * pred_new + (1 - 0.5 * w) * prediction + + # Root must be updated as well + if current.parent is None: + break + + # And now we go up + current = current.parent + + return prediction diff --git a/river/utils/random.py b/river/utils/random.py index ce99b8f5c6..346eaa9591 100644 --- a/river/utils/random.py +++ b/river/utils/random.py @@ -3,7 +3,7 @@ import math import random -__all__ = ["poisson"] +__all__ = ["poisson", "exponential"] def poisson(rate: float, rng=random) -> int: @@ -29,3 +29,24 @@ def poisson(rate: float, rng=random) -> int: p *= rng.random() return k - 1 + + +def exponential(rate: float = 1.0, rng=random) -> float: + """Sample a random value from a Poisson distribution. + + Parameters + ---------- + rate + rng + + References + ---------- + [^1]: [Wikipedia article](https://www.wikiwand.com/en/Exponential_distribution#Random_variate_generation) + + """ + + u = rng.random() + + # Retrive the λ value from the rate (β): β = 1 / λ + lmbda = 1.0 / rate + return -math.log(1 - u) / lmbda