Skip to content

Commit

Permalink
MAINT honest API with a multi view (#231)
Browse files Browse the repository at this point in the history
* Ensure HonestForest and HonestTree inherit fitted attributes from any decision tree we can use in scikit-tree

---------

Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Feb 24, 2024
1 parent 928a855 commit 48ca383
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 49 deletions.
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ requires = [
"numpy>=1.25; python_version>='3.9'"
]

[lint.per-file-ignores]
'__init__.py' = ['F401']

[project]
name = "scikit-tree"
version = "0.7.0dev0"
Expand Down Expand Up @@ -266,10 +263,12 @@ extend-exclude = [
'validation'
]
line-length = 88
lint.ignore = ['E731']

[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
[tool.ruff.lint]
ignore = ['E731']

[tool.ruff.lint.per-file-ignores]
'__init__.py' = ['F401']

[tool.spin]
package = 'sktree'
Expand Down
31 changes: 9 additions & 22 deletions sktree/ensemble/_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,6 @@ class labels (multi-output problem).
The number of classes (single output problem), or a list containing the
number of classes for each output (multi-output problem).
n_features_ : int
The number of features when ``fit`` is performed.
n_features_in_ : int
Number of features seen during :term:`fit`.
Expand Down Expand Up @@ -508,6 +505,9 @@ def fit(self, X, y, sample_weight=None, classes=None, **fit_params):

super().fit(X, y, sample_weight=sample_weight, classes=classes, **fit_params)

# Inherit attributes from the tree estimator
self._inherit_estimator_attributes()

# Compute honest decision function
self.honest_decision_function_ = self._predict_proba(
X, indices=self.honest_indices_, impute_missing=np.nan
Expand Down Expand Up @@ -536,6 +536,12 @@ def _make_estimator(self, append=True, random_state=None):

return estimator

def _inherit_estimator_attributes(self):
"""Initialize necessary attributes from the provided tree estimator"""
if hasattr(self.tree_estimator, "_inheritable_fitted_attribute"):
for attr in self.tree_estimator._inheritable_fitted_attribute:
setattr(self, attr, getattr(self.estimators_[0], attr))

def predict_proba(self, X):
"""
Predict class probabilities for X.
Expand Down Expand Up @@ -638,25 +644,6 @@ def oob_samples_(self):
def _more_tags(self):
return {"multioutput": False}

def apply(self, X):
"""
Apply trees in the forest to X, return leaf indices.
Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
The input samples. Internally, its dtype will be converted to
``dtype=np.float32``. If a sparse matrix is provided, it will be
converted into a sparse ``csr_matrix``.
Returns
-------
X_leaves : ndarray of shape (n_samples, n_estimators)
For each datapoint x in X and for each tree in the forest,
return the index of the leaf x ends up in.
"""
return self.estimator_.apply(X)

def decision_path(self, X):
"""
Return the decision path in the forest.
Expand Down
63 changes: 63 additions & 0 deletions sktree/stats/tests/test_forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,69 @@ def test_build_coleman_forest():
assert forest_result.observe_stat < 0.05, f"{forest_result.observe_stat}"


def test_build_coleman_forest_multiview():
"""Simple test for building a Coleman forest.
Test the function under alternative and null hypothesis for a very simple dataset.
"""
n_estimators = 40
n_samples = 30
n_features = 5
rng = np.random.default_rng(seed)

_X = rng.uniform(size=(n_samples, n_features))
_X = rng.uniform(size=(n_samples // 2, n_features))
X2 = _X + 3
X = np.vstack([_X, X2])
y = np.vstack(
[np.zeros((n_samples // 2, 1)), np.ones((n_samples // 2, 1))]
) # Binary classification

clf = HonestForestClassifier(
n_estimators=n_estimators,
random_state=seed,
n_jobs=-1,
honest_fraction=0.5,
bootstrap=True,
max_samples=1.6,
max_features=[1, 1],
tree_estimator=MultiViewDecisionTreeClassifier(),
feature_set_ends=[2, 5],
)
perm_clf = PermutationHonestForestClassifier(
n_estimators=n_estimators,
random_state=seed,
n_jobs=-1,
honest_fraction=0.5,
bootstrap=True,
max_samples=1.6,
max_features=[1, 1],
tree_estimator=MultiViewDecisionTreeClassifier(),
feature_set_ends=[2, 5],
)
with pytest.raises(
RuntimeError, match="Permutation forest must be a PermutationHonestForestClassifier"
):
build_coleman_forest(clf, clf, X, y)

forest_result, orig_forest_proba, perm_forest_proba, clf_fitted, perm_clf_fitted = (
build_coleman_forest(clf, perm_clf, X, y, metric="s@98", n_repeats=1000, seed=seed)
)
assert clf_fitted._n_samples_bootstrap == round(n_samples * 1.6)
assert perm_clf_fitted._n_samples_bootstrap == round(n_samples * 1.6)
assert_array_equal(perm_clf_fitted.permutation_indices_.shape, (n_samples, 1))

assert forest_result.pvalue <= 0.05, f"{forest_result.pvalue}"
assert forest_result.observe_stat > 0.1, f"{forest_result.observe_stat}"
assert_array_equal(orig_forest_proba.shape, perm_forest_proba.shape)

X = np.vstack([_X, _X])
forest_result, _, _, clf_fitted, perm_clf_fitted = build_coleman_forest(
clf, perm_clf, X, y, metric="s@98"
)
assert forest_result.pvalue > 0.05, f"{forest_result.pvalue}"


def test_build_permutation_forest():
"""Simple test for building a permutation forest."""
n_estimators = 30
Expand Down
38 changes: 28 additions & 10 deletions sktree/tests/test_honest_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,33 +437,51 @@ def test_honest_forest_with_sklearn_trees_with_mi():
assert_allclose(np.mean(sk_scores), np.mean(scores), atol=0.005)


def test_honest_forest_with_tree_estimator_params():
@pytest.mark.parametrize(
"tree, tree_kwargs",
[
(MultiViewDecisionTreeClassifier(), {"feature_set_ends": [10, 20]}),
(ObliqueDecisionTreeClassifier(), {"feature_combinations": 2}),
(PatchObliqueDecisionTreeClassifier(), {"max_patch_dims": 5}),
],
)
def test_honest_forest_with_tree_estimator_params(tree, tree_kwargs):
"""Test that honest forest inherits all the fitted parameters of the tree estimator."""
X = np.ones((20, 4))
X[10:] *= -1
y = [0] * 10 + [1] * 10

# test with a parameter that is a repeat of an init parameter
clf = HonestForestClassifier(
tree_estimator=DecisionTreeClassifier(),
random_state=0,
feature_set_ends=[10, 20],
tree_estimator=DecisionTreeClassifier(), random_state=0, **tree_kwargs
)
with pytest.raises(ValueError, match=r"Invalid parameter\(s\)"):
clf.fit(X, y)

# test with a parameter that is not in any init signature
clf = HonestForestClassifier(
tree_estimator=MultiViewDecisionTreeClassifier(),
tree_estimator=tree,
random_state=0,
blah=0,
)
with pytest.raises(ValueError, match=r"Invalid parameter\(s\)"):
clf.fit(X, y)

# passing in a valid argument to the tree_estimator should work
clf = HonestForestClassifier(
tree_estimator=MultiViewDecisionTreeClassifier(),
random_state=0,
feature_set_ends=[10, 20],
)
clf = HonestForestClassifier(tree_estimator=tree, random_state=0, **tree_kwargs)
clf.fit(X, y)
checked_attrs = [
"classes_",
"n_classes_",
"n_features_in_",
"n_outputs_",
]
checked_attrs + getattr(tree, "_inheritable_fitted_attribute", [])
for attr_name in checked_attrs:
if not attr_name.startswith("_") and attr_name.endswith("_"):
if isinstance(getattr(clf, attr_name), np.ndarray):
np.testing.assert_array_equal(
getattr(clf, attr_name), getattr(clf.estimators_[0], attr_name)
)
else:
assert getattr(clf, attr_name) == getattr(clf.estimators_[0], attr_name)
22 changes: 22 additions & 0 deletions sktree/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,12 @@ def _build_tree(

return self

@property
def _inheritable_fitted_attribute(self):
return [
"feature_combinations_",
]


class ObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor):
"""An oblique decision tree Regressor.
Expand Down Expand Up @@ -1852,6 +1858,16 @@ def _more_tags(self):
allow_nan = False
return {"multilabel": True, "allow_nan": allow_nan}

@property
def _inheritable_fitted_attribute(self):
return [
"feature_combinations_",
"min_patch_dims_",
"max_patch_dims_",
"dim_contiguous_",
"data_dims_",
]


class PatchObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor):
"""A oblique decision tree regressor that operates over patches of data.
Expand Down Expand Up @@ -2747,6 +2763,12 @@ def _build_tree(
self.classes_ = self.classes_[0]
return self

@property
def _inheritable_fitted_attribute(self):
return [
"feature_combinations_",
]


class ExtraObliqueDecisionTreeRegressor(SimMatrixMixin, DecisionTreeRegressor):
"""An oblique decision tree Regressor.
Expand Down
8 changes: 8 additions & 0 deletions sktree/tree/_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,13 +730,21 @@ def _set_leaf_nodes(self, leaf_ids, y):

def _inherit_estimator_attributes(self):
"""Initialize necessary attributes from the provided tree estimator"""
if hasattr(self.estimator_, "_inheritable_fitted_attribute"):
for attr in self.estimator_._inheritable_fitted_attribute:
setattr(self, attr, getattr(self.estimator_, attr))

self.classes_ = self.estimator_.classes_
self.max_features_ = self.estimator_.max_features_
self.n_classes_ = self.estimator_.n_classes_
self.n_features_in_ = self.estimator_.n_features_in_
self.n_outputs_ = self.estimator_.n_outputs_
self.tree_ = self.estimator_.tree_

# XXX: scikit-learn trees do not store their builder, or min_samples_split_
self.builder_ = getattr(self.estimator_, "builder_", None)
self.min_samples_split_ = getattr(self.estimator_, "min_samples_split_", None)

def _empty_leaf_correction(self, proba, pos=0):
"""Leaves with empty posteriors are assigned values.
Expand Down
16 changes: 15 additions & 1 deletion sktree/tree/_multiview.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,6 @@ def _build_tree(
max_features = 0

self.max_features_ = max_features
print(self.max_features_, self.max_features_per_set_)

if not isinstance(self.splitter, ObliqueSplitter):
splitter = SPLITTERS[self.splitter](
Expand Down Expand Up @@ -576,3 +575,18 @@ def _fit(
super()._fit(X, y, sample_weight, check_input, missing_values_in_feature_mask, classes)
self.max_features = self._max_features_arr
return self

@property
def _inheritable_fitted_attribute(self):
"""Define additional attributes to pass onto a parent meta tree-estimator.
Used for passing parameters to HonestTreeClassifier.
"""
return [
"max_features_",
"feature_combinations_",
"feature_set_ends_",
"n_feature_sets_",
"n_features_in_set_",
"max_features_per_set_",
]
31 changes: 21 additions & 10 deletions sktree/tree/tests/test_honest_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,36 +63,47 @@ def test_toy_accuracy():
np.testing.assert_array_equal(clf.predict(X), y)


def test_honest_tree_with_tree_estimator_params():
@pytest.mark.parametrize(
"tree, tree_kwargs",
[
(MultiViewDecisionTreeClassifier(), {"feature_set_ends": [10, 20]}),
(ObliqueDecisionTreeClassifier(), {"feature_combinations": 2}),
(PatchObliqueDecisionTreeClassifier(), {"max_patch_dims": 5}),
],
)
def test_honest_tree_with_tree_estimator_params(tree, tree_kwargs):
"""Test that honest tree inherits all the fitted parameters of the tree estimator."""
X = np.ones((20, 4))
X[10:] *= -1
y = [0] * 10 + [1] * 10

# test with a parameter that is a repeat of an init parameter
clf = HonestTreeClassifier(
tree_estimator=DecisionTreeClassifier(),
random_state=0,
feature_set_ends=[10, 20],
tree_estimator=DecisionTreeClassifier(), random_state=0, **tree_kwargs
)
with pytest.raises(ValueError, match=r"Invalid parameter\(s\)"):
clf.fit(X, y)

# test with a parameter that is not in any init signature
clf = HonestTreeClassifier(
tree_estimator=MultiViewDecisionTreeClassifier(),
tree_estimator=tree,
random_state=0,
blah=0,
)
with pytest.raises(ValueError, match=r"Invalid parameter\(s\)"):
clf.fit(X, y)

# passing in a valid argument to the tree_estimator should work
clf = HonestTreeClassifier(
tree_estimator=MultiViewDecisionTreeClassifier(),
random_state=0,
feature_set_ends=[10, 20],
)
clf = HonestTreeClassifier(tree_estimator=tree, random_state=0, **tree_kwargs)
clf.fit(X, y)
for attr_name in dir(clf.estimator_):
if not attr_name.startswith("_") and attr_name.endswith("_"):
if isinstance(getattr(clf, attr_name), np.ndarray):
np.testing.assert_array_equal(
getattr(clf, attr_name), getattr(clf.estimator_, attr_name)
)
else:
assert getattr(clf, attr_name) == getattr(clf.estimator_, attr_name)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 48ca383

Please sign in to comment.