diff --git a/econml/dml.py b/econml/dml.py index e5153e6fa..05672355c 100644 --- a/econml/dml.py +++ b/econml/dml.py @@ -633,7 +633,7 @@ def __init__(self, n_splits=n_splits, random_state=random_state) - def fit(self, Y, T, X=None, W=None, sample_weight=None, inference=None): + def fit(self, Y, T, X=None, W=None, sample_weight=None, sample_var=None, inference=None): """ Estimate the counterfactual model from data, i.e. estimates functions τ(·,·,·), ∂τ(·,·). @@ -649,6 +649,9 @@ def fit(self, Y, T, X=None, W=None, sample_weight=None, inference=None): Controls for each sample sample_weight: optional (n,) vector Weights for each row + sample_var: optional (n, n_y) vector + Variance of sample, in case it corresponds to summary of many samples. Currently + not in use by this method but will be supported in a future release. inference: string, `Inference` instance, or None Method for performing inference. This estimator supports 'bootstrap' (or an instance of :class:`.BootstrapInference`) and 'debiasedlasso' @@ -659,7 +662,7 @@ def fit(self, Y, T, X=None, W=None, sample_weight=None, inference=None): self """ # TODO: support sample_var - if sample_weight is not None and inference is not None: + if sample_var is not None and inference is not None: warn("This estimator does not yet support sample variances and inference does not take " "sample variances into account. This feature will be supported in a future release.") check_high_dimensional(X, T, threshold=5, featurizer=self.featurizer, diff --git a/econml/tests/test_dml.py b/econml/tests/test_dml.py index 1420d5f84..4bc9ad53b 100644 --- a/econml/tests/test_dml.py +++ b/econml/tests/test_dml.py @@ -117,7 +117,8 @@ def make_random(is_discrete, d): fit_cate_intercept=fit_cate_intercept, discrete_treatment=is_discrete), True, - [None, 'debiasedlasso']), + [None, 'debiasedlasso'] + + ([BootstrapInference(n_bootstrap_samples=20)] if not is_discrete else [])), (KernelDMLCateEstimator(model_y=WeightedLasso(), model_t=model_t, fit_cate_intercept=fit_cate_intercept,