Skip to content

Commit

Permalink
Fixed inference results when cate_feature_names is not defined
Browse files Browse the repository at this point in the history
  • Loading branch information
Miruna Oprescu committed Feb 24, 2020
1 parent c106485 commit a7b2688
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions econml/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def coef__inference(self):
if coef.size == 0: # X is None
raise AttributeError("X is None, please call intercept_inference to learn the constant!")

if callable(self._est.cate_feature_names):
if hasattr(self._est, 'cate_feature_names') and callable(self._est.cate_feature_names):
def fname_transformer(x):
return self._est.cate_feature_names(x)
else:
Expand Down Expand Up @@ -356,13 +356,13 @@ def fit(self, estimator, *args, **kwargs):

def const_marginal_effect_interval(self, X, *, alpha=0.1):
if (X is not None) and (self.featurizer is not None):
X = self.featurizer.transform(X)
X = self.featurizer.fit_transform(X)
preds = np.array([mdl.predict_interval(X, alpha=alpha) for mdl in self.fitted_models_final])
return tuple(np.moveaxis(preds, [0, 1], [-1, 0])) # send treatment to the end, pull bounds to the front

def const_marginal_effect_inference(self, X):
if (X is not None) and (self.featurizer is not None):
X = self.featurizer.transform(X)
X = self.featurizer.fit_transform(X)
pred = np.array([mdl.predict(X) for mdl in self.fitted_models_final])
if not hasattr(self.fitted_models_final[0], 'prediction_stderr'):
raise AttributeError("Final model doesn't support prediction standard eror, "
Expand Down Expand Up @@ -426,7 +426,7 @@ def coef__inference(self, T):
coef_stderr = self.fitted_models_final[ind].coef_stderr_
if coef.size == 0: # X is None
raise AttributeError("X is None, please call intercept_inference to learn the constant!")
if callable(self._est.cate_feature_names):
if hasattr(self._est, 'cate_feature_names') and callable(self._est.cate_feature_names):
def fname_transformer(x):
return self._est.cate_feature_names(x)
else:
Expand Down Expand Up @@ -692,7 +692,7 @@ def summary_frame(self, alpha=0.1, value=0, decimals=3, feat_name=None):
if self.d_y == 1:
res.index = res.index.droplevel(1)
if self.inf_type == 'coefficient':
if feat_name and self.fname_transformer:
if feat_name is not None and self.fname_transformer:
ind = self.fname_transformer(feat_name)
else:
ct = res.shape[0] // self.d_y
Expand Down

0 comments on commit a7b2688

Please sign in to comment.