Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] add PermutationForests and FeatureImportanceForests to sktree #125

Merged
merged 79 commits into from
Oct 5, 2023

Conversation

PSSF23
Copy link
Member

@PSSF23 PSSF23 commented Sep 11, 2023

Changes proposed in this pull request:

#112 , #120 to be addressed in a future PR

Before submitting

  • I've read and followed all steps in the Making a pull request
    section of the CONTRIBUTING docs.
  • I've updated or added any relevant docstrings following the syntax described in the
    Writing docstrings section of the CONTRIBUTING docs.
  • If this PR fixes a bug, I've added a test that will fail without my fix.
  • If this PR adds a new feature, I've added tests that sufficiently cover my new functionality.

After submitting

  • All GitHub Actions jobs for my pull request have passed.

Co-Authored-By: Sambit Panda <36676569+sampan501@users.noreply.github.com>
Co-Authored-By: Yuxin <99897042+YuxinB@users.noreply.github.com>
Co-Authored-By: Adam Li <3460267+adam2392@users.noreply.github.com>
@sampan501
Copy link
Member

Will this also have MIRF with Mutual Info as a stat?

@codecov
Copy link

codecov bot commented Sep 11, 2023

Codecov Report

Attention: 225 lines in your changes are missing coverage. Please review.

Comparison is base (9b486bc) 87.71% compared to head (be16e5a) 44.51%.
Report is 1 commits behind head on main.

❗ Current head be16e5a differs from pull request most recent head 60d9c85. Consider uploading reports for the commit 60d9c85 to get more accurate results

Additional details and impacted files
@@             Coverage Diff             @@
##             main     #125       +/-   ##
===========================================
- Coverage   87.71%   44.51%   -43.21%     
===========================================
  Files          30       36        +6     
  Lines        2426     3116      +690     
===========================================
- Hits         2128     1387      -741     
- Misses        298     1729     +1431     
Files Coverage Δ
sktree/__init__.py 80.76% <100.00%> (ø)
sktree/conftest.py 100.00% <100.00%> (ø)
sktree/ensemble/_eiforest.py 57.14% <ø> (-42.86%) ⬇️
sktree/stats/__init__.py 100.00% <100.00%> (ø)
sktree/tree/__init__.py 100.00% <100.00%> (ø)
sktree/tests/test_honest_forest.py 35.95% <0.00%> (-64.05%) ⬇️
sktree/ensemble/_honest_forest.py 51.19% <60.00%> (-40.28%) ⬇️
sktree/tree/_classes.py 48.72% <37.50%> (-26.28%) ⬇️
sktree/tree/tests/test_honest_tree.py 33.78% <25.00%> (-66.22%) ⬇️
sktree/tree/_honest_tree.py 20.80% <30.00%> (-78.60%) ⬇️
... and 4 more

... and 18 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member Author

@PSSF23 PSSF23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Referring to this file in hyppo. I removed everything related to hyppo for now, including the mutual information function, to avoid confusion. The _might.py file includes all other single-feature and multi-view methods we developed.

@adam2392
Copy link
Collaborator

Referring to this file in hyppo. I removed everything related to hyppo for now, including the mutual information function, to avoid confusion. The _might.py file includes all other single-feature and multi-view methods we developed.

For now I will remove the multi-view stuff prolly and then also look at how we can introduce arbitrary metrics in here: e.g. MI, ROC_auc, etc.

I'll take a look tonight

Copy link
Member Author

@PSSF23 PSSF23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per @sampan501 , added stat="MI" and stat="AUC" as parameters for statistic().

@sampan501
Copy link
Member

sampan501 commented Sep 12, 2023

@PSSF23 Can you also rename it to MIGHT? Might make all this name changing easier to follow - no pun intended

Copy link
Member Author

@PSSF23 PSSF23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Renamed to MIGHT and MIGHT_MV
  • Added y-label permutation test to MIGHT (previously removed due to connection to hyppo)

@PSSF23
Copy link
Member Author

PSSF23 commented Sep 12, 2023

@sampan501 I might be misunderstanding the MI calculation, but why the original method had the axis=1 param? If it's not needed I'll correct MIGHT_MV as well.

@sampan501
Copy link
Member

predict_proba returns a (n_samples, n_classes) array as an output. So, the previous MI calculation was taking averages over the classes. We don't need to do that since you do that already in forest_pos. We can probably remove the mean for both calculations too.

@sampan501
Copy link
Member

Before we make any serious changes like the one above, we really need unit tests for this method

Copy link
Member Author

