Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve tests mmm utils #738

Merged
merged 7 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 0 additions & 27 deletions pymc_marketing/mmm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
"""Utility functions for the Marketing Mix Modeling module."""

import re
from collections.abc import Callable
from typing import Any

Expand Down Expand Up @@ -227,32 +226,6 @@ def find_sigmoid_inflection_point(
return x_inflection, y_inflection


def standardize_scenarios_dict_keys(d: dict, keywords: list[str]):
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
"""
Standardize the keys in a dictionary based on a list of keywords.

This function iterates over the keys in the dictionary and the keywords.
If a keyword is found in a key (case-insensitive), the key is replaced with the keyword.

Parameters
----------
d : dict
The dictionary whose keys are to be standardized.
keywords : list
The list of keywords to standardize the keys to.

Returns
-------
None
The function modifies the given dictionary in-place and doesn't return any object.
"""
for keyword in keywords:
for key in list(d.keys()):
if re.search(keyword, key, re.IGNORECASE):
d[keyword] = d.pop(key)
break


def apply_sklearn_transformer_across_dim(
data: xr.DataArray,
func: Callable[[np.ndarray], np.ndarray],
Expand Down
32 changes: 32 additions & 0 deletions tests/mmm/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
import numpy as np
import pandas as pd
import pymc as pm
import pytest
import xarray as xr
from sklearn.preprocessing import MaxAbsScaler

from pymc_marketing.mmm.utils import (
_get_distribution_from_dict,
apply_sklearn_transformer_across_dim,
compute_sigmoid_second_derivative,
create_new_spend_data,
Expand Down Expand Up @@ -344,3 +346,33 @@ def test_create_new_spend_data(
new_spend_data,
np.array(expected_result),
)


def test_create_new_spend_data_value_errors() -> None:
with pytest.raises(
ValueError, match="spend_leading_up must be the same length as the spend"
):
create_new_spend_data(
spend=np.array([1, 2]),
adstock_max_lag=2,
one_time=True,
spend_leading_up=np.array([3, 4, 5]),
)


@pytest.mark.parametrize(
argnames="distribution_dict, expected",
argvalues=[
({"dist": "Normal"}, pm.Normal),
({"dist": "Gamma"}, pm.Gamma),
({"dist": "StudentT"}, pm.StudentT),
],
ids=["Normal", "Gamma", "StudentT"],
)
def test_get_distribution_from_dict(distribution_dict, expected):
assert _get_distribution_from_dict(distribution_dict) == expected


def test_get_distribution_from_dict_value_error():
with pytest.raises(ValueError):
_get_distribution_from_dict({"dist": "InvalidDistribution"})
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
Loading