Skip to content

Commit

Permalink
Make CATE policy interpreter use treatment index
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Mar 4, 2020
1 parent 967dcf3 commit a701c77
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
27 changes: 17 additions & 10 deletions econml/cate_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def interpret(self, cate_estimator, X, sample_treatment_costs=None, treatment_na
else:
_, y_pred = cate_estimator.const_marginal_effect_interval(X, alpha=self.risk_level)

# TODO: generalize to multiple treatment case?
assert all(d == 1 for d in y_pred.shape[1:]), ("Interpretation is only available for "
"single-dimensional treatments and outcomes")

Expand All @@ -564,16 +565,22 @@ def interpret(self, cate_estimator, X, sample_treatment_costs=None, treatment_na
assert np.ndim(sample_treatment_costs) < 2, "Sample treatment costs should be a vector or scalar"
y_pred -= sample_treatment_costs

if np.all(y_pred > 0):
raise AttributeError("All samples should be treated with the given treatment costs. " +
"Consider increasing the cost!")

if np.all(y_pred < 0):
raise AttributeError("All samples should not be treated with the given treatment costs. " +
"Consider decreasing the cost!")

self.tree_model.fit(X, np.sign(y_pred).flatten(), sample_weight=np.abs(y_pred))
self.policy_value = np.mean(y_pred * (self.tree_model.predict(X) == 1))
# get index of best treatment
all_y = np.hstack([np.zeros((y_pred.shape[0], 1)), y_pred.reshape(-1, 1)])
best_y = np.argmax(all_y, axis=-1)

used_t = np.unique(best_y)
if len(used_t) == 1:
best_y, = used_t
if best_y > 0:
raise AttributeError("All samples should be treated with the given treatment costs. " +
"Consider increasing the cost!")
else:
raise AttributeError("No samples should be treated with the given treatment costs. " +
"Consider decreasing the cost!")

self.tree_model.fit(X, best_y, sample_weight=np.abs(y_pred))
self.policy_value = np.mean(all_y[:, self.tree_model.predict(X)])
self.always_treat_value = np.mean(y_pred)
self.treatment_names = treatment_names
return self
Expand Down
2 changes: 1 addition & 1 deletion econml/tests/test_cate_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,4 @@ def test_random_cate_settings(self):
intrp.render('outfile', **render_kwargs)
intrp.export_graphviz(**export_kwargs)
except AttributeError as e:
assert str(e).find("All samples should") >= 0
assert str(e).find("samples should") >= 0

0 comments on commit a701c77

Please sign in to comment.