diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 7186e4a46..4424188b4 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -2232,7 +2232,7 @@ def allocate_budget_to_maximize_response( def plot_budget_allocation( self, - samples: az.InferenceData, + samples: Dataset, figsize: tuple[float, float] = (12, 6), ax: plt.Axes | None = None, original_scale: bool = True, @@ -2242,8 +2242,8 @@ def plot_budget_allocation( Parameters ---------- - samples : az.InferenceData - The inference data containing the channel contributions. + samples : Dataset + The dataset containing the channel contributions. figsize : tuple[float, float], optional The size of the figure to be created, by default (12, 6). ax : plt.Axes, optional @@ -2331,7 +2331,7 @@ def plot_budget_allocation( def plot_allocated_contribution_by_channel( self, - samples: az.InferenceData, + samples: Dataset, lower_quantile: float = 0.025, upper_quantile: float = 0.975, original_scale: bool = True, @@ -2345,8 +2345,8 @@ def plot_allocated_contribution_by_channel( Parameters ---------- - samples : az.InferenceData - The inference data containing the samples of channel contributions. + samples : Dataset + The dataset containing the samples of channel contributions. lower_quantile : float, optional The lower quantile for the uncertainty interval. Default is 0.025. upper_quantile : float, optional