diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index 5ddf29a33..bdab7e4f4 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -663,34 +663,6 @@ def compute_channel_contribution_original_scale(self) -> DataArray: coords=channel_contribution.coords, ) - def _get_distribution_from_dict(self, dist: dict) -> Callable: - """ - Retrieve a PyMC distribution callable based on the provided dictionary. - - Parameters - ---------- - dist : Dict - A dictionary containing the key 'dist' which should correspond to the - name of a PyMC distribution. - - Returns - ------- - Callable - A PyMC distribution callable that can be used to instantiate a random - variable. - - Raises - ------ - ValueError - If the specified distribution name in the dictionary does not correspond - to any distribution in PyMC. - """ - try: - prior_distribution = getattr(pm, dist["dist"]) - except AttributeError: - raise ValueError(f"Distribution {dist['dist']} does not exist in PyMC") - return prior_distribution - def compute_mean_contributions_over_time( self, original_scale: bool = False ) -> pd.DataFrame: diff --git a/pymc_marketing/mmm/utils.py b/pymc_marketing/mmm/utils.py index 49978f404..9d79dea8d 100644 --- a/pymc_marketing/mmm/utils.py +++ b/pymc_marketing/mmm/utils.py @@ -13,7 +13,6 @@ # limitations under the License. """Utility functions for the Marketing Mix Modeling module.""" -import re from collections.abc import Callable from typing import Any @@ -227,32 +226,6 @@ def find_sigmoid_inflection_point( return x_inflection, y_inflection -def standardize_scenarios_dict_keys(d: dict, keywords: list[str]): - """ - Standardize the keys in a dictionary based on a list of keywords. - - This function iterates over the keys in the dictionary and the keywords. - If a keyword is found in a key (case-insensitive), the key is replaced with the keyword. - - Parameters - ---------- - d : dict - The dictionary whose keys are to be standardized. - keywords : list - The list of keywords to standardize the keys to. - - Returns - ------- - None - The function modifies the given dictionary in-place and doesn't return any object. - """ - for keyword in keywords: - for key in list(d.keys()): - if re.search(keyword, key, re.IGNORECASE): - d[keyword] = d.pop(key) - break - - def apply_sklearn_transformer_across_dim( data: xr.DataArray, func: Callable[[np.ndarray], np.ndarray], diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index 2a9091323..46ad98c96 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -863,16 +863,6 @@ def test_new_data_predict_method( # assert lower < toy_y.mean() < upper -def test_get_valid_distribution(mmm): - normal_dist = mmm._get_distribution_from_dict({"dist": "Normal"}) - assert normal_dist is pm.Normal - - -def test_get_invalid_distribution(mmm): - with pytest.raises(ValueError, match="does not exist in PyMC"): - mmm._get_distribution_from_dict({"dist": "NonExistentDist"}) - - def test_invalid_likelihood_type(mmm): with pytest.raises( ValueError, diff --git a/tests/mmm/test_utils.py b/tests/mmm/test_utils.py index d6334302a..0829933bf 100644 --- a/tests/mmm/test_utils.py +++ b/tests/mmm/test_utils.py @@ -13,11 +13,13 @@ # limitations under the License. import numpy as np import pandas as pd +import pymc as pm import pytest import xarray as xr from sklearn.preprocessing import MaxAbsScaler from pymc_marketing.mmm.utils import ( + _get_distribution_from_dict, apply_sklearn_transformer_across_dim, compute_sigmoid_second_derivative, create_new_spend_data, @@ -344,3 +346,33 @@ def test_create_new_spend_data( new_spend_data, np.array(expected_result), ) + + +def test_create_new_spend_data_value_errors() -> None: + with pytest.raises( + ValueError, match="spend_leading_up must be the same length as the spend" + ): + create_new_spend_data( + spend=np.array([1, 2]), + adstock_max_lag=2, + one_time=True, + spend_leading_up=np.array([3, 4, 5]), + ) + + +@pytest.mark.parametrize( + argnames="distribution_dict, expected", + argvalues=[ + ({"dist": "Normal"}, pm.Normal), + ({"dist": "Gamma"}, pm.Gamma), + ({"dist": "StudentT"}, pm.StudentT), + ], + ids=["Normal", "Gamma", "StudentT"], +) +def test_get_distribution_from_dict(distribution_dict, expected): + assert _get_distribution_from_dict(distribution_dict) == expected + + +def test_get_distribution_from_dict_value_error(): + with pytest.raises(ValueError): + _get_distribution_from_dict({"dist": "InvalidDistribution"})