From c8ea67f43179a43ee4918d121caaeb96b75ac66d Mon Sep 17 00:00:00 2001 From: Saulo Martiello Mastelini Date: Fri, 28 Jun 2024 16:55:39 -0300 Subject: [PATCH] Fix #1560 (ARF cornercase) --- docs/releases/unreleased.md | 4 ++++ river/tree/nodes/arf_htc_nodes.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/releases/unreleased.md b/docs/releases/unreleased.md index 2cf21a66c1..30d582211e 100644 --- a/docs/releases/unreleased.md +++ b/docs/releases/unreleased.md @@ -5,3 +5,7 @@ This release makes Polars an optional dependency instead of a required one. ## cluster - Added `ODAC` (Online Divisive-Agglomerative Clustering) for clustering time series. + +## forest + +- Fix error in `forest.ARFClassifer` and `forest.ARFRegressor` where the algorithms would crash in case the number of features available for learning went below the value of the `max_features` parameter (#1560). diff --git a/river/tree/nodes/arf_htc_nodes.py b/river/tree/nodes/arf_htc_nodes.py index f40cc965c9..f25eb50475 100644 --- a/river/tree/nodes/arf_htc_nodes.py +++ b/river/tree/nodes/arf_htc_nodes.py @@ -44,7 +44,9 @@ def _iter_features(self, x) -> typing.Iterable: yield att_id, x[att_id] def _sample_features(self, x, max_features): - return self.rng.sample(sorted(x.keys()), k=max_features) + if len(x) >= max_features: + return self.rng.sample(sorted(x.keys()), k=max_features) + return sorted(x.keys()) class RandomLeafMajorityClass(BaseRandomLeaf, LeafMajorityClass):