@PSSF23 PSSF23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an empty test file. Let's edit and add to this message about what we need to test:

  • test on iris for mutual info for accuracy
  • partial AUROC
  • multiview splitter (not this time)

@adam2392
Copy link
Collaborator

I made an empty test file. Let's edit and add to this message about what we need to test:

  • test on iris for mutual info for accuracy
  • partial AUROC
  • multiview splitter

Let's do multi view in a sep PR.

Copy link
Member Author

@PSSF23 PSSF23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By unit test results, MIGHT seems to perform the worst when coupled with PatchObliqueDecisionTreeClassifier(). I remember similar situations back with honest tree tests. Should I lower the passing threshold or remove the estimator option?

@adam2392
Copy link
Collaborator

Changes proposed in this pull request:

I would remove the second two bullets to prevent those issues from getting closed w/o actually resolving the issue raised.

Copy link
Member Author

@PSSF23 PSSF23 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In FI, I cleared all the duplicate variables like posteriors_final_ and observe_posteriors_, and save the permuted statistic as class variable permute_stat_ (all other results of permutation are saved, so this one should be there as well). I also make all default statistic MI, but we might change it to pAUC later.

Signed-off-by: Adam Li <adam2392@gmail.com>
Signed-off-by: Adam Li <adam2392@gmail.com>
Signed-off-by: Adam Li <adam2392@gmail.com>
Comment on lines 392 to 414
@pytest.mark.parametrize("backend", ["loky", "threading"])
@pytest.mark.parametrize("n_jobs", [1, -1])
def test_parallelization(backend, n_jobs):
"""Test parallelization of training forests."""
n_samples = 100
n_features = 5
X = rng.uniform(size=(n_samples, n_features))
y = rng.integers(0, 2, size=n_samples) # Binary classification

def run_forest(covariate_index=None):
clf = FeatureImportanceForestClassifier(
estimator=HonestForestClassifier(
n_estimators=10, random_state=seed, n_jobs=n_jobs, honest_fraction=0.2
),
test_size=0.5,
)
pvalue = clf.test(X, y, covariate_index=[covariate_index], metric="mi")
return pvalue

out = Parallel(n_jobs=1, backend=backend)(
delayed(run_forest)(covariate_index) for covariate_index in range(n_features)
)
assert len(out) == n_features
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sampan501 to my knowledge, any issues w/ joblib should be fixed or are the result of some other issues perhaps?

Lmk if this unit-test sufficiently captures what the usage looks like in the power simulations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_jobs = 1 should be n_jobs = n_jobs in line 411, but otherwise good

@PSSF23
Copy link
Member Author

PSSF23 commented Oct 3, 2023

@adam2392 can the print statements in FIClf be removed?

@adam2392
Copy link
Collaborator

adam2392 commented Oct 3, 2023

Yeah feel free to push that. Im making some other changes tho so anything larger feel free to open a PR to this PR.

PSSF23 and others added 5 commits October 3, 2023 20:07
* Update

Signed-off-by: Adam Li <adam2392@gmail.com>

* Fix submodule

Signed-off-by: Adam Li <adam2392@gmail.com>

* Possible change to might code

Signed-off-by: Adam Li <adam2392@gmail.com>

* Add fixes

Signed-off-by: Adam Li <adam2392@gmail.com>

* Fix style

Signed-off-by: Adam Li <adam2392@gmail.com>

---------

Signed-off-by: Adam Li <adam2392@gmail.com>
Signed-off-by: Adam Li <adam2392@gmail.com>
Signed-off-by: Adam Li <adam2392@gmail.com>
@@ -205,12 +209,12 @@ def test_linear_model(hypotester, model_kwargs, n_samples, n_repeats, test_size)
n_jobs=-1,
),
"random_state": seed,
"permute_per_tree": True,
"sample_dataset_per_tree": True,
"permute_per_tree": False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is testing MI Sep and not MI/Tree right? Where are we testing MI/Tree?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. This is MI Sep. We should test MI / Tree as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I will add a pytest.mark.parametrize in a bit. I think last night the pvalue behavior was not converging as well as the MI Sep.

@sampan501
Copy link
Member

sampan501 commented Oct 5, 2023

@adam2392 @PSSF23 I'm still having an error with the sample size when using the parameters from our meeting today (n = 32, test_size = 0.2). Here's the trace:

