Skip to content

Commit

Permalink
Merge pull request #385 from jroessler/jroessler/fix_synthetic_data_2
Browse files Browse the repository at this point in the history
Fixed bug in simulate_randomized_trial
  • Loading branch information
paullo0106 authored Sep 1, 2021
2 parents 0be443a + b52b2b7 commit 27a55f5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion causalml/dataset/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def simulate_randomized_trial(n=1000, p=5, sigma=1.0, adj=0.):
'''

X = np.random.normal(size=n*p).reshape((n, -1))
b = np.maximum(np.repeat(0.0, n), X[:, 0] + X[:, 1] + X[:, 2]) + np.maximum(np.repeat(0.0, n), X[:, 3] + X[:, 4])
b = np.maximum(np.repeat(0.0, n), X[:, 0] + X[:, 1], X[:, 2]) + np.maximum(np.repeat(0.0, n), X[:, 3] + X[:, 4])
e = np.repeat(0.5, n)
tau = X[:, 0] + np.log1p(np.exp(X[:, 1]))

Expand Down
5 changes: 3 additions & 2 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import pytest

from causalml.dataset import simulate_nuisance_and_easy_treatment, simulate_hidden_confounder
from causalml.dataset import simulate_nuisance_and_easy_treatment, simulate_hidden_confounder, simulate_randomized_trial
from causalml.dataset import get_synthetic_preds, get_synthetic_summary, get_synthetic_auuc
from causalml.dataset import get_synthetic_preds_holdout, get_synthetic_summary_holdout
from causalml.inference.meta import LRSRegressor, XGBTRegressor


@pytest.mark.parametrize('synthetic_data_func', [simulate_nuisance_and_easy_treatment,
simulate_hidden_confounder])
simulate_hidden_confounder,
simulate_randomized_trial])
def test_get_synthetic_preds(synthetic_data_func):
preds_dict = get_synthetic_preds(synthetic_data_func=synthetic_data_func,
n=1000,
Expand Down

0 comments on commit 27a55f5

Please sign in to comment.