From b52b2b71aba48a9e28a0c8cc10924edf537bb1d9 Mon Sep 17 00:00:00 2001 From: jroessler Date: Mon, 30 Aug 2021 11:01:38 +0200 Subject: [PATCH] Fixed bug in simulate_randomized_trial; added randomized trial in tests --- causalml/dataset/regression.py | 2 +- tests/test_datasets.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/causalml/dataset/regression.py b/causalml/dataset/regression.py index 9b347046..b574008c 100644 --- a/causalml/dataset/regression.py +++ b/causalml/dataset/regression.py @@ -100,7 +100,7 @@ def simulate_randomized_trial(n=1000, p=5, sigma=1.0, adj=0.): ''' X = np.random.normal(size=n*p).reshape((n, -1)) - b = np.maximum(np.repeat(0.0, n), X[:, 0] + X[:, 1] + X[:, 2]) + np.maximum(np.repeat(0.0, n), X[:, 3] + X[:, 4]) + b = np.maximum(np.repeat(0.0, n), X[:, 0] + X[:, 1], X[:, 2]) + np.maximum(np.repeat(0.0, n), X[:, 3] + X[:, 4]) e = np.repeat(0.5, n) tau = X[:, 0] + np.log1p(np.exp(X[:, 1])) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index eb477d92..16195e9d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,13 +1,14 @@ import pytest -from causalml.dataset import simulate_nuisance_and_easy_treatment, simulate_hidden_confounder +from causalml.dataset import simulate_nuisance_and_easy_treatment, simulate_hidden_confounder, simulate_randomized_trial from causalml.dataset import get_synthetic_preds, get_synthetic_summary, get_synthetic_auuc from causalml.dataset import get_synthetic_preds_holdout, get_synthetic_summary_holdout from causalml.inference.meta import LRSRegressor, XGBTRegressor @pytest.mark.parametrize('synthetic_data_func', [simulate_nuisance_and_easy_treatment, - simulate_hidden_confounder]) + simulate_hidden_confounder, + simulate_randomized_trial]) def test_get_synthetic_preds(synthetic_data_func): preds_dict = get_synthetic_preds(synthetic_data_func=synthetic_data_func, n=1000,