Skip to content

Commit

Permalink
Raise user error when any of the query instances has missing values (#…
Browse files Browse the repository at this point in the history
…403)

Looks like missing values in query instances cause weird failures when
attempting to do predict()/predict_proba(). Hence, asking the user to
impute the values for missing values.
  • Loading branch information
gaugup authored Sep 18, 2023
2 parents e5d2e27 + 623ac9c commit 8277afe
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
25 changes: 25 additions & 0 deletions dice_ml/explainer_interfaces/explainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pickle
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -47,13 +48,37 @@ def __init__(self, data_interface, model_interface=None):
# self.cont_precisions = \
# [self.data_interface.get_decimal_precisions()[ix] for ix in self.encoded_continuous_feature_indexes]

def _find_features_having_missing_values(
self, data: Any) -> List[str]:
"""Return list of features which have missing values.
:param data: The dataset to check.
:type data: Any
:return: List of feature names which have missing values.
:rtype: List[str]
"""
if not isinstance(data, pd.DataFrame):
return []

list_of_feature_having_missing_values = []
for feature in data.columns.tolist():
if np.any(data[feature].isnull()):
list_of_feature_having_missing_values.append(feature)
return list_of_feature_having_missing_values

def _validate_counterfactual_configuration(
self, query_instances, total_CFs,
desired_class="opposite", desired_range=None,
permitted_range=None, features_to_vary="all",
stopping_threshold=0.5, posthoc_sparsity_param=0.1,
posthoc_sparsity_algorithm="linear", verbose=False, **kwargs):

if len(self._find_features_having_missing_values(query_instances)) > 0:
raise UserConfigValidationException(
"The query instance(s) should not have any missing values. "
"Please impute the missing values and try again."
)

if total_CFs <= 0:
raise UserConfigValidationException(
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")
Expand Down
24 changes: 24 additions & 0 deletions tests/test_dice_interface/test_explainer_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import re

import numpy as np
import pandas as pd
import pytest
from rai_test_utils.datasets.tabular import create_housing_data
Expand Down Expand Up @@ -501,6 +504,27 @@ def test_generate_counterfactuals_user_config_validations(
method=method)

explainer_function = getattr(exp, explainer_function)

regex_pattern = re.escape(
'The query instance(s) should not have any missing values. '
'Please impute the missing values and try again.')

query_instances_missing_values_numerical = pd.DataFrame({'Categorical': ['a'], 'Numerical': [np.nan]})
with pytest.raises(
UserConfigValidationException,
match=regex_pattern):
explainer_function(
query_instances=query_instances_missing_values_numerical, desired_class='opposite',
total_CFs=10)

query_instances_missing_values_categorical = pd.DataFrame({'Categorical': [np.nan], 'Numerical': [1]})
with pytest.raises(
UserConfigValidationException,
match=regex_pattern):
explainer_function(
query_instances=query_instances_missing_values_categorical, desired_class='opposite',
total_CFs=10)

with pytest.raises(
UserConfigValidationException,
match=r"The number of counterfactuals generated per query instance \(total_CFs\) "
Expand Down

0 comments on commit 8277afe

Please sign in to comment.