Skip to content

Commit

Permalink
[WIP] Unify tests (#214)
Browse files Browse the repository at this point in the history
* Converge test fixtures

Signed-off-by: gaugup <gaugup@microsoft.com>

* Unify tests

Signed-off-by: gaugup <gaugup@microsoft.com>

* Unify more tests

Signed-off-by: gaugup <gaugup@microsoft.com>

* Fix lint

Signed-off-by: gaugup <gaugup@microsoft.com>

* Migrate more tests

Signed-off-by: gaugup <gaugup@microsoft.com>

* Migrate few more tests to common tests

Signed-off-by: gaugup <gaugup@microsoft.com>

* Unify more tests

Signed-off-by: gaugup <gaugup@microsoft.com>
  • Loading branch information
gaugup authored Sep 14, 2021
1 parent ea97c91 commit 74c3f8b
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 254 deletions.
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,34 @@ def regression_exp_object(method="random"):
return exp


@pytest.fixture(scope='session')
def custom_public_data_interface():
dataset = helpers.load_custom_testing_dataset_regression()
d = dice_ml.Data(dataframe=dataset, continuous_features=['Numerical'], outcome_name='Outcome')
return d


@pytest.fixture(scope='session')
def sklearn_binary_classification_model_interface():
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_binary()
m = dice_ml.Model(model_path=ML_modelpath, backend='sklearn', model_type='classifier')
return m


@pytest.fixture(scope='session')
def sklearn_multiclass_classification_model_interface():
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_multiclass()
m = dice_ml.Model(model_path=ML_modelpath, backend='sklearn', model_type='classifier')
return m


@pytest.fixture(scope='session')
def sklearn_regression_model_interface():
ML_modelpath = helpers.get_custom_dataset_modelpath_pipeline_regression()
m = dice_ml.Model(model_path=ML_modelpath, backend='sklearn', model_type='regression')
return m


@pytest.fixture
def public_data_object():
"""
Expand Down
59 changes: 0 additions & 59 deletions tests/test_dice_interface/test_dice_KD.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
import dice_ml
from dice_ml.utils import helpers
from dice_ml.utils.exception import UserConfigValidationException
from dice_ml.diverse_counterfactuals import CounterfactualExamples
from dice_ml.counterfactual_explanations import CounterfactualExplanations

Expand Down Expand Up @@ -46,17 +45,6 @@ def _initiate_exp_object(self, KD_binary_classification_exp_object):
self.exp = KD_binary_classification_exp_object # explainer object
self.data_df_copy = self.exp.data_interface.data_df.copy()

# When no elements in the desired_class are present in the training data
@pytest.mark.parametrize("desired_class, total_CFs", [(1, 3), ('a', 3)])
def test_unsupported_binary_class(self, desired_class, sample_custom_query_1, total_CFs):
with pytest.raises(UserConfigValidationException) as ucve:
self.exp._generate_counterfactuals(query_instance=sample_custom_query_1, total_CFs=total_CFs,
desired_class=desired_class)
if desired_class == 1:
assert "Desired class not present in training data!" in str(ucve)
else:
assert "The target class for {0} could not be identified".format(desired_class) in str(ucve)

# When a query's feature value is not within the permitted range and the feature is not allowed to vary
@pytest.mark.parametrize("desired_range, desired_class, total_CFs, features_to_vary, permitted_range",
[(None, 0, 4, ['Numerical'], {'Categorical': ['b', 'c']})])
Expand Down Expand Up @@ -119,20 +107,6 @@ def test_permitted_range_categorical(self, desired_class, sample_custom_query_2,
total_CFs=total_CFs, permitted_range=permitted_range)
assert all(i in permitted_range["Categorical"] for i in self.exp.final_cfs_df.Categorical.values)

# Testing if an error is thrown when the query instance has an unknown categorical variable
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 1)])
def test_query_instance_outside_bounds(self, desired_class, sample_custom_query_3, total_CFs):
with pytest.raises(ValueError):
self.exp._generate_counterfactuals(query_instance=sample_custom_query_3, total_CFs=total_CFs,
desired_class=desired_class)

# Testing if an error is thrown when the query instance has an unknown column
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 1)])
def test_query_instance_unknown_column(self, desired_class, sample_custom_query_5, total_CFs):
with pytest.raises(ValueError):
self.exp._generate_counterfactuals(query_instance=sample_custom_query_5, total_CFs=total_CFs,
desired_class=desired_class)

# Ensuring that there are no duplicates in the resulting counterfactuals even if the dataset has duplicates
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 2)])
def test_duplicates(self, desired_class, sample_custom_query_4, total_CFs):
Expand All @@ -147,12 +121,6 @@ def test_duplicates(self, desired_class, sample_custom_query_4, total_CFs):

assert all(self.exp.final_cfs_df == expected_output)

# Testing for 0 CFs needed
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 0)])
def test_zero_cfs(self, desired_class, sample_custom_query_4, total_CFs):
self.exp._generate_counterfactuals(query_instance=sample_custom_query_4, total_CFs=total_CFs,
desired_class=desired_class)

# Testing for index returned
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 1)])
@pytest.mark.parametrize('posthoc_sparsity_algorithm', ['linear', 'binary', None])
Expand All @@ -179,33 +147,6 @@ def test_KD_tree_output(self, desired_class, sample_custom_query_2, total_CFs,
posthoc_sparsity_algorithm=posthoc_sparsity_algorithm)
assert all(i == desired_class for i in self.exp_multi.cfs_preds)

# Testing that the output of multiclass classification lies in the desired_class
@pytest.mark.parametrize("desired_class, total_CFs", [(2, 3)])
def test_KD_tree_counterfactual_explanations_output(self, desired_class, sample_custom_query_2, total_CFs):
counterfactual_explanations = self.exp_multi.generate_counterfactuals(
query_instances=sample_custom_query_2, total_CFs=total_CFs,
desired_class=desired_class)
assert all(i == desired_class for i in self.exp_multi.cfs_preds)

assert counterfactual_explanations is not None

# Testing for 0 CFs needed
@pytest.mark.parametrize("desired_class, total_CFs", [(0, 0)])
def test_zero_cfs(self, desired_class, sample_custom_query_4, total_CFs):
self.exp_multi._generate_counterfactuals(query_instance=sample_custom_query_4, total_CFs=total_CFs,
desired_class=desired_class)

# When no elements in the desired_class are present in the training data
@pytest.mark.parametrize("desired_class, total_CFs", [(100, 3), ('opposite', 3)])
def test_unsupported_multiclass(self, desired_class, sample_custom_query_4, total_CFs):
with pytest.raises(UserConfigValidationException) as ucve:
self.exp_multi._generate_counterfactuals(query_instance=sample_custom_query_4, total_CFs=total_CFs,
desired_class=desired_class)
if desired_class == 100:
assert "Desired class not present in training data!" in str(ucve)
else:
assert "Desired class cannot be opposite if the number of classes is more than 2." in str(ucve)


class TestDiceKDRegressionMethods:
@pytest.fixture(autouse=True)
Expand Down
34 changes: 0 additions & 34 deletions tests/test_dice_interface/test_dice_genetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,6 @@ def test_invalid_query_instance(self, sample_custom_query_1, features_to_vary, p
with pytest.raises(ValueError):
self.exp.setup(features_to_vary, permitted_range, sample_custom_query_1, feature_weights)

# # Testing that the counterfactuals are in the desired class
@pytest.mark.parametrize("desired_class, total_CFs, features_to_vary, initialization",
[(1, 2, "all", "kdtree"), (1, 2, "all", "random")])
def test_desired_class(self, desired_class, sample_custom_query_2, total_CFs, features_to_vary, initialization):
ans = self.exp.generate_counterfactuals(query_instances=sample_custom_query_2,
features_to_vary=features_to_vary,
total_CFs=total_CFs, desired_class=desired_class,
initialization=initialization)
for cfs_example in ans.cf_examples_list:
assert all(
cfs_example.final_cfs_df[self.exp.data_interface.outcome_name].values == [desired_class] * total_CFs)

# Testing that the features_to_vary argument actually varies only the features that you wish to vary
@pytest.mark.parametrize("desired_class, total_CFs, features_to_vary, initialization",
[(1, 2, ["Numerical"], "kdtree"), (1, 2, ["Numerical"], "random")])
Expand Down Expand Up @@ -121,18 +109,6 @@ def test_permitted_range_categorical(self, desired_class, total_CFs, features_to
permitted_range[feature][1] for i
in range(total_CFs))

# Testing if an error is thrown when the query instance has an unknown categorical variable
@pytest.mark.parametrize("desired_class, total_CFs, features_to_vary", [(0, 1, "all")])
def test_query_instance_outside_bounds(self, desired_class, sample_custom_query_3, total_CFs, features_to_vary):
with pytest.raises(ValueError):
self.exp.setup(features_to_vary, None, sample_custom_query_3, "inverse_mad")

# Testing if an error is thrown when the query instance has an unknown categorical variable
@pytest.mark.parametrize("features_to_vary", [("all")])
def test_query_instance_unknown_column(self, sample_custom_query_5, features_to_vary):
with pytest.raises(ValueError):
self.exp.setup(features_to_vary, None, sample_custom_query_5, "inverse_mad")

# Testing if an error is thrown when the query instance has outcome variable
def test_query_instance_with_target_column(self, sample_custom_query_6):
with pytest.raises(ValueError) as ve:
Expand Down Expand Up @@ -167,16 +143,6 @@ class TestDiceGeneticMultiClassificationMethods:
def _initiate_exp_object(self, genetic_multi_classification_exp_object):
self.exp = genetic_multi_classification_exp_object # explainer object

# Testing that the counterfactuals are in the desired class
@pytest.mark.parametrize("desired_class, total_CFs, initialization", [(2, 2, "kdtree"), (2, 2, "random")])
def test_desired_class(self, desired_class, sample_custom_query_2, total_CFs, initialization):
ans = self.exp.generate_counterfactuals(query_instances=sample_custom_query_2,
total_CFs=total_CFs, desired_class=desired_class,
initialization=initialization)
for cfs_example in ans.cf_examples_list:
assert all(
cfs_example.final_cfs_df[self.exp.data_interface.outcome_name].values == [desired_class] * total_CFs)

# Testing if only valid cfs are found after maxiterations
@pytest.mark.parametrize("desired_class, total_CFs, initialization, maxiterations",
[(2, 7, "kdtree", 0), (2, 7, "random", 0)])
Expand Down
Loading

0 comments on commit 74c3f8b

Please sign in to comment.