Skip to content

Commit

Permalink
Merge pull request #47 from kenki931128/feature/validation-period
Browse files Browse the repository at this point in the history
Validation period
  • Loading branch information
kenki931128 authored Apr 23, 2024
2 parents 77fbdeb + 5835ddb commit cda5475
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ You can estimate the causal effect of the intervention using the `estimate_effec

```python
>>> sc.estimate_effects()
0.8398940970771678
[0.8398940970771678]
```

### Placebo Test
Expand Down
5 changes: 3 additions & 2 deletions examples/0 - sample data.ipynb

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions examples/1 - synthetic control.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "synthx"
version = "0.4.3"
version = "1.0.0"
description = "A Python Library for Advanced Synthetic Control Analysis"
authors = ["kenki931128 <kenki.nkmr@gmail.com>"]
license = "MIT License"
Expand Down
20 changes: 19 additions & 1 deletion synthx/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
covariate_columns: Optional[list[str]],
intervention_units: Union[Any, list[Any]],
intervention_time: Union[int, date],
validation_time: Optional[Union[int, date]] = None,
norm: Optional[str] = None,
) -> None:
"""Initialize the Dataset.
Expand All @@ -42,6 +43,7 @@ def __init__(
covariate_columns (Optional[list[str]]): The columns representing the covariates.
intervention_units (Union[Any, list[Any]]): A list of intervented units
intervention_time (Union[int, date]): When the intervention or event happens.
validation_time (Optional[Union[int, date]]): validation time if needed.
norm (Optional[str]): If not None, will normalize the y_column.
it should be 'z_standardize'
"""
Expand All @@ -54,6 +56,7 @@ def __init__(
intervention_units if isinstance(intervention_units, list) else [intervention_units]
)
self.intervention_time = intervention_time
self.validation_time = validation_time
self.norm = norm
self.__validate()
self.__normalization()
Expand All @@ -71,7 +74,12 @@ def plot(self, units: Optional[list[str]] = None, save: Optional[str] = None) ->
for unit in units:
unit_data = self.data.filter(pl.col(self.unit_column) == unit)
plt.plot(unit_data[self.time_column], unit_data[self.y_column], label=f'Unit {unit}')
# Add vertical line for intervention time

# Add vertical line for validation time and intervention time
if self.validation_time is not None:
plt.axvline(
self.validation_time, color='orange', linestyle='--', label='Validation Time' # type: ignore
)
plt.axvline(
self.intervention_time, color='red', linestyle='--', label='Intervention Time' # type: ignore
)
Expand Down Expand Up @@ -117,6 +125,16 @@ def __validate(self) -> None:
)
if self.intervention_time > self.data[self.time_column].max(): # type: ignore
raise InvalidInterventionTimeError(f'no date point at {self.intervention_time} time.')
if self.validation_time is not None and type(self.intervention_time) != type(
self.validation_time
):
raise InvalidColumnTypeError(
f'intervention_time and validation_time should have the same type.'
)
if self.validation_time is not None and self.intervention_time <= self.validation_time: # type: ignore
raise InvalidInterventionTimeError(
f'intervention_time should be later than validation_time.'
)

if self.y_column not in self.data.columns:
raise ColumnNotFoundError(self.y_column)
Expand Down
16 changes: 14 additions & 2 deletions synthx/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,14 @@ def estimate_effects(self) -> list[float]:
Returns:
list[float]: The estimated effect of the intervention.
"""
# dataset before intervention
# dataset in the training period
training_time = (
self.dataset.validation_time
if self.dataset.validation_time is not None
else self.dataset.intervention_time
)
pre_df = self.dataset.data.filter(
self.dataset.data[self.dataset.time_column] < self.dataset.intervention_time
self.dataset.data[self.dataset.time_column] < training_time
)
pre_result = SyntheticControlResult(
dataset=sx.Dataset(
Expand Down Expand Up @@ -159,6 +164,13 @@ def plot(self, save: Optional[str] = None) -> None:
for i, intervention_unit in enumerate(self.dataset.intervention_units):
axs[i].plot(self.x_time, self.y_test(intervention_unit), label='Test')
axs[i].plot(self.x_time, self.y_control(intervention_unit), label='Control')
if self.dataset.validation_time is not None:
axs[i].axvline(
self.dataset.validation_time,
color='orange',
linestyle='--',
label='Validation Time',
)
axs[i].axvline(
self.dataset.intervention_time,
color='red',
Expand Down
14 changes: 10 additions & 4 deletions synthx/method.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ def synthetic_control(dataset: sx.Dataset) -> sx.SyntheticControlResult:
df = dataset.data

# condition
# TODO: add validation period
condition_pre_intervention_time = df[dataset.time_column] < dataset.intervention_time
training_time = (
dataset.validation_time
if dataset.validation_time is not None
else dataset.intervention_time
)
condition_training_time = df[dataset.time_column] < training_time
condition_control_units = ~df[dataset.unit_column].is_in(dataset.intervention_units)

# weights for variables and scale each variables
Expand All @@ -64,12 +68,12 @@ def synthetic_control(dataset: sx.Dataset) -> sx.SyntheticControlResult:
)

# dataframe for control
df_control = df.filter(condition_pre_intervention_time & condition_control_units)
df_control = df.filter(condition_training_time & condition_control_units)

control_unit_weights: list[np.ndarray] = []
for intervention_unit in dataset.intervention_units:
condition_test_unit = df[dataset.unit_column] == intervention_unit
df_test = df.filter(condition_pre_intervention_time & condition_test_unit)
df_test = df.filter(condition_training_time & condition_test_unit)

# optimize unit weights
def objective(unit_weights: np.ndarray) -> float:
Expand Down Expand Up @@ -182,6 +186,7 @@ def process_placebo(
covariate_columns=dataset.covariate_columns,
intervention_units=test_unit_placebo,
intervention_time=dataset.intervention_time,
validation_time=dataset.validation_time,
)
try:
sc_placebo = synthetic_control(dataset_placebo)
Expand Down Expand Up @@ -246,6 +251,7 @@ def sensitivity_check(
covariate_columns=dataset.covariate_columns,
intervention_units=dataset.intervention_units,
intervention_time=dataset.intervention_time,
validation_time=dataset.validation_time,
)

try:
Expand Down
28 changes: 28 additions & 0 deletions tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def test_init_single_intervention_unit(self, sample_data: pl.DataFrame) -> None:
)
assert dataset.intervention_units == [1]

def test_init_with_validation_time(self, sample_data: pl.DataFrame) -> None:
dataset = sx.Dataset(
data=sample_data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=['cov1', 'cov2'],
intervention_units=1,
intervention_time=3,
validation_time=2,
)
assert dataset.validation_time == 2

def test_plot(self, sample_data: pl.DataFrame, mocker: MockerFixture) -> None:
dataset = sx.Dataset(
data=sample_data,
Expand Down Expand Up @@ -160,6 +173,21 @@ def test_validate_invalid_intervention_time(self, sample_data: pl.DataFrame) ->
intervention_time=4,
)

def test_validate_invalid_intervention_time_vs_validation_time(
self, sample_data: pl.DataFrame
) -> None:
with pytest.raises(InvalidInterventionTimeError):
sx.Dataset(
data=sample_data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=['cov1', 'cov2'],
intervention_units=[1],
intervention_time=4,
validation_time=5,
)

def test_validate_missing_y_column(self, sample_data: pl.DataFrame) -> None:
with pytest.raises(ColumnNotFoundError):
sx.Dataset(
Expand Down

0 comments on commit cda5475

Please sign in to comment.