Skip to content

Commit

Permalink
Merge pull request #59 from kenki931128/feature/filter-lift-and-moe
Browse files Browse the repository at this point in the history
Filtering the series based on lift and margin of error threshold
  • Loading branch information
kenki931128 authored Apr 30, 2024
2 parents 3725b28 + 1afdcb3 commit 9cddc35
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 1 deletion.
126 changes: 126 additions & 0 deletions examples/0 - sample data.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "synthx"
version = "1.0.2"
version = "1.1.0"
description = "A Python Library for Advanced Synthetic Control Analysis"
authors = ["kenki931128 <kenki.nkmr@gmail.com>"]
license = "MIT License"
Expand Down
92 changes: 92 additions & 0 deletions synthx/core/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Class for dataset."""

import sys
from datetime import date
from typing import Any, Optional, Union

import matplotlib.pyplot as plt
import polars as pl
from tqdm import tqdm

from synthx.errors import (
ColumnNotFoundError,
Expand All @@ -13,6 +15,7 @@
InvalidInterventionTimeError,
InvalidInterventionUnitError,
InvalidNormalizationError,
StrictFilteringError,
)
from synthx.stats import norm

Expand Down Expand Up @@ -93,6 +96,95 @@ def plot(self, units: Optional[list[str]] = None, save: Optional[str] = None) ->
else:
plt.show()

def filtered_by_lift_and_moe(
self,
*,
lift_threshold: Optional[float] = None,
moe_threshold: Optional[float] = None,
write_progress: bool = False,
) -> 'Dataset':
"""Filter the dataset based on lift and margin of error thresholds.
Args:
lift_threshold (Optional[float]): The threshold for the lift. Units with lift values
outside the range [1 / lift_threshold, lift_threshold] will be excluded.
If None, no lift-based filtering is performed.
moe_threshold (Optional[float]): The threshold for the margin of error. Units with
margin of error values outside the range [-moe_threshold, moe_threshold] will be excluded.
If None, no margin of error-based filtering is performed.
write_progress (bool): Whether to write progress information to stderr.
Returns:
Dataset: A new Dataset object containing the filtered data.
Raises:
ValueError: If lift_threshold or moe_threshold is not positive.
StrictFilteringError: If all units or all intervention units are filtered out.
Consider loosening the thresholds in this case.
"""
if lift_threshold is not None and lift_threshold <= 0:
raise ValueError('lift_threshold should be positive.')
if moe_threshold is not None and moe_threshold <= 0:
raise ValueError('moe_threshold should be positive.')

df = self.data
units_excluded = []

for unit in tqdm(df[self.unit_column].unique()):
df_unit = df.filter(pl.col(self.unit_column) == unit)

mean = df_unit[self.y_column].mean()
std = df_unit[self.y_column].std()

lift = df_unit[self.y_column] / mean
moe = (df_unit[self.y_column] - mean) / std

str_range = f'lift: {lift.min():.3f} ~ {lift.max():.3f}, moe: {moe.min():.3f} ~ {moe.max():.3f}' # type: ignore
if lift_threshold is not None and (
(lift < 1 / lift_threshold).any() or (lift_threshold < lift).any()
):
if write_progress:
tqdm.write(
f'unit {unit} out of lift threshold. {str_range}',
file=sys.stderr,
)
units_excluded.append(unit)
elif moe_threshold is not None and (
(moe < -moe_threshold).any() or (moe_threshold < moe).any()
):
if write_progress:
tqdm.write(
f'unit {unit} out of moe threshold. {str_range}',
file=sys.stderr,
)
units_excluded.append(unit)
elif write_progress:
tqdm.write(
f'unit {unit} kept. {str_range}',
file=sys.stderr,
)

df = df.filter(~pl.col(self.unit_column).is_in(units_excluded))
if len(df) == 0:
raise StrictFilteringError('all units filterred out. Consider loosing thresholds.')
intervention_units = [u for u in self.intervention_units if u not in units_excluded]
if len(intervention_units) == 0:
raise StrictFilteringError(
'all intervention units filterred out. Consider loosing thresholds.'
)

return Dataset(
df,
unit_column=self.unit_column,
time_column=self.time_column,
y_column=self.y_column,
covariate_columns=self.covariate_columns,
intervention_units=intervention_units,
intervention_time=self.intervention_time,
validation_time=self.validation_time,
norm=self.norm,
)

def __validate(self) -> None:
"""Validate the dataset and raise appropriate errors if any issues are found.
Expand Down
4 changes: 4 additions & 0 deletions synthx/errors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ class InconsistentTimestampsError(Exception):

class InvalidNormalizationError(Exception):
"""Exception raised when undefined normalization is selected."""


class StrictFilteringError(Exception):
"""Exception raised when no data exists after filtering."""
43 changes: 43 additions & 0 deletions tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,49 @@ def test_plot_with_save(self, sample_data: pl.DataFrame, mocker: MockerFixture)
dataset.plot(save='test.png')
assert True # Assert that the method runs without errors

def test_filtered_by_lift_and_moe(self, sample_data: pl.DataFrame) -> None:
dataset = sx.Dataset(
data=sample_data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=['cov1', 'cov2'],
intervention_units=[1],
intervention_time=2,
)
filtered_dataset = dataset.filtered_by_lift_and_moe(lift_threshold=2.0, moe_threshold=1.0)
assert filtered_dataset.data.shape == (6, 5) # Assuming no units are filtered out

def test_filtered_by_lift_and_moe_invalid_lift_threshold(
self, sample_data: pl.DataFrame
) -> None:
dataset = sx.Dataset(
data=sample_data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=['cov1', 'cov2'],
intervention_units=[1],
intervention_time=2,
)
with pytest.raises(ValueError):
dataset.filtered_by_lift_and_moe(lift_threshold=-1.0)

def test_filtered_by_lift_and_moe_invalid_moe_threshold(
self, sample_data: pl.DataFrame
) -> None:
dataset = sx.Dataset(
data=sample_data,
unit_column='unit',
time_column='time',
y_column='y',
covariate_columns=['cov1', 'cov2'],
intervention_units=[1],
intervention_time=2,
)
with pytest.raises(ValueError):
dataset.filtered_by_lift_and_moe(moe_threshold=0.0)

def test_validate_missing_unit_column(self, sample_data: pl.DataFrame) -> None:
with pytest.raises(ColumnNotFoundError):
sx.Dataset(
Expand Down

0 comments on commit 9cddc35

Please sign in to comment.