diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index ad231e69..af5f77ce 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -5,6 +5,7 @@ import pickle from abc import ABC, abstractmethod from collections.abc import Iterable +from typing import Any, List import numpy as np import pandas as pd @@ -47,6 +48,24 @@ def __init__(self, data_interface, model_interface=None): # self.cont_precisions = \ # [self.data_interface.get_decimal_precisions()[ix] for ix in self.encoded_continuous_feature_indexes] + def _find_features_having_missing_values( + self, data: Any) -> List[str]: + """Return list of features which have missing values. + + :param data: The dataset to check. + :type data: Any + :return: List of feature names which have missing values. + :rtype: List[str] + """ + if not isinstance(data, pd.DataFrame): + return [] + + list_of_feature_having_missing_values = [] + for feature in data.columns.tolist(): + if np.any(data[feature].isnull()): + list_of_feature_having_missing_values.append(feature) + return list_of_feature_having_missing_values + def _validate_counterfactual_configuration( self, query_instances, total_CFs, desired_class="opposite", desired_range=None, @@ -54,6 +73,12 @@ def _validate_counterfactual_configuration( stopping_threshold=0.5, posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm="linear", verbose=False, **kwargs): + if len(self._find_features_having_missing_values(query_instances)) > 0: + raise UserConfigValidationException( + "The query instance(s) should not have any missing values. " + "Please impute the missing values and try again." + ) + if total_CFs <= 0: raise UserConfigValidationException( "The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.") diff --git a/tests/test_dice_interface/test_explainer_base.py b/tests/test_dice_interface/test_explainer_base.py index f898ddde..4325bd16 100644 --- a/tests/test_dice_interface/test_explainer_base.py +++ b/tests/test_dice_interface/test_explainer_base.py @@ -1,3 +1,6 @@ +import re + +import numpy as np import pandas as pd import pytest from rai_test_utils.datasets.tabular import create_housing_data @@ -501,6 +504,27 @@ def test_generate_counterfactuals_user_config_validations( method=method) explainer_function = getattr(exp, explainer_function) + + regex_pattern = re.escape( + 'The query instance(s) should not have any missing values. ' + 'Please impute the missing values and try again.') + + query_instances_missing_values_numerical = pd.DataFrame({'Categorical': ['a'], 'Numerical': [np.nan]}) + with pytest.raises( + UserConfigValidationException, + match=regex_pattern): + explainer_function( + query_instances=query_instances_missing_values_numerical, desired_class='opposite', + total_CFs=10) + + query_instances_missing_values_categorical = pd.DataFrame({'Categorical': [np.nan], 'Numerical': [1]}) + with pytest.raises( + UserConfigValidationException, + match=regex_pattern): + explainer_function( + query_instances=query_instances_missing_values_categorical, desired_class='opposite', + total_CFs=10) + with pytest.raises( UserConfigValidationException, match=r"The number of counterfactuals generated per query instance \(total_CFs\) "