Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stratify sampling when split train/test data #143

Merged
merged 28 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e030050
Startify sampling when split tran/test data
YuxinB Oct 12, 2023
5d60959
Stratified_Sample, Let startify = None for Regressor
YuxinB Oct 12, 2023
c3df52e
Merge branch 'main' into Stratified_Sample
PSSF23 Oct 16, 2023
9a15c69
Merge branch 'main' into Stratified_Sample
adam2392 Oct 16, 2023
78837d2
FIX correct changes & black format
PSSF23 Oct 17, 2023
4f88518
DOC modify warning text
PSSF23 Oct 17, 2023
ffb8136
Add unit test for verifying stratified sampling
YuxinB Oct 17, 2023
8fa1277
Merge branch 'Stratified_Sample' of https://github.com/neurodata/scik…
YuxinB Oct 17, 2023
3ff6340
Correct Typo for Stratified
YuxinB Oct 17, 2023
3a67779
Merge branch 'main' into Stratified_Sample
sampan501 Oct 17, 2023
70a14a5
Fixed example and whatsnew
adam2392 Oct 18, 2023
98fbe5f
Merge branch 'main' into Stratified_Sample
adam2392 Oct 18, 2023
f555e2c
ENH correct tests & add coverage
PSSF23 Oct 18, 2023
4595df3
FIX change n_samples for test to be valid
PSSF23 Oct 18, 2023
30b6d3e
DOC update name for MIGHT & black format
PSSF23 Oct 18, 2023
9a7459d
FIX update the test for stratification
PSSF23 Oct 18, 2023
e0cbb60
FIX correct test variables
PSSF23 Oct 18, 2023
e248a7c
FIX correct variable shape
PSSF23 Oct 19, 2023
8ba06ef
FIX correct test method
PSSF23 Oct 19, 2023
5d516a7
FIX disable check_input for correct error
PSSF23 Oct 19, 2023
735a10b
FIX remove duplicate checks
PSSF23 Oct 19, 2023
47857c3
DOC add docstring for stratify
PSSF23 Oct 19, 2023
3ce68e7
Merge branch 'main' into Stratified_Sample
PSSF23 Oct 19, 2023
888cb42
Merge branch 'main' into Stratified_Sample
sampan501 Oct 19, 2023
35eb776
Add contributor
YuxinB Oct 19, 2023
9e2ba9e
Merge branch 'Stratified_Sample' of https://github.com/neurodata/scik…
YuxinB Oct 19, 2023
3332e9a
DOC update reference
PSSF23 Oct 19, 2023
3bc05b5
Merge branch 'Stratified_Sample' of https://github.com/neurodata/scik…
PSSF23 Oct 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/whats_new/v0.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Changelog
- |Fix| Fixes a bug in consistency of train/test samples when ``random_state`` is not set in FeatureImportanceForestClassifier and FeatureImportanceForestRegressor, by `Adam Li`_ (:pr:`135`)
- |Fix| Fixes a bug where covariate indices were not shuffled by default when running FeatureImportanceForestClassifier and FeatureImportanceForestRegressor test methods, by `Sambit Panda`_ (:pr:`140`)
- |Enhancement| Add multi-view splitter for axis-aligned decision trees, by `Adam Li`_ (:pr:`129`)
- |Enhancement| Add stratified sampling option to ``FeatureImportance*`` via the ``stratify`` keyword argument, by `Yuxin Bai`_ (:pr:`143`)

Code and Documentation Contributors
-----------------------------------
Expand All @@ -24,4 +25,4 @@ the project since version inception, including:

* `Adam Li`_
* `Sambit Panda`_

* `Yuxin Bai`_
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
===========================================================
Mutual Information for Gigantic Hypothesis Testing (MIGHT)
===========================================================
=========================================================
Mutual Information for Genuine Hypothesis Testing (MIGHT)
=========================================================

An example using :class:`~sktree.stats.FeatureImportanceForestClassifier` for nonparametric
multivariate hypothesis test, on simulated datasets. Here, we present a simulation
Expand Down Expand Up @@ -49,8 +49,8 @@
# We simulate the two feature sets, and the target variable. We then combine them
# into a single dataset to perform hypothesis testing.

n_samples = 1000
n_features_set = 500
n_samples = 2000
n_features_set = 20
mean = 1.0
sigma = 2.0
beta = 5.0
Expand Down Expand Up @@ -91,7 +91,7 @@
# computed as the proportion of samples in the null distribution that are less than the
# observed test statistic.

n_estimators = 200
n_estimators = 100
max_features = "sqrt"
test_size = 0.2
n_repeats = 1000
Expand All @@ -103,12 +103,12 @@
max_features=max_features,
tree_estimator=DecisionTreeClassifier(),
random_state=seed,
honest_fraction=0.7,
honest_fraction=0.25,
n_jobs=n_jobs,
),
random_state=seed,
test_size=test_size,
permute_per_tree=True,
permute_per_tree=False,
sample_dataset_per_tree=False,
)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy>=1.25
scipy
scikit-learn>=1.3.1

40 changes: 26 additions & 14 deletions sktree/stats/forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ def __init__(
test_size=0.2,
permute_per_tree=True,
sample_dataset_per_tree=True,
stratify=True,
):
self.estimator = estimator
self.random_state = random_state
self.verbose = verbose
self.test_size = test_size
self.permute_per_tree = permute_per_tree
self.sample_dataset_per_tree = sample_dataset_per_tree
self.stratify = stratify

self.n_samples_test_ = None
self._n_samples_ = None
Expand Down Expand Up @@ -160,8 +162,9 @@ def reset(self):
self.n_features_in_ = None
self._is_fitted = False
self._seeds = None
self._y = None

def _get_estimators_indices(self, sample_separate=False):
def _get_estimators_indices(self, stratifier=None, sample_separate=False):
indices = np.arange(self._n_samples_, dtype=int)

# Get drawn indices along both sample and feature axes
Expand Down Expand Up @@ -191,7 +194,11 @@ def _get_estimators_indices(self, sample_separate=False):
# Operations accessing random_state must be performed identically
# to those in `_parallel_build_trees()`
indices_train, indices_test = train_test_split(
indices, test_size=self.test_size, shuffle=True, random_state=seed
indices,
test_size=self.test_size,
shuffle=True,
stratify=stratifier,
random_state=seed,
)

yield indices_train, indices_test
Expand All @@ -202,12 +209,13 @@ def _get_estimators_indices(self, sample_separate=False):
else:
self._seeds = self.estimator_.random_state

# TODO: make random_state consistent
indices_train, indices_test = train_test_split(
indices,
test_size=self.test_size,
stratify=stratifier,
random_state=self._seeds,
)

for _ in self.estimator_.estimators_:
yield indices_train, indices_test

Expand All @@ -227,9 +235,12 @@ def train_test_samples_(self):
if self._n_samples_ is None:
raise RuntimeError("The estimator must be fitted before accessing this attribute.")

# Stratifier uses a cached _y attribute if available
stratifier = self._y if is_classifier(self.estimator_) and self.stratify else None

return [
(indices_train, indices_test)
for indices_train, indices_test in self._get_estimators_indices()
for indices_train, indices_test in self._get_estimators_indices(stratifier=stratifier)
]

def _statistic(
Expand Down Expand Up @@ -329,6 +340,8 @@ def statistic(

if self._n_samples_ is None:
self._n_samples_, self.n_features_in_ = X.shape

# Infer type of target y
if self._type_of_target_ is None:
self._type_of_target_ = type_of_target(y)

Expand All @@ -339,9 +352,9 @@ def statistic(
self.permuted_estimator_ = self._get_estimator()
estimator = self.permuted_estimator_

# Infer type of target y
if not hasattr(self, "_type_of_target"):
self._type_of_target_ = type_of_target(y)
# Store a cache of the y variable
if is_classifier(self._get_estimator()):
self._y = y.copy()

# XXX: this can be improved as an extra fit can be avoided, by just doing error-checking
# and then setting the internal meta data structures
Expand Down Expand Up @@ -462,10 +475,10 @@ def test(
observe_posteriors = self.observe_posteriors_
observe_stat = self.observe_stat_

# next permute the data
if covariate_index is None:
covariate_index = np.arange(X.shape[1], dtype=int)

adam2392 marked this conversation as resolved.
Show resolved Hide resolved
# next permute the data
permute_stat, permute_posteriors, permute_samples = self.statistic(
X,
y,
Expand Down Expand Up @@ -632,6 +645,7 @@ def __init__(
test_size=test_size,
permute_per_tree=permute_per_tree,
sample_dataset_per_tree=sample_dataset_per_tree,
stratify=False,
)

def _get_estimator(self):
Expand Down Expand Up @@ -724,9 +738,7 @@ def _statistic(
self.permute_per_tree,
self._type_of_target_,
)
for idx, (indices_train, indices_test) in enumerate(
self._get_estimators_indices(sample_separate=True)
)
for idx, (indices_train, indices_test) in enumerate(self.train_test_samples_)
)
else:
# fitting a forest will only get one unique train/test split
Expand Down Expand Up @@ -877,6 +889,7 @@ def __init__(
test_size=0.2,
permute_per_tree=True,
sample_dataset_per_tree=True,
stratify=True,
):
super().__init__(
estimator=estimator,
Expand All @@ -885,6 +898,7 @@ def __init__(
test_size=test_size,
permute_per_tree=permute_per_tree,
sample_dataset_per_tree=sample_dataset_per_tree,
stratify=stratify,
)

def _get_estimator(self):
Expand Down Expand Up @@ -945,9 +959,7 @@ def _statistic(
self.permute_per_tree,
self._type_of_target_,
)
for idx, (indices_train, indices_test) in enumerate(
self._get_estimators_indices(sample_separate=True)
)
for idx, (indices_train, indices_test) in enumerate(self.train_test_samples_)
)
else:
# fitting a forest will only get one unique train/test split
Expand Down
32 changes: 32 additions & 0 deletions sktree/stats/tests/test_forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,38 @@ def test_featureimportance_forest_permute_pertree(sample_dataset_per_tree):
est.statistic(iris_X[:n_samples], iris_y[:n_samples], [0, 1.0], metric="mi")


@pytest.mark.parametrize("sample_dataset_per_tree", [True, False])
def test_featureimportance_forest_stratified(sample_dataset_per_tree):
est = FeatureImportanceForestClassifier(
estimator=RandomForestClassifier(
n_estimators=10,
random_state=seed,
),
permute_per_tree=True,
test_size=0.7,
random_state=seed,
sample_dataset_per_tree=sample_dataset_per_tree,
)
n_samples = 100
est.statistic(iris_X[:n_samples], iris_y[:n_samples], metric="mi")

_, indices_test = est.train_test_samples_[0]
y_test = iris_y[indices_test]

assert len(y_test[y_test == 0]) == len(y_test[y_test == 1]), (
f"{len(y_test[y_test==0])} " f"{len(y_test[y_test==1])}"
)

est.test(iris_X[:n_samples], iris_y[:n_samples], [0, 1], n_repeats=10, metric="mi")

_, indices_test = est.train_test_samples_[0]
y_test = iris_y[indices_test]

assert len(y_test[y_test == 0]) == len(y_test[y_test == 1]), (
f"{len(y_test[y_test==0])} " f"{len(y_test[y_test==1])}"
)


def test_featureimportance_forest_errors():
permute_per_tree = False
sample_dataset_per_tree = True
Expand Down
Loading