Skip to content

Commit

Permalink
Fix multi-output
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Jun 15, 2023
1 parent f831bc4 commit f1ab592
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion sktree/_lib/sklearn_fork
Submodule sklearn_fork updated 156 files
5 changes: 5 additions & 0 deletions sktree/ensemble/_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def _predict_proba(self, X, indices=None, impute_missing=None):
delayed(_accumulate_prediction)(tree, X, posteriors, lock, idx)
for tree, idx in zip(self.estimators_, indices)
)

# Normalize to unit length, due to prior weighting
posteriors = np.array(posteriors)
zero_mask = posteriors.sum(2) == 0
Expand All @@ -461,6 +462,10 @@ def _predict_proba(self, X, indices=None, impute_missing=None):
else:
posteriors[zero_mask] = impute_missing

# preserve shape of multi-outputs
if self.n_outputs_ > 1:
posteriors = [post for post in posteriors]

if len(posteriors) == 1:
return posteriors[0]
else:
Expand Down
4 changes: 1 addition & 3 deletions sktree/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,12 @@ def test_honest_decision_function(honest_fraction, val):
[HonestForestClassifier(n_estimators=10, honest_fraction=0.5, random_state=0)]
)
def test_sklearn_compatible_estimator(estimator, check):
# 1. multi-output is not fully supported
# 2. check_class_weight_classifiers is not supported since it requires sample weight
# 1. check_class_weight_classifiers is not supported since it requires sample weight
# XXX: can include this "generalization" in the future if it's useful
# zero sample weight is not "really supported" in honest subsample trees since sample weight
# for fitting the tree's splits
if check.func.__name__ in [
"check_class_weight_classifiers",
"check_classifiers_multilabel_output_format_predict_proba",
]:
pytest.skip()
check(estimator)
3 changes: 0 additions & 3 deletions sktree/tree/tests/test_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,4 @@ def test_impute_classes():

@parametrize_with_checks([HonestTreeClassifier(random_state=0)])
def test_sklearn_compatible_estimator(estimator, check):
# TODO: remove when we implement Regressor classes
# if TREE_ESTIMATORS[estimator].__name__ in TREE_CLASSIFIERS:
# pytest.skip()
check(estimator)

0 comments on commit f1ab592

Please sign in to comment.