Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation period #47

Merged
merged 5 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ You can estimate the causal effect of the intervention using the `estimate_effec

```python
>>> sc.estimate_effects()
0.8398940970771678
[0.8398940970771678]
```

### Placebo Test
Expand Down
5 changes: 3 additions & 2 deletions examples/0 - sample data.ipynb

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions examples/1 - synthetic control.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "synthx"
version = "0.4.3"
version = "1.0.0"
description = "A Python Library for Advanced Synthetic Control Analysis"
authors = ["kenki931128 <kenki.nkmr@gmail.com>"]
license = "MIT License"
Expand Down
20 changes: 19 additions & 1 deletion synthx/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
covariate_columns: Optional[list[str]],
intervention_units: Union[Any, list[Any]],
intervention_time: Union[int, date],
validation_time: Optional[Union[int, date]] = None,
norm: Optional[str] = None,
) -> None:
"""Initialize the Dataset.
Expand All @@ -42,6 +43,7 @@ def __init__(
covariate_columns (Optional[list[str]]): The columns representing the covariates.
intervention_units (Union[Any, list[Any]]): A list of intervented units
intervention_time (Union[int, date]): When the intervention or event happens.
validation_time (Optional[Union[int, date]]): validation time if needed.
norm (Optional[str]): If not None, will normalize the y_column.
it should be 'z_standardize'
"""
Expand All @@ -54,6 +56,7 @@ def __init__(
intervention_units if isinstance(intervention_units, list) else [intervention_units]
)
self.intervention_time = intervention_time
self.validation_time = validation_time
self.norm = norm
self.__validate()
self.__normalization()
Expand All @@ -71,7 +74,12 @@ def plot(self, units: Optional[list[str]] = None, save: Optional[str] = None) ->
for unit in units:
unit_data = self.data.filter(pl.col(self.unit_column) == unit)
plt.plot(unit_data[self.time_column], unit_data[self.y_column], label=f'Unit {unit}')
# Add vertical line for intervention time

# Add vertical line for validation time and intervention time
if self.validation_time is not None:
plt.axvline(
self.validation_time, color='orange', linestyle='--', label='Validation Time' # type: ignore
)
plt.axvline(
self.intervention_time, color='red', linestyle='--', label='Intervention Time' # type: ignore
)
Expand Down Expand Up @@ -117,6 +125,16 @@ def __validate(self) -> None:
)
if self.intervention_time > self.data[self.time_column].max(): # type: ignore
raise InvalidInterventionTimeError(f'no date point at {self.intervention_time} time.')
if self.validation_time is not None and type(self.intervention_time) != type(
self.validation_time
):
raise InvalidColumnTypeError(
f'intervention_time and validation_time should have the same type.'
)
if self.validation_time is not None and self.intervention_time <= self.validation_time: # type: ignore
raise InvalidInterventionTimeError(
f'intervention_time should be later than validation_time.'
)

if self.y_column not in self.data.columns:
raise ColumnNotFoundError(self.y_column)
Expand Down
16 changes: 14 additions & 2 deletions synthx/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,14 @@ def estimate_effects(self) -> list[float]:
Returns:
list[float]: The estimated effect of the intervention.
"""
# dataset before intervention
# dataset in the training period
training_time = (
self.dataset.validation_time
if self.dataset.validation_time is not None
else self.dataset.intervention_time
)
pre_df = self.dataset.data.filter(
self.dataset.data[self.dataset.time_column] < self.dataset.intervention_time
self.dataset.data[self.dataset.time_column] < training_time
)
pre_result = SyntheticControlResult(
dataset=sx.Dataset(
Expand Down Expand Up @@ -159,6 +164,13 @@ def plot(self, save: Optional[str] = None) -> None:
for i, intervention_unit in enumerate(self.dataset.intervention_units):
axs[i].plot(self.x_time, self.y_test(intervention_unit), label='Test')
axs[i].plot(self.x_time, self.y_control(intervention_unit), label='Control')
if self.dataset.validation_time is not None:
axs[i].axvline(
self.dataset.validation_time,
color='orange',
linestyle='--',
label='Validation Time',
)
axs[i].axvline(
self.dataset.intervention_time,
color='red',
Expand Down
14 changes: 10 additions & 4 deletions synthx/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ def synthetic_control(dataset: sx.Dataset) -> sx.SyntheticControlResult:
df = dataset.data

# condition
# TODO: add validation period
condition_pre_intervention_time = df[dataset.time_column] < dataset.intervention_time
training_time = (
dataset.validation_time
if dataset.validation_time is not None
else dataset.intervention_time
)
condition_training_time = df[dataset.time_column] < training_time
condition_control_units = ~df[dataset.unit_column].is_in(dataset.intervention_units)

# weights for variables and scale each variables
Expand All @@ -64,12 +68,12 @@ def synthetic_control(dataset: sx.Dataset) -> sx.SyntheticControlResult:
)

# dataframe for control
df_control = df.filter(condition_pre_intervention_time & condition_control_units)
df_control = df.filter(condition_training_time & condition_control_units)

control_unit_weights: list[np.ndarray] = []
for intervention_unit in dataset.intervention_units:
condition_test_unit = df[dataset.unit_column] == intervention_unit
df_test = df.filter(condition_pre_intervention_time & condition_test_unit)
df_test = df.filter(condition_training_time & condition_test_unit)

# optimize unit weights
def objective(unit_weights: np.ndarray) -> float:
Expand Down Expand Up @@ -182,6 +186,7 @@ def process_placebo(
covariate_columns=dataset.covariate_columns,
intervention_units=test_unit_placebo,
intervention_time=dataset.intervention_time,
validation_time=dataset.validation_time,
)
try:
sc_placebo = synthetic_control(dataset_placebo)
Expand Down Expand Up @@ -246,6 +251,7 @@ def sensitivity_check(
covariate_columns=dataset.covariate_columns,
intervention_units=dataset.intervention_units,
intervention_time=dataset.intervention_time,
validation_time=dataset.validation_time,
)

try:
Expand Down
28 changes: 28 additions & 0 deletions tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def test_init_single_intervention_unit(self, sample_data: pl.DataFrame) -> None:
)
assert dataset.intervention_units == [1]

def test_init_with_validation_time(self, sample_data: pl.DataFrame) -> None:
dataset = sx.Dataset(
data=sample_data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=['cov1', 'cov2'],
intervention_units=1,
intervention_time=3,
validation_time=2,
)
assert dataset.validation_time == 2

def test_plot(self, sample_data: pl.DataFrame, mocker: MockerFixture) -> None:
dataset = sx.Dataset(
data=sample_data,
Expand Down Expand Up @@ -160,6 +173,21 @@ def test_validate_invalid_intervention_time(self, sample_data: pl.DataFrame) ->
intervention_time=4,
)

def test_validate_invalid_intervention_time_vs_validation_time(
self, sample_data: pl.DataFrame
) -> None:
with pytest.raises(InvalidInterventionTimeError):
sx.Dataset(
data=sample_data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=['cov1', 'cov2'],
intervention_units=[1],
intervention_time=4,
validation_time=5,
)

def test_validate_missing_y_column(self, sample_data: pl.DataFrame) -> None:
with pytest.raises(ColumnNotFoundError):
sx.Dataset(
Expand Down