From 8f17cfd2d374635543a5dda141ee18fa83833b74 Mon Sep 17 00:00:00 2001 From: An <19732879+azz147@users.noreply.github.com> Date: Wed, 16 Nov 2022 02:32:17 -0500 Subject: [PATCH] fix int(final_cfs_sparse.at[cf_ix, feature]) in do_posthoc_sparsity_enhancement, do_linear_search and do_binary_search (#343) Signed-off-by: An <19732879+azz147@users.noreply.github.com> Signed-off-by: An <19732879+azz147@users.noreply.github.com> Co-authored-by: An <> --- dice_ml/explainer_interfaces/explainer_base.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 50ae8e3f..eab7e875 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -532,7 +532,7 @@ def do_posthoc_sparsity_enhancement(self, final_cfs_sparse, query_instance, post for feature in features_sorted: # current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.iat[[cf_ix]][self.data_interface.feature_names]) # feat_ix = self.data_interface.continuous_feature_names.index(feature) - diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) + diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature] if(abs(diff) <= quantiles[feature]): if posthoc_sparsity_algorithm == "linear": final_cfs_sparse = self.do_linear_search(diff, decimal_prec, query_instance, cf_ix, @@ -561,17 +561,16 @@ def do_linear_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f while((abs(diff) > 10e-4) and (np.sign(diff*old_diff) > 0) and self.is_cf_valid(current_pred)) and (count_steps < limit_steps_ls): - old_val = int(final_cfs_sparse.at[cf_ix, feature]) + old_val = final_cfs_sparse.at[cf_ix, feature] final_cfs_sparse.at[cf_ix, feature] += np.sign(diff)*change current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) old_diff = diff if not self.is_cf_valid(current_pred): final_cfs_sparse.at[cf_ix, feature] = old_val - diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) return final_cfs_sparse - diff = query_instance[feature].iat[0] - int(final_cfs_sparse.at[cf_ix, feature]) + diff = query_instance[feature].iat[0] - final_cfs_sparse.at[cf_ix, feature] count_steps += 1 @@ -581,7 +580,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f """Performs a binary search between continuous features of a CF and corresponding values in query_instance until the prediction class changes.""" - old_val = int(final_cfs_sparse.at[cf_ix, feature]) + old_val = final_cfs_sparse.at[cf_ix, feature] final_cfs_sparse.at[cf_ix, feature] = query_instance[feature].iat[0] # Prediction of the query instance current_pred = self.predict_fn_for_sparsity(final_cfs_sparse.loc[[cf_ix]][self.data_interface.feature_names]) @@ -594,7 +593,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f # move the CF values towards the query_instance if diff > 0: - left = int(final_cfs_sparse.at[cf_ix, feature]) + left = final_cfs_sparse.at[cf_ix, feature] right = query_instance[feature].iat[0] while left <= right: @@ -614,7 +613,7 @@ def do_binary_search(self, diff, decimal_prec, query_instance, cf_ix, feature, f else: left = query_instance[feature].iat[0] - right = int(final_cfs_sparse.at[cf_ix, feature]) + right = final_cfs_sparse.at[cf_ix, feature] while right >= left: current_val = right - ((right - left)/2)