"""
Traceback (most recent call last):
  File "/data/sambit/miniconda3/envs/cancer/lib/python3.11/site-packages/joblib/externals/loky/process_executor.py", line 463, in _process_worker
    r = call_item()
        ^^^^^^^^^^^
  File "/data/sambit/miniconda3/envs/cancer/lib/python3.11/site-packages/joblib/externals/loky/process_executor.py", line 291, in __call__
    return self.fn(*self.args, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/miniconda3/envs/cancer/lib/python3.11/site-packages/joblib/parallel.py", line 589, in __call__
    return [func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/miniconda3/envs/cancer/lib/python3.11/site-packages/joblib/parallel.py", line 589, in <listcomp>
    return [func(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/mendseqs/high-d/high-d-sims.py", line 1342, in compute_null
    pval = _nonperm_pval(test, sim, n, p, noise=noise, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/mendseqs/high-d/high-d-sims.py", line 1277, in _nonperm_pval
    pvalue = test[0](**test[1]).test(u, v, **kwargs)[1]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/scikit-tree/sktree/stats/forestht.py", line 412, in test
    metric_star, metric_star_pi = _compute_null_distribution_coleman(
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/scikit-tree/sktree/stats/utils.py", line 291, in _compute_null_distribution_coleman
    first_half_metric = metric_func(y_test[non_nan_samples, :], y_pred_first_half)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/scikit-tree/sktree/stats/utils.py", line 34, in _mutual_information
    raise ValueError(f"y_true must be 1d, not {y_true.shape}")
ValueError: y_true must be 1d, not (1, 1)
"""

@adam2392
Copy link
Collaborator

adam2392 commented Oct 5, 2023

@adam2392 @PSSF23 I'm still having an error with the sample size when using the parameters from our meeting today (n = 32, test_size = 0.2). Here's the trace:

"""
Traceback (most recent call last):
  File "/data/sambit/miniconda3/envs/cancer/lib/python3.11/site-packages/joblib/externals/loky/process_executor.py", line 463, in _process_worker
    r = call_item()
        ^^^^^^^^^^^
  File "/data/sambit/miniconda3/envs/cancer/lib/python3.11/site-packages/joblib/externals/loky/process_executor.py", line 291, in __call__
    return self.fn(*self.args, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/miniconda3/envs/cancer/lib/python3.11/site-packages/joblib/parallel.py", line 589, in __call__
    return [func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/miniconda3/envs/cancer/lib/python3.11/site-packages/joblib/parallel.py", line 589, in <listcomp>
    return [func(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/mendseqs/high-d/high-d-sims.py", line 1342, in compute_null
    pval = _nonperm_pval(test, sim, n, p, noise=noise, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/mendseqs/high-d/high-d-sims.py", line 1277, in _nonperm_pval
    pvalue = test[0](**test[1]).test(u, v, **kwargs)[1]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/scikit-tree/sktree/stats/forestht.py", line 412, in test
    metric_star, metric_star_pi = _compute_null_distribution_coleman(
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/scikit-tree/sktree/stats/utils.py", line 291, in _compute_null_distribution_coleman
    first_half_metric = metric_func(y_test[non_nan_samples, :], y_pred_first_half)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/sambit/scikit-tree/sktree/stats/utils.py", line 34, in _mutual_information
    raise ValueError(f"y_true must be 1d, not {y_true.shape}")
ValueError: y_true must be 1d, not (1, 1)
"""

Any chance you can reproduce the error w/ a small code snippet?

The following code works for me:

def test_small_dataset():
    n_samples = 32
    n_features = 5
    X = rng.uniform(size=(n_samples, n_features))
    y = rng.integers(0, 2, size=n_samples)  # Binary classification

    clf = FeatureImportanceForestClassifier(
        estimator=HonestForestClassifier(
            n_estimators=10, random_state=seed, n_jobs=1, honest_fraction=0.5
        ),
        test_size=0.2,
        permute_per_tree=False,
        sample_dataset_per_tree=False,
    )
    stat, pvalue = clf.test(X, y, covariate_index=[1,2], metric='mi')

Signed-off-by: Adam Li <adam2392@gmail.com>
@adam2392
Copy link
Collaborator

adam2392 commented Oct 5, 2023

FYI, I added a short unit-test to test small sample-sizes.

@sampan501
Copy link
Member

I will once I find the simulation and sample size that's causing the issue

Signed-off-by: Adam Li <adam2392@gmail.com>
Signed-off-by: Adam Li <adam2392@gmail.com>
@adam2392 adam2392 merged commit f8a2ff7 into main Oct 5, 2023
21 of 22 checks passed
@adam2392 adam2392 deleted the might branch October 5, 2023 14:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants