Skip to content

Commit

Permalink
Merge pull request #60 from kenki931128/develop
Browse files Browse the repository at this point in the history
develop into main v1.1.0
  • Loading branch information
kenki931128 authored Apr 30, 2024
2 parents 69e092b + 9cddc35 commit 162e387
Show file tree
Hide file tree
Showing 8 changed files with 406 additions and 68 deletions.
126 changes: 126 additions & 0 deletions examples/0 - sample data.ipynb

Large diffs are not rendered by default.

115 changes: 48 additions & 67 deletions examples/1 - synthetic control.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "synthx"
version = "1.0.2"
version = "1.1.0"
description = "A Python Library for Advanced Synthetic Control Analysis"
authors = ["kenki931128 <kenki.nkmr@gmail.com>"]
license = "MIT License"
Expand Down
92 changes: 92 additions & 0 deletions synthx/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Class for dataset."""

import sys
from datetime import date
from typing import Any, Optional, Union

import matplotlib.pyplot as plt
import polars as pl
from tqdm import tqdm

from synthx.errors import (
ColumnNotFoundError,
Expand All @@ -13,6 +15,7 @@
InvalidInterventionTimeError,
InvalidInterventionUnitError,
InvalidNormalizationError,
StrictFilteringError,
)
from synthx.stats import norm

Expand Down Expand Up @@ -93,6 +96,95 @@ def plot(self, units: Optional[list[str]] = None, save: Optional[str] = None) ->
else:
plt.show()

def filtered_by_lift_and_moe(
self,
*,
lift_threshold: Optional[float] = None,
moe_threshold: Optional[float] = None,
write_progress: bool = False,
) -> 'Dataset':
"""Filter the dataset based on lift and margin of error thresholds.
Args:
lift_threshold (Optional[float]): The threshold for the lift. Units with lift values
outside the range [1 / lift_threshold, lift_threshold] will be excluded.
If None, no lift-based filtering is performed.
moe_threshold (Optional[float]): The threshold for the margin of error. Units with
margin of error values outside the range [-moe_threshold, moe_threshold] will be excluded.
If None, no margin of error-based filtering is performed.
write_progress (bool): Whether to write progress information to stderr.
Returns:
Dataset: A new Dataset object containing the filtered data.
Raises:
ValueError: If lift_threshold or moe_threshold is not positive.
StrictFilteringError: If all units or all intervention units are filtered out.
Consider loosening the thresholds in this case.
"""
if lift_threshold is not None and lift_threshold <= 0:
raise ValueError('lift_threshold should be positive.')
if moe_threshold is not None and moe_threshold <= 0:
raise ValueError('moe_threshold should be positive.')

df = self.data
units_excluded = []

for unit in tqdm(df[self.unit_column].unique()):
df_unit = df.filter(pl.col(self.unit_column) == unit)

mean = df_unit[self.y_column].mean()
std = df_unit[self.y_column].std()

lift = df_unit[self.y_column] / mean
moe = (df_unit[self.y_column] - mean) / std

str_range = f'lift: {lift.min():.3f} ~ {lift.max():.3f}, moe: {moe.min():.3f} ~ {moe.max():.3f}' # type: ignore
if lift_threshold is not None and (
(lift < 1 / lift_threshold).any() or (lift_threshold < lift).any()
):
if write_progress:
tqdm.write(
f'unit {unit} out of lift threshold. {str_range}',
file=sys.stderr,
)
units_excluded.append(unit)
elif moe_threshold is not None and (
(moe < -moe_threshold).any() or (moe_threshold < moe).any()
):
if write_progress:
tqdm.write(
f'unit {unit} out of moe threshold. {str_range}',
file=sys.stderr,
)
units_excluded.append(unit)
elif write_progress:
tqdm.write(
f'unit {unit} kept. {str_range}',
file=sys.stderr,
)

df = df.filter(~pl.col(self.unit_column).is_in(units_excluded))
if len(df) == 0:
raise StrictFilteringError('all units filterred out. Consider loosing thresholds.')
intervention_units = [u for u in self.intervention_units if u not in units_excluded]
if len(intervention_units) == 0:
raise StrictFilteringError(
'all intervention units filterred out. Consider loosing thresholds.'
)

return Dataset(
df,
unit_column=self.unit_column,
time_column=self.time_column,
y_column=self.y_column,
covariate_columns=self.covariate_columns,
intervention_units=intervention_units,
intervention_time=self.intervention_time,
validation_time=self.validation_time,
norm=self.norm,
)

def __validate(self) -> None:
"""Validate the dataset and raise appropriate errors if any issues are found.
Expand Down
51 changes: 51 additions & 0 deletions synthx/core/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,57 @@ def estimate_effects(self) -> list[float]:
for intervention_unit in self.dataset.intervention_units
]

def validation_differences(self) -> Optional[list[float]]:
"""Calculate the difference between training and validation.
Returns:
Optional[list[float]]: The difference between training and validation.
"""
if self.dataset.validation_time is None:
return None

# dataset in the training period
pre_df = self.dataset.data.filter(
self.dataset.data[self.dataset.time_column] < self.dataset.validation_time
)
pre_result = SyntheticControlResult(
dataset=sx.Dataset(
pre_df,
unit_column=self.dataset.unit_column,
time_column=self.dataset.time_column,
y_column=self.dataset.y_column,
covariate_columns=self.dataset.covariate_columns,
intervention_units=self.dataset.intervention_units,
intervention_time=0,
),
control_unit_weights=self.control_unit_weights,
)
# dataset in the validation period
val_df = self.dataset.data.filter(
(self.dataset.data[self.dataset.time_column] >= self.dataset.validation_time)
& (self.dataset.data[self.dataset.time_column] < self.dataset.intervention_time)
)
val_result = SyntheticControlResult(
dataset=sx.Dataset(
val_df,
unit_column=self.dataset.unit_column,
time_column=self.dataset.time_column,
y_column=self.dataset.y_column,
covariate_columns=self.dataset.covariate_columns,
intervention_units=self.dataset.intervention_units,
intervention_time=0,
),
control_unit_weights=self.control_unit_weights,
)

return [
np.mean(val_result.y_test(intervention_unit) - val_result.y_control(intervention_unit))
- np.mean(
pre_result.y_test(intervention_unit) - pre_result.y_control(intervention_unit)
)
for intervention_unit in self.dataset.intervention_units
]

def plot(self, save: Optional[str] = None) -> None:
"""Plot the target variable over time for both test and control units.
Expand Down
4 changes: 4 additions & 0 deletions synthx/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ class InconsistentTimestampsError(Exception):

class InvalidNormalizationError(Exception):
"""Exception raised when undefined normalization is selected."""


class StrictFilteringError(Exception):
"""Exception raised when no data exists after filtering."""
43 changes: 43 additions & 0 deletions tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,49 @@ def test_plot_with_save(self, sample_data: pl.DataFrame, mocker: MockerFixture)
dataset.plot(save='test.png')
assert True # Assert that the method runs without errors

def test_filtered_by_lift_and_moe(self, sample_data: pl.DataFrame) -> None:
dataset = sx.Dataset(
data=sample_data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=['cov1', 'cov2'],
intervention_units=[1],
intervention_time=2,
)
filtered_dataset = dataset.filtered_by_lift_and_moe(lift_threshold=2.0, moe_threshold=1.0)
assert filtered_dataset.data.shape == (6, 5) # Assuming no units are filtered out

def test_filtered_by_lift_and_moe_invalid_lift_threshold(
self, sample_data: pl.DataFrame
) -> None:
dataset = sx.Dataset(
data=sample_data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=['cov1', 'cov2'],
intervention_units=[1],
intervention_time=2,
)
with pytest.raises(ValueError):
dataset.filtered_by_lift_and_moe(lift_threshold=-1.0)

def test_filtered_by_lift_and_moe_invalid_moe_threshold(
self, sample_data: pl.DataFrame
) -> None:
dataset = sx.Dataset(
data=sample_data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=['cov1', 'cov2'],
intervention_units=[1],
intervention_time=2,
)
with pytest.raises(ValueError):
dataset.filtered_by_lift_and_moe(moe_threshold=0.0)

def test_validate_missing_unit_column(self, sample_data: pl.DataFrame) -> None:
with pytest.raises(ColumnNotFoundError):
sx.Dataset(
Expand Down
41 changes: 41 additions & 0 deletions tests/core/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,26 @@ def dummy_dataset(self) -> sx.Dataset:
intervention_time=2,
)

@pytest.fixture
def dummy_dataset_val(self) -> sx.Dataset:
data = pl.DataFrame(
{
'unit': [1, 1, 1, 2, 2, 2, 3, 3, 3],
'time': [1, 2, 3, 1, 2, 3, 1, 2, 3],
'y': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
}
)
return sx.Dataset(
data=data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=[],
intervention_units=[1],
intervention_time=3,
validation_time=2,
)

@pytest.fixture
def dummy_result(self, dummy_dataset: sx.Dataset) -> sx.SyntheticControlResult:
control_unit_weights = np.array([0.5, 0.5])
Expand All @@ -36,6 +56,14 @@ def dummy_result(self, dummy_dataset: sx.Dataset) -> sx.SyntheticControlResult:
control_unit_weights=control_unit_weights,
)

@pytest.fixture
def dummy_result_val(self, dummy_dataset_val: sx.Dataset) -> sx.SyntheticControlResult:
control_unit_weights = np.array([0.5, 0.5])
return sx.SyntheticControlResult(
dataset=dummy_dataset_val,
control_unit_weights=control_unit_weights,
)

def test_init(self, dummy_dataset: sx.Dataset) -> None:
control_unit_weights = np.array([0.5, 0.5])
result = sx.SyntheticControlResult(
Expand Down Expand Up @@ -77,6 +105,19 @@ def test_estimate_effects(self, dummy_result: sx.SyntheticControlResult) -> None
expected_effect = 0
assert dummy_result.estimate_effects()[0] == expected_effect

def test_validation_differences_no_validation_time(
self, dummy_result: sx.SyntheticControlResult
) -> None:
assert dummy_result.validation_differences() is None

def test_validation_differences_with_validation_time(
self, dummy_result_val: sx.SyntheticControlResult
) -> None:
expected_difference = 0
val_result = dummy_result_val.validation_differences()
assert val_result is not None
assert val_result[0] == expected_difference

def test_plot(self, dummy_result: sx.SyntheticControlResult, mocker: MockerFixture) -> None:
mocker.patch('matplotlib.pyplot.show')
dummy_result.plot()
Expand Down

0 comments on commit 162e387

Please sign in to comment.