Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stratify sampling when split train/test data #143

Merged
merged 28 commits into from
Oct 19, 2023
Merged
Changes from 5 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e030050
Startify sampling when split tran/test data
YuxinB Oct 12, 2023
5d60959
Stratified_Sample, Let startify = None for Regressor
YuxinB Oct 12, 2023
c3df52e
Merge branch 'main' into Stratified_Sample
PSSF23 Oct 16, 2023
9a15c69
Merge branch 'main' into Stratified_Sample
adam2392 Oct 16, 2023
78837d2
FIX correct changes & black format
PSSF23 Oct 17, 2023
4f88518
DOC modify warning text
PSSF23 Oct 17, 2023
ffb8136
Add unit test for verifying stratified sampling
YuxinB Oct 17, 2023
8fa1277
Merge branch 'Stratified_Sample' of https://github.com/neurodata/scik…
YuxinB Oct 17, 2023
3ff6340
Correct Typo for Stratified
YuxinB Oct 17, 2023
3a67779
Merge branch 'main' into Stratified_Sample
sampan501 Oct 17, 2023
70a14a5
Fixed example and whatsnew
adam2392 Oct 18, 2023
98fbe5f
Merge branch 'main' into Stratified_Sample
adam2392 Oct 18, 2023
f555e2c
ENH correct tests & add coverage
PSSF23 Oct 18, 2023
4595df3
FIX change n_samples for test to be valid
PSSF23 Oct 18, 2023
30b6d3e
DOC update name for MIGHT & black format
PSSF23 Oct 18, 2023
9a7459d
FIX update the test for stratification
PSSF23 Oct 18, 2023
e0cbb60
FIX correct test variables
PSSF23 Oct 18, 2023
e248a7c
FIX correct variable shape
PSSF23 Oct 19, 2023
8ba06ef
FIX correct test method
PSSF23 Oct 19, 2023
5d516a7
FIX disable check_input for correct error
PSSF23 Oct 19, 2023
735a10b
FIX remove duplicate checks
PSSF23 Oct 19, 2023
47857c3
DOC add docstring for stratify
PSSF23 Oct 19, 2023
3ce68e7
Merge branch 'main' into Stratified_Sample
PSSF23 Oct 19, 2023
888cb42
Merge branch 'main' into Stratified_Sample
sampan501 Oct 19, 2023
35eb776
Add contributor
YuxinB Oct 19, 2023
9e2ba9e
Merge branch 'Stratified_Sample' of https://github.com/neurodata/scik…
YuxinB Oct 19, 2023
3332e9a
DOC update reference
PSSF23 Oct 19, 2023
3bc05b5
Merge branch 'Stratified_Sample' of https://github.com/neurodata/scik…
PSSF23 Oct 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions sktree/stats/forestht.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,26 @@ def reset(self):
self._is_fitted = False
self._seeds = None

def _get_estimators_indices(self, sample_separate=False):
def _get_estimators_indices(self, stratifier, sample_separate=False):

# Check stratifier
# if stratifier is None, stratifier is regressor
if stratifier is not None:
if self._n_samples_ is not None and stratifier.shape[0] != self._n_samples_:
raise RuntimeError(
PSSF23 marked this conversation as resolved.
Show resolved Hide resolved
f"stratifier must have {self._n_samples_} samples, got {stratifier.shape[0]}. "
f"If running on a new dataset, call the 'reset' method."
)

if (
self._type_of_target_ is not None
and type_of_target(stratifier) != self._type_of_target_
):
raise RuntimeError(
PSSF23 marked this conversation as resolved.
Show resolved Hide resolved
f"stratifier must have type {self._type_of_target_}, got {type_of_target(stratifier)}. "
f"If running on a new dataset, call the 'reset' method."
)

indices = np.arange(self._n_samples_, dtype=int)

# Get drawn indices along both sample and feature axes
Expand Down Expand Up @@ -191,7 +210,11 @@ def _get_estimators_indices(self, sample_separate=False):
# Operations accessing random_state must be performed identically
# to those in `_parallel_build_trees()`
indices_train, indices_test = train_test_split(
indices, test_size=self.test_size, shuffle=True, random_state=seed
indices,
test_size=self.test_size,
shuffle=True,
stratify=stratifier,
random_state=seed,
)

yield indices_train, indices_test
Expand All @@ -206,13 +229,14 @@ def _get_estimators_indices(self, sample_separate=False):
indices_train, indices_test = train_test_split(
indices,
test_size=self.test_size,
stratify=stratifier,
random_state=self._seeds,
)

for _ in self.estimator_.estimators_:
yield indices_train, indices_test

@property
def train_test_samples_(self):
def train_test_samples_(self, stratifier):
"""
The subset of drawn samples for each base estimator.

Expand All @@ -229,7 +253,7 @@ def train_test_samples_(self):

return [
(indices_train, indices_test)
for indices_train, indices_test in self._get_estimators_indices()
for indices_train, indices_test in self._get_estimators_indices(stratifier=stratifier)
]

def _statistic(
Expand Down Expand Up @@ -462,10 +486,10 @@ def test(
observe_posteriors = self.observe_posteriors_
observe_stat = self.observe_stat_

# next permute the data
if covariate_index is None:
covariate_index = np.arange(X.shape[1], dtype=int)

adam2392 marked this conversation as resolved.
Show resolved Hide resolved
# next permute the data
permute_stat, permute_posteriors, permute_samples = self.statistic(
X,
y,
Expand Down Expand Up @@ -493,7 +517,7 @@ def test(
# If not sampling a new dataset per tree, then we may either be
# permuting the covariate index per tree or per forest. If not permuting
# there is only one train and test split, so we can just use that
_, indices_test = self.train_test_samples_[0]
_, indices_test = self.train_test_samples_(stratifier=y)[0]
indices_test = observe_samples
y_test = y[indices_test, :]
y_pred_proba_normal = observe_posteriors[:, indices_test, :]
Expand Down Expand Up @@ -725,12 +749,12 @@ def _statistic(
self._type_of_target_,
)
for idx, (indices_train, indices_test) in enumerate(
self._get_estimators_indices(sample_separate=True)
self._get_estimators_indices(y, sample_separate=True)
)
)
else:
# fitting a forest will only get one unique train/test split
indices_train, indices_test = self.train_test_samples_[0]
indices_train, indices_test = self.train_test_samples_(stratifier=None)[0]

X_train, X_test = X[indices_train, :], X[indices_test, :]
y_train, y_test = y[indices_train, :], y[indices_test, :]
Expand Down Expand Up @@ -946,12 +970,12 @@ def _statistic(
self._type_of_target_,
)
for idx, (indices_train, indices_test) in enumerate(
self._get_estimators_indices(sample_separate=True)
self._get_estimators_indices(y, sample_separate=True)
)
)
else:
# fitting a forest will only get one unique train/test split
indices_train, indices_test = self.train_test_samples_[0]
indices_train, indices_test = self.train_test_samples_(stratifier=y)[0]

X_train, X_test = X[indices_train, :], X[indices_test, :]
y_train = y[indices_train, :]
Expand Down
Loading