Skip to content

Commit

Permalink
Merge pull request #258 from pymc-labs/fix-plot-error
Browse files Browse the repository at this point in the history
Fix error in plot method
  • Loading branch information
drbenvincent authored Oct 17, 2023
2 parents adea6f7 + 37515db commit f1a522d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
14 changes: 1 addition & 13 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,18 +463,6 @@ class DifferenceInDifferences(ExperimentalDesign):
... }
... )
... )
>>> result.summary() # doctest: +NUMBER
===========================Difference in Differences============================
Formula: y ~ 1 + group*post_treatment
<BLANKLINE>
Results:
Causal impact = 0.5, $CI_{94%}$[0.4, 0.6]
Model coefficients:
Intercept 1.0, 94% HDI [1.0, 1.1]
post_treatment[T.True] 0.9, 94% HDI [0.9, 1.0]
group 0.1, 94% HDI [0.0, 0.2]
group:post_treatment[T.True] 0.5, 94% HDI [0.4, 0.6]
sigma 0.0, 94% HDI [0.0, 0.1]
"""

def __init__(
Expand Down Expand Up @@ -726,7 +714,7 @@ def _plot_causal_impact_arrow(self, ax):
def _causal_impact_summary_stat(self) -> str:
"""Computes the mean and 94% credible interval bounds for the causal impact."""
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
ci = "$CI_{94%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
ci = "$CI_{94\\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
causal_impact = f"{self.causal_impact.mean():.2f}, "
return f"Causal impact = {causal_impact + ci}"

Expand Down
21 changes: 21 additions & 0 deletions causalpy/tests/test_pymc_experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Unit tests for pymc_experiments.py
"""

import causalpy as cp

sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2}


def test_did_summary():
"""Test that the summary stat function returns a string."""
df = cp.load_data("did")
result = cp.pymc_experiments.DifferenceInDifferences(
df,
formula="y ~ 1 + group*post_treatment",
time_variable_name="t",
group_variable_name="group",
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
)
print(type(result._causal_impact_summary_stat()))
assert isinstance(result._causal_impact_summary_stat(), str)
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit f1a522d

Please sign in to comment.