diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index bd17f0a..4006af3 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -120,3 +120,13 @@ def test_fit_line(): assert math.isclose( fit_disc.loc[line, 'CI'].values[0], 0.0000193, rel_tol=1e-2 ) + +all_pat = pipeline.combine_discovery_validation( + disc_data_shap, valid_data_shap, fit_disc, fit_valid +) +def test_combine_discovery_validation(): + assert 'Feature' in all_pat.columns + assert 'Fraction' in all_pat.columns + assert 'SHAP value' in all_pat.columns + if local: + all_pat.shape == (15888, 3)