Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT honest API with a multi view #231

Merged
merged 15 commits into from
Feb 24, 2024
Merged
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ requires = [
"numpy>=1.25; python_version>='3.9'"
]

[lint.per-file-ignores]
'__init__.py' = ['F401']

[project]
name = "scikit-tree"
version = "0.7.0dev0"
Expand Down Expand Up @@ -266,10 +263,12 @@ extend-exclude = [
'validation'
]
line-length = 88
lint.ignore = ['E731']

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
[tool.ruff.lint]
ignore = ['E731']

[tool.ruff.lint.per-file-ignores]
'__init__.py' = ['F401']

[tool.spin]
package = 'sktree'
Expand Down
37 changes: 15 additions & 22 deletions sktree/ensemble/_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,6 @@ class labels (multi-output problem).
The number of classes (single output problem), or a list containing the
number of classes for each output (multi-output problem).

n_features_ : int
The number of features when ``fit`` is performed.

n_features_in_ : int
Number of features seen during :term:`fit`.

Expand Down Expand Up @@ -508,6 +505,9 @@ def fit(self, X, y, sample_weight=None, classes=None, **fit_params):

super().fit(X, y, sample_weight=sample_weight, classes=classes, **fit_params)

# Inherit attributes from the tree estimator
self._inherit_estimator_attributes()

# Compute honest decision function
self.honest_decision_function_ = self._predict_proba(
X, indices=self.honest_indices_, impute_missing=np.nan
Expand Down Expand Up @@ -536,6 +536,18 @@ def _make_estimator(self, append=True, random_state=None):

return estimator

def _inherit_estimator_attributes(self):
"""Initialize necessary attributes from the provided tree estimator"""
if hasattr(self.estimators_[0], "_inheritable_fitted_attribute"):
for attr in self.estimators_[0]._inheritable_fitted_attribute:
setattr(self, attr, getattr(self.estimators_[0], attr))

self.classes_ = self.estimators_[0].classes_
self.max_features_ = self.estimators_[0].max_features_
self.n_classes_ = self.estimators_[0].n_classes_
self.n_features_in_ = self.estimators_[0].n_features_in_
self.n_outputs_ = self.estimators_[0].n_outputs_

def predict_proba(self, X):
"""
Predict class probabilities for X.
Expand Down Expand Up @@ -638,25 +650,6 @@ def oob_samples_(self):
def _more_tags(self):
return {"multioutput": False}

def apply(self, X):
"""
Apply trees in the forest to X, return leaf indices.

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input samples. Internally, its dtype will be converted to
``dtype=np.float32``. If a sparse matrix is provided, it will be
converted into a sparse ``csr_matrix``.

Returns
-------
X_leaves : ndarray of shape (n_samples, n_estimators)
For each datapoint x in X and for each tree in the forest,
return the index of the leaf x ends up in.
"""
return self.estimator_.apply(X)

def decision_path(self, X):
"""
Return the decision path in the forest.
Expand Down
64 changes: 64 additions & 0 deletions sktree/stats/tests/test_forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,70 @@ def test_build_coleman_forest():
assert forest_result.observe_stat < 0.05, f"{forest_result.observe_stat}"


def test_build_coleman_forest_multiview():
"""Simple test for building a Coleman forest.

Test the function under alternative and null hypothesis for a very simple dataset.
"""
n_estimators = 40
n_samples = 30
n_features = 5
rng = np.random.default_rng(seed)

_X = rng.uniform(size=(n_samples, n_features))
_X = rng.uniform(size=(n_samples // 2, n_features))
X2 = _X + 3
X = np.vstack([_X, X2])
y = np.vstack(
[np.zeros((n_samples // 2, 1)), np.ones((n_samples // 2, 1))]
) # Binary classification

clf = HonestForestClassifier(
n_estimators=n_estimators,
random_state=seed,
n_jobs=-1,
honest_fraction=0.5,
bootstrap=True,
max_samples=1.6,
max_features=[1, 2],
tree_estimator=MultiViewDecisionTreeClassifier(),
feature_set_ends=[2, 5],
)
perm_clf = PermutationHonestForestClassifier(
n_estimators=n_estimators,
random_state=seed,
n_jobs=-1,
honest_fraction=0.5,
bootstrap=True,
max_samples=1.6,
max_features=[1, 2],
tree_estimator=MultiViewDecisionTreeClassifier(),
feature_set_ends=[2, 5],
)
with pytest.raises(
RuntimeError, match="Permutation forest must be a PermutationHonestForestClassifier"
):
build_coleman_forest(clf, clf, X, y)

forest_result, orig_forest_proba, perm_forest_proba, clf_fitted, perm_clf_fitted = (
build_coleman_forest(clf, perm_clf, X, y, metric="s@98", n_repeats=1000, seed=seed)
)
assert clf_fitted._n_samples_bootstrap == round(n_samples * 1.6)
assert perm_clf_fitted._n_samples_bootstrap == round(n_samples * 1.6)
assert_array_equal(perm_clf_fitted.permutation_indices_.shape, (n_samples, 1))

assert forest_result.pvalue <= 0.05, f"{forest_result.pvalue}"
assert forest_result.observe_stat > 0.1, f"{forest_result.observe_stat}"
assert_array_equal(orig_forest_proba.shape, perm_forest_proba.shape)

X = np.vstack([_X, _X])
forest_result, _, _, clf_fitted, perm_clf_fitted = build_coleman_forest(
clf, perm_clf, X, y, metric="s@98"
)
assert forest_result.pvalue > 0.05, f"{forest_result.pvalue}"
assert forest_result.observe_stat < 0.05, f"{forest_result.observe_stat}"


def test_build_permutation_forest():
"""Simple test for building a permutation forest."""
n_estimators = 30
Expand Down
38 changes: 28 additions & 10 deletions sktree/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,33 +437,51 @@ def test_honest_forest_with_sklearn_trees_with_mi():
assert_allclose(np.mean(sk_scores), np.mean(scores), atol=0.005)


def test_honest_forest_with_tree_estimator_params():
@pytest.mark.parametrize(
"tree, tree_kwargs",
[
(MultiViewDecisionTreeClassifier(), {"feature_set_ends": [10, 20]}),
(ObliqueDecisionTreeClassifier(), {"feature_combinations": 2}),
(PatchObliqueDecisionTreeClassifier(), {"max_patch_dims": 5}),
],
)
def test_honest_forest_with_tree_estimator_params(tree, tree_kwargs):
"""Test that honest forest inherits all the fitted parameters of the tree estimator."""
X = np.ones((20, 4))
X[10:] *= -1
y = [0] * 10 + [1] * 10

# test with a parameter that is a repeat of an init parameter
clf = HonestForestClassifier(
tree_estimator=DecisionTreeClassifier(),
random_state=0,
feature_set_ends=[10, 20],
tree_estimator=DecisionTreeClassifier(), random_state=0, **tree_kwargs
)
with pytest.raises(ValueError, match=r"Invalid parameter\(s\)"):
clf.fit(X, y)

# test with a parameter that is not in any init signature
clf = HonestForestClassifier(
tree_estimator=MultiViewDecisionTreeClassifier(),
tree_estimator=tree,
random_state=0,
blah=0,
)
with pytest.raises(ValueError, match=r"Invalid parameter\(s\)"):
clf.fit(X, y)

# passing in a valid argument to the tree_estimator should work
clf = HonestForestClassifier(
tree_estimator=MultiViewDecisionTreeClassifier(),
random_state=0,
feature_set_ends=[10, 20],
)
clf = HonestForestClassifier(tree_estimator=tree, random_state=0, **tree_kwargs)
clf.fit(X, y)
checked_attrs = [
"classes_",
"n_classes_",
"n_features_in_",
"n_outputs_",
]
checked_attrs + getattr(clf.estimator_, "_inheritable_fitted_attribute", [])
for attr_name in checked_attrs:
if not attr_name.startswith("_") and attr_name.endswith("_"):
if isinstance(getattr(clf, attr_name), np.ndarray):
np.testing.assert_array_equal(
getattr(clf, attr_name), getattr(clf.estimators_[0], attr_name)
)
else:
assert getattr(clf, attr_name) == getattr(clf.estimators_[0], attr_name)
22 changes: 22 additions & 0 deletions sktree/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,12 @@ def _build_tree(

return self

@property
def _inheritable_fitted_attribute(self):
return {
"feature_combinations_",
}


class ObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor):
"""An oblique decision tree Regressor.
Expand Down Expand Up @@ -1852,6 +1858,16 @@ def _more_tags(self):
allow_nan = False
return {"multilabel": True, "allow_nan": allow_nan}

@property
def _inheritable_fitted_attribute(self):
return {
"feature_combinations_",
"min_patch_dims_",
"max_patch_dims_",
"dim_contiguous_",
"data_dims_",
}


class PatchObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor):
"""A oblique decision tree regressor that operates over patches of data.
Expand Down Expand Up @@ -2747,6 +2763,12 @@ def _build_tree(
self.classes_ = self.classes_[0]
return self

@property
def _inheritable_fitted_attribute(self):
return {
"feature_combinations_",
}


class ExtraObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor):
"""An oblique decision tree Regressor.
Expand Down
6 changes: 6 additions & 0 deletions sktree/tree/_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,12 +730,18 @@ def _set_leaf_nodes(self, leaf_ids, y):

def _inherit_estimator_attributes(self):
"""Initialize necessary attributes from the provided tree estimator"""
if hasattr(self.estimator_, "_inheritable_fitted_attribute"):
for attr in self.estimator_._inheritable_fitted_attribute:
setattr(self, attr, getattr(self.estimator_, attr))

self.classes_ = self.estimator_.classes_
self.max_features_ = self.estimator_.max_features_
self.n_classes_ = self.estimator_.n_classes_
self.n_features_in_ = self.estimator_.n_features_in_
self.n_outputs_ = self.estimator_.n_outputs_
self.tree_ = self.estimator_.tree_
self.builder_ = self.estimator_.builder_
self.min_samples_split_ = self.estimator_.min_samples_split_

def _empty_leaf_correction(self, proba, pos=0):
"""Leaves with empty posteriors are assigned values.
Expand Down
16 changes: 15 additions & 1 deletion sktree/tree/_multiview.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,6 @@ def _build_tree(
max_features = 0

self.max_features_ = max_features
print(self.max_features_, self.max_features_per_set_)

if not isinstance(self.splitter, ObliqueSplitter):
splitter = SPLITTERS[self.splitter](
Expand Down Expand Up @@ -576,3 +575,18 @@ def _fit(
super()._fit(X, y, sample_weight, check_input, missing_values_in_feature_mask, classes)
self.max_features = self._max_features_arr
return self

@property
def _inheritable_fitted_attribute(self):
"""Define additional attributes to pass onto a parent meta tree-estimator.

Used for passing parameters to HonestTreeClassifier.
"""
return {
"max_features_",
"feature_combinations_",
"feature_set_ends_",
"n_feature_sets_",
"n_features_in_set_",
"max_features_per_set_",
}
31 changes: 21 additions & 10 deletions sktree/tree/tests/test_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,36 +63,47 @@ def test_toy_accuracy():
np.testing.assert_array_equal(clf.predict(X), y)


def test_honest_tree_with_tree_estimator_params():
@pytest.mark.parametrize(
"tree, tree_kwargs",
[
(MultiViewDecisionTreeClassifier(), {"feature_set_ends": [10, 20]}),
(ObliqueDecisionTreeClassifier(), {"feature_combinations": 2}),
(PatchObliqueDecisionTreeClassifier(), {"max_patch_dims": 5}),
],
)
def test_honest_tree_with_tree_estimator_params(tree, tree_kwargs):
"""Test that honest tree inherits all the fitted parameters of the tree estimator."""
X = np.ones((20, 4))
X[10:] *= -1
y = [0] * 10 + [1] * 10

# test with a parameter that is a repeat of an init parameter
clf = HonestTreeClassifier(
tree_estimator=DecisionTreeClassifier(),
random_state=0,
feature_set_ends=[10, 20],
tree_estimator=DecisionTreeClassifier(), random_state=0, **tree_kwargs
)
with pytest.raises(ValueError, match=r"Invalid parameter\(s\)"):
clf.fit(X, y)

# test with a parameter that is not in any init signature
clf = HonestTreeClassifier(
tree_estimator=MultiViewDecisionTreeClassifier(),
tree_estimator=tree,
random_state=0,
blah=0,
)
with pytest.raises(ValueError, match=r"Invalid parameter\(s\)"):
clf.fit(X, y)

# passing in a valid argument to the tree_estimator should work
clf = HonestTreeClassifier(
tree_estimator=MultiViewDecisionTreeClassifier(),
random_state=0,
feature_set_ends=[10, 20],
)
clf = HonestTreeClassifier(tree_estimator=tree, random_state=0, **tree_kwargs)
clf.fit(X, y)
for attr_name in dir(clf.estimator_):
if not attr_name.startswith("_") and attr_name.endswith("_"):
if isinstance(getattr(clf, attr_name), np.ndarray):
np.testing.assert_array_equal(
getattr(clf, attr_name), getattr(clf.estimator_, attr_name)
)
else:
assert getattr(clf, attr_name) == getattr(clf.estimator_, attr_name)


@pytest.mark.parametrize(
Expand Down
Loading