From f1ab592c9d01e0e8a6f5a514f69d5413754b98fc Mon Sep 17 00:00:00 2001 From: Adam Li Date: Thu, 15 Jun 2023 14:43:51 -0400 Subject: [PATCH] Fix multi-output Signed-off-by: Adam Li --- sktree/_lib/sklearn_fork | 2 +- sktree/ensemble/_honest_forest.py | 5 +++++ sktree/tests/test_honest_forest.py | 4 +--- sktree/tree/tests/test_honest_tree.py | 3 --- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sktree/_lib/sklearn_fork b/sktree/_lib/sklearn_fork index 45320b4d3..545e2a298 160000 --- a/sktree/_lib/sklearn_fork +++ b/sktree/_lib/sklearn_fork @@ -1 +1 @@ -Subproject commit 45320b4d3ef05b4ccbe81e8c13676b1c755d1973 +Subproject commit 545e2a298ab403262e00a16f4d85ccde1c2a250b diff --git a/sktree/ensemble/_honest_forest.py b/sktree/ensemble/_honest_forest.py index f843051a7..c38a09c65 100644 --- a/sktree/ensemble/_honest_forest.py +++ b/sktree/ensemble/_honest_forest.py @@ -451,6 +451,7 @@ def _predict_proba(self, X, indices=None, impute_missing=None): delayed(_accumulate_prediction)(tree, X, posteriors, lock, idx) for tree, idx in zip(self.estimators_, indices) ) + # Normalize to unit length, due to prior weighting posteriors = np.array(posteriors) zero_mask = posteriors.sum(2) == 0 @@ -461,6 +462,10 @@ def _predict_proba(self, X, indices=None, impute_missing=None): else: posteriors[zero_mask] = impute_missing + # preserve shape of multi-outputs + if self.n_outputs_ > 1: + posteriors = [post for post in posteriors] + if len(posteriors) == 1: return posteriors[0] else: diff --git a/sktree/tests/test_honest_forest.py b/sktree/tests/test_honest_forest.py index 6c339d124..d803c0b44 100644 --- a/sktree/tests/test_honest_forest.py +++ b/sktree/tests/test_honest_forest.py @@ -138,14 +138,12 @@ def test_honest_decision_function(honest_fraction, val): [HonestForestClassifier(n_estimators=10, honest_fraction=0.5, random_state=0)] ) def test_sklearn_compatible_estimator(estimator, check): - # 1. multi-output is not fully supported - # 2. check_class_weight_classifiers is not supported since it requires sample weight + # 1. check_class_weight_classifiers is not supported since it requires sample weight # XXX: can include this "generalization" in the future if it's useful # zero sample weight is not "really supported" in honest subsample trees since sample weight # for fitting the tree's splits if check.func.__name__ in [ "check_class_weight_classifiers", - "check_classifiers_multilabel_output_format_predict_proba", ]: pytest.skip() check(estimator) diff --git a/sktree/tree/tests/test_honest_tree.py b/sktree/tree/tests/test_honest_tree.py index f8068e758..ace165206 100644 --- a/sktree/tree/tests/test_honest_tree.py +++ b/sktree/tree/tests/test_honest_tree.py @@ -106,7 +106,4 @@ def test_impute_classes(): @parametrize_with_checks([HonestTreeClassifier(random_state=0)]) def test_sklearn_compatible_estimator(estimator, check): - # TODO: remove when we implement Regressor classes - # if TREE_ESTIMATORS[estimator].__name__ in TREE_CLASSIFIERS: - # pytest.skip() check(estimator)