diff --git a/econml/cate_interpreter.py b/econml/cate_interpreter.py index e22e112a4..9ae497a81 100644 --- a/econml/cate_interpreter.py +++ b/econml/cate_interpreter.py @@ -555,6 +555,7 @@ def interpret(self, cate_estimator, X, sample_treatment_costs=None, treatment_na else: _, y_pred = cate_estimator.const_marginal_effect_interval(X, alpha=self.risk_level) + # TODO: generalize to multiple treatment case? assert all(d == 1 for d in y_pred.shape[1:]), ("Interpretation is only available for " "single-dimensional treatments and outcomes") @@ -564,16 +565,22 @@ def interpret(self, cate_estimator, X, sample_treatment_costs=None, treatment_na assert np.ndim(sample_treatment_costs) < 2, "Sample treatment costs should be a vector or scalar" y_pred -= sample_treatment_costs - if np.all(y_pred > 0): - raise AttributeError("All samples should be treated with the given treatment costs. " + - "Consider increasing the cost!") - - if np.all(y_pred < 0): - raise AttributeError("All samples should not be treated with the given treatment costs. " + - "Consider decreasing the cost!") - - self.tree_model.fit(X, np.sign(y_pred).flatten(), sample_weight=np.abs(y_pred)) - self.policy_value = np.mean(y_pred * (self.tree_model.predict(X) == 1)) + # get index of best treatment + all_y = np.hstack([np.zeros((y_pred.shape[0], 1)), y_pred.reshape(-1, 1)]) + best_y = np.argmax(all_y, axis=-1) + + used_t = np.unique(best_y) + if len(used_t) == 1: + best_y, = used_t + if best_y > 0: + raise AttributeError("All samples should be treated with the given treatment costs. " + + "Consider increasing the cost!") + else: + raise AttributeError("No samples should be treated with the given treatment costs. " + + "Consider decreasing the cost!") + + self.tree_model.fit(X, best_y, sample_weight=np.abs(y_pred)) + self.policy_value = np.mean(all_y[:, self.tree_model.predict(X)]) self.always_treat_value = np.mean(y_pred) self.treatment_names = treatment_names return self diff --git a/econml/tests/test_cate_interpreter.py b/econml/tests/test_cate_interpreter.py index ff9c218c0..2c39044f0 100644 --- a/econml/tests/test_cate_interpreter.py +++ b/econml/tests/test_cate_interpreter.py @@ -187,4 +187,4 @@ def test_random_cate_settings(self): intrp.render('outfile', **render_kwargs) intrp.export_graphviz(**export_kwargs) except AttributeError as e: - assert str(e).find("All samples should") >= 0 + assert str(e).find("samples should") >= 0