-
-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-output support to honest trees #86
Changes from 19 commits
7005dfb
2c3606b
f831bc4
f1ab592
dd0662d
8365c33
b259441
16f8561
2e00358
fd6643a
c182cad
72306d0
ee14b64
d4d2337
c14f629
2feb117
f172880
a31e1de
0fc8ec4
0f773e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -360,12 +360,6 @@ def fit(self, X, y, sample_weight=None, check_input=True): | |
) | ||
self._inherit_estimator_attributes() | ||
|
||
if self.n_outputs_ > 1: | ||
raise NotImplementedError( | ||
"Multi-target honest trees not yet \ | ||
implemented" | ||
) | ||
|
||
# update the number of classes, unsplit | ||
if y.ndim == 1: | ||
# reshape is necessary to preserve the data contiguity against vs | ||
|
@@ -419,8 +413,8 @@ def _set_leaf_nodes(self, leaf_ids, y): | |
classes are ordered by their index in the tree_.value array. | ||
""" | ||
self.tree_.value[:, :, :] = 0 | ||
for leaf_id, yval in zip(leaf_ids, y[self.honest_indices_, 0]): | ||
self.tree_.value[leaf_id][0, yval] += 1 | ||
for leaf_id, yval in zip(leaf_ids, y[self.honest_indices_, :]): | ||
self.tree_.value[leaf_id][:, yval] += 1 | ||
|
||
def _inherit_estimator_attributes(self): | ||
"""Initialize necessary attributes from the provided tree estimator""" | ||
|
@@ -431,29 +425,36 @@ def _inherit_estimator_attributes(self): | |
self.n_outputs_ = self.estimator_.n_outputs_ | ||
self.tree_ = self.estimator_.tree_ | ||
|
||
def _empty_leaf_correction(self, proba, normalizer): | ||
"""Leaves with empty posteriors are assigned values""" | ||
def _empty_leaf_correction(self, proba, pos=0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you just add a short docstring to describe what's going on? I'm reading these lines and having trouble figuring out what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"""Leaves with empty posteriors are assigned values. | ||
|
||
The posteriors are corrected according to the honest prior. | ||
In multi-output cases, the posterior corrections only correspond | ||
to the respective y dimension, indicated by the position param pos. | ||
""" | ||
zero_mask = proba.sum(axis=1) == 0.0 | ||
if self.honest_prior == "empirical": | ||
proba[zero_mask] = self.empirical_prior_ | ||
elif self.honest_prior == "uniform": | ||
proba[zero_mask] = 1 / self.n_classes_ | ||
elif self.honest_prior == "ignore": | ||
proba[zero_mask] = np.nan | ||
else: | ||
raise ValueError(f"honest_prior {self.honest_prior} not a valid input.") | ||
|
||
# For multi-output cases | ||
if self.n_outputs_ > 1: | ||
if self.honest_prior == "empirical": | ||
proba[zero_mask] = self.empirical_prior_[pos] | ||
elif self.honest_prior == "uniform": | ||
proba[zero_mask] = 1 / self.n_classes_[pos] | ||
elif self.honest_prior == "ignore": | ||
proba[zero_mask] = np.nan | ||
else: | ||
raise ValueError(f"honest_prior {self.honest_prior} not a valid input.") | ||
else: | ||
if self.honest_prior == "empirical": | ||
proba[zero_mask] = self.empirical_prior_ | ||
elif self.honest_prior == "uniform": | ||
proba[zero_mask] = 1 / self.n_classes_ | ||
elif self.honest_prior == "ignore": | ||
proba[zero_mask] = np.nan | ||
else: | ||
raise ValueError(f"honest_prior {self.honest_prior} not a valid input.") | ||
return proba | ||
|
||
def _impute_missing_classes(self, proba): | ||
"""Due to splitting, provide proba outputs for some classes""" | ||
new_proba = np.zeros((proba.shape[0], self.n_classes_)) | ||
for i, old_class in enumerate(self._tree_classes_): | ||
j = np.where(self.classes_ == old_class)[0][0] | ||
new_proba[:, j] = proba[:, i] | ||
|
||
return new_proba | ||
|
||
def predict_proba(self, X, check_input=True): | ||
"""Predict class probabilities of the input samples X. | ||
|
||
|
@@ -487,17 +488,22 @@ class in a leaf. | |
normalizer = proba.sum(axis=1)[:, np.newaxis] | ||
normalizer[normalizer == 0.0] = 1.0 | ||
proba /= normalizer | ||
if self._tree_n_classes_ != self.n_classes_: | ||
proba = self._impute_missing_classes(proba) | ||
proba = self._empty_leaf_correction(proba, normalizer) | ||
proba = self._empty_leaf_correction(proba) | ||
|
||
return proba | ||
|
||
else: | ||
raise NotImplementedError( | ||
"Multi-target honest trees not yet \ | ||
implemented" | ||
) | ||
all_proba = [] | ||
|
||
for k in range(self.n_outputs_): | ||
proba_k = proba[:, k, : self._tree_n_classes_[k]] | ||
normalizer = proba_k.sum(axis=1)[:, np.newaxis] | ||
normalizer[normalizer == 0.0] = 1.0 | ||
proba_k /= normalizer | ||
proba_k = self._empty_leaf_correction(proba_k, k) | ||
all_proba.append(proba_k) | ||
|
||
return all_proba | ||
|
||
def predict(self, X, check_input=True): | ||
"""Predict class for X. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't set global random seed. I just realized a lot of tests are doing this. This actually isn't thread safe which is an issue for Cythonized code.
Instead you can set a global
seed = 12345
and then for each place-in fornp.random.
, you runrng = np.random.default_rng(seed)
and userng
in place ofnp.random
Can you do this everywhere in the file (I think just 8 places that uses
np.random.seed
)?Here's a ref talking about it: https://albertcthomas.github.io/good-practices-random-number-generators/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do.