Skip to content

Commit

Permalink
MMM data_df param renamed to data (pymc-labs#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelraczycki committed Apr 6, 2023
1 parent 84ec7a2 commit 227e6a9
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 77 deletions.
16 changes: 8 additions & 8 deletions docs/source/notebooks/mmm/mmm_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -773,9 +773,9 @@
" \"dayofyear\",\n",
"]\n",
"\n",
"data_df = df[columns_to_keep].copy()\n",
"data = df[columns_to_keep].copy()\n",
"\n",
"data_df.head()"
"data.head()"
]
},
{
Expand Down Expand Up @@ -941,17 +941,17 @@
"source": [
"# Fourier modes\n",
"fourier_modes = generate_fourier_modes(\n",
" periods=data_df[\"dayofyear\"] / 365.25,\n",
" periods=data[\"dayofyear\"] / 365.25,\n",
" n_order=2\n",
")\n",
"\n",
"# trend feature\n",
"\n",
"data_df[\"t\"] = range(n)\n",
"data[\"t\"] = range(n)\n",
"\n",
"data_df = pd.concat([data_df, fourier_modes], axis=1)\n",
"data = pd.concat([data, fourier_modes], axis=1)\n",
"\n",
"data_df.head()"
"data.head()"
]
},
{
Expand All @@ -977,7 +977,7 @@
"outputs": [],
"source": [
"mmm = DelayedSaturatedMMM(\n",
" data_df=data_df,\n",
" data=data,\n",
" target_column=\"y\",\n",
" date_column=\"date_week\",\n",
" channel_columns=[\"x1\", \"x2\"],\n",
Expand Down Expand Up @@ -4666,7 +4666,7 @@
"\n",
"roas_samples = (\n",
" channel_contribution_original_scale.stack(sample=(\"chain\", \"draw\")).sum(\"date\")\n",
" / data_df[[\"x1\", \"x2\"]].sum().to_numpy()[..., None]\n",
" / data[[\"x1\", \"x2\"]].sum().to_numpy()[..., None]\n",
")\n",
"\n",
"fig, ax = plt.subplots(figsize=(10, 6))\n",
Expand Down
48 changes: 24 additions & 24 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,28 @@
class BaseMMM(ModelBuilder):
def __init__(
self,
data_df: pd.DataFrame,
data: pd.DataFrame,
target_column: str,
date_column: str,
channel_columns: Union[List[str], Tuple[str]],
validate_data: bool = True,
**kwargs,
) -> None:
self.data_df: pd.DataFrame = data_df
self.data: pd.DataFrame = data
self.target_column: str = target_column
self.date_column: str = date_column
self.channel_columns: Union[List[str], Tuple[str]] = channel_columns
self.n_obs: int = data_df.shape[0]
self.n_obs: int = data.shape[0]
self.n_channel: int = len(channel_columns)
self._fit_result: Optional[az.InferenceData] = None
self._posterior_predictive: Optional[az.InferenceData] = None

if validate_data:
self.validate(self.data_df)
self.preprocessed_data = self.preprocess(self.data_df.copy())
self.validate(self.data)
self.preprocessed_data = self.preprocess(self.data.copy())

self.build_model(
data_df=self.preprocessed_data,
data=self.preprocessed_data,
**kwargs,
)
super.__init__(self.data_df, self.model_config)
Expand Down Expand Up @@ -97,14 +97,14 @@ def get_target_transformer(self) -> Pipeline:
identity_transformer = FunctionTransformer()
return Pipeline(steps=[("scaler", identity_transformer)])

def validate(self, data_df: pd.DataFrame):
def validate(self, data: pd.DataFrame):
for method in self.validation_methods:
method(self, data_df)
method(self, data)

def preprocess(self, data_df: pd.DataFrame) -> pd.DataFrame:
def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
for method in self.preprocessing_methods:
data_df = method(self, data_df)
return data_df
data = method(self, data)
return data

@abstractmethod
def build_model(*args, **kwargs):
Expand Down Expand Up @@ -167,7 +167,7 @@ def plot_prior_predictive(
fig, ax = plt.subplots(**plt_kwargs)

ax.fill_between(
x=self.data_df[self.date_column],
x=self.data[self.date_column],
y1=likelihood_hdi_94[:, 0],
y2=likelihood_hdi_94[:, 1],
color="C0",
Expand All @@ -176,7 +176,7 @@ def plot_prior_predictive(
)

ax.fill_between(
x=self.data_df[self.date_column],
x=self.data[self.date_column],
y1=likelihood_hdi_50[:, 0],
y2=likelihood_hdi_50[:, 1],
color="C0",
Expand All @@ -185,7 +185,7 @@ def plot_prior_predictive(
)

ax.plot(
self.data_df[self.date_column],
self.data[self.date_column],
self.preprocessed_data[self.target_column],
color="black",
)
Expand Down Expand Up @@ -215,7 +215,7 @@ def plot_posterior_predictive(
fig, ax = plt.subplots(**plt_kwargs)

ax.fill_between(
x=self.data_df[self.date_column],
x=self.data[self.date_column],
y1=likelihood_hdi_94[:, 0],
y2=likelihood_hdi_94[:, 1],
color="C0",
Expand All @@ -224,7 +224,7 @@ def plot_posterior_predictive(
)

ax.fill_between(
x=self.data_df[self.date_column],
x=self.data[self.date_column],
y1=likelihood_hdi_50[:, 0],
y2=likelihood_hdi_50[:, 1],
color="C0",
Expand All @@ -233,11 +233,11 @@ def plot_posterior_predictive(
)

target_to_plot: pd.Series = (
self.data_df[self.target_column]
self.data[self.target_column]
if original_scale
else self.preprocessed_data[self.target_column]
)
ax.plot(self.data_df[self.date_column], target_to_plot, color="black")
ax.plot(self.data[self.date_column], target_to_plot, color="black")
ax.set(
title="Posterior Predictive Check",
xlabel="date",
Expand Down Expand Up @@ -295,15 +295,15 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure:
)
):
ax.fill_between(
x=self.data_df[self.date_column],
x=self.data[self.date_column],
y1=hdi.isel(hdi=0),
y2=hdi.isel(hdi=1),
color=f"C{i}",
alpha=0.25,
label=f"$94 %$ HDI ({var_contribution})",
)
sns.lineplot(
x=self.data_df[self.date_column],
x=self.data[self.date_column],
y=mean,
color=f"C{i}",
ax=ax,
Expand All @@ -316,21 +316,21 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure:
axis=0,
)
sns.lineplot(
x=self.data_df[self.date_column],
x=self.data[self.date_column],
y=intercept.mean().data,
color=f"C{i + 1}",
ax=ax,
)
ax.fill_between(
x=self.data_df[self.date_column],
x=self.data[self.date_column],
y1=intercept_hdi[:, 0],
y2=intercept_hdi[:, 1],
color=f"C{i + 1}",
alpha=0.25,
label="$94 %$ HDI (intercept)",
)
ax.plot(
self.data_df[self.date_column],
self.data[self.date_column],
self.preprocessed_data[self.target_column],
color="black",
)
Expand Down Expand Up @@ -397,7 +397,7 @@ def plot_contribution_curves(self) -> plt.Figure:
for i, channel in enumerate(self.channel_columns):
ax = axes[i]
sns.regplot(
x=self.data_df[self.channel_columns].to_numpy()[:, i],
x=self.data[self.channel_columns].to_numpy()[:, i],
y=channel_contributions.sel(channel=channel),
color=f"C{i}",
order=2,
Expand Down
14 changes: 7 additions & 7 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DelayedSaturatedMMM(
):
def __init__(
self,
data_df: pd.DataFrame,
data: pd.DataFrame,
target_column: str,
date_column: str,
channel_columns: List[str],
Expand All @@ -31,7 +31,7 @@ def __init__(
self.control_columns = control_columns
self.adstock_max_lag = adstock_max_lag
super().__init__(
data_df=data_df,
data=data,
target_column=target_column,
date_column=date_column,
channel_columns=channel_columns,
Expand All @@ -41,14 +41,14 @@ def __init__(

def build_model(
self,
data_df: pd.DataFrame,
data: pd.DataFrame,
adstock_max_lag: int = 4,
) -> None:
date_data = data_df[self.date_column]
target_data = data_df[self.target_column]
channel_data = data_df[self.channel_columns]
date_data = data[self.date_column]
target_data = data[self.target_column]
channel_data = data[self.channel_columns]
if self.control_columns is not None:
control_data: Optional[pd.DataFrame] = data_df[self.control_columns]
control_data: Optional[pd.DataFrame] = data[self.control_columns]
else:
control_data = None
coords: Dict[str, Any] = {
Expand Down
24 changes: 12 additions & 12 deletions pymc_marketing/mmm/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,38 @@ def preprocessing_method(method: Callable) -> Callable:

class MixMaxScaleTarget:
@preprocessing_method
def min_max_scale_target_data(self, data_df: pd.DataFrame) -> pd.DataFrame:
target_vector = data_df[self.target_column].to_numpy().reshape(-1, 1)
def min_max_scale_target_data(self, data: pd.DataFrame) -> pd.DataFrame:
target_vector = data[self.target_column].to_numpy().reshape(-1, 1)
transformers = [("scaler", MinMaxScaler())]
pipeline = Pipeline(steps=transformers)
self.target_transformer: Pipeline = pipeline.fit(X=target_vector)
data_df[self.target_column] = self.target_transformer.transform(
data[self.target_column] = self.target_transformer.transform(
X=target_vector
).flatten()
return data_df
return data


class MaxAbsScaleChannels:
@preprocessing_method
def max_abs_scale_channel_data(self, data_df: pd.DataFrame) -> pd.DataFrame:
channel_data: pd.DataFrame = data_df[self.channel_columns]
def max_abs_scale_channel_data(self, data: pd.DataFrame) -> pd.DataFrame:
channel_data: pd.DataFrame = data[self.channel_columns]
transformers = [("scaler", MaxAbsScaler())]
pipeline: Pipeline = Pipeline(steps=transformers)
self.channel_transformer: Pipeline = pipeline.fit(X=channel_data.to_numpy())
data_df[self.channel_columns] = self.channel_transformer.transform(
data[self.channel_columns] = self.channel_transformer.transform(
channel_data.to_numpy()
)
return data_df
return data


class StandardizeControls:
@preprocessing_method
def standardize_control_data(self, data_df: pd.DataFrame) -> pd.DataFrame:
control_data: pd.DataFrame = data_df[self.control_columns]
def standardize_control_data(self, data: pd.DataFrame) -> pd.DataFrame:
control_data: pd.DataFrame = data[self.control_columns]
transformers = [("scaler", StandardScaler())]
pipeline: Pipeline = Pipeline(steps=transformers)
self.control_transformer: Pipeline = pipeline.fit(X=control_data.to_numpy())
data_df[self.control_columns] = self.control_transformer.transform(
data[self.control_columns] = self.control_transformer.transform(
control_data.to_numpy()
)
return data_df
return data
28 changes: 14 additions & 14 deletions pymc_marketing/mmm/validating.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,42 @@ def validation_method(method: Callable) -> Callable:

class ValidateTargetColumn:
@validation_method
def validate_target(self, data_df: pd.DataFrame) -> None:
if self.target_column not in data_df.columns:
raise ValueError(f"target {self.target_column} not in data_df")
def validate_target(self, data: pd.DataFrame) -> None:
if self.target_column not in data.columns:
raise ValueError(f"target {self.target_column} not in data")


class ValidateDateColumn:
@validation_method
def validate_date_col(self, data_df: pd.DataFrame) -> None:
if self.date_column not in data_df.columns:
raise ValueError(f"date_col {self.date_column} not in data_df")
if not data_df[self.date_column].is_unique:
def validate_date_col(self, data: pd.DataFrame) -> None:
if self.date_column not in data.columns:
raise ValueError(f"date_col {self.date_column} not in data")
if not data[self.date_column].is_unique:
raise ValueError(f"date_col {self.date_column} has repeated values")


class ValidateChannelColumns:
@validation_method
def validate_channel_columns(self, data_df: pd.DataFrame) -> None:
def validate_channel_columns(self, data: pd.DataFrame) -> None:
if not isinstance(self.channel_columns, (list, tuple)):
raise ValueError("channel_columns must be a list or tuple")
if len(self.channel_columns) == 0:
raise ValueError("channel_columns must not be empty")
if not set(self.channel_columns).issubset(data_df.columns):
raise ValueError(f"channel_columns {self.channel_columns} not in data_df")
if not set(self.channel_columns).issubset(data.columns):
raise ValueError(f"channel_columns {self.channel_columns} not in data")
if len(set(self.channel_columns)) != len(self.channel_columns):
raise ValueError(
f"channel_columns {self.channel_columns} contains duplicates"
)
if (data_df[self.channel_columns] < 0).any().any():
if (data[self.channel_columns] < 0).any().any():
raise ValueError(
f"channel_columns {self.channel_columns} contains negative values"
)


class ValidateControlColumns:
@validation_method
def validate_control_columns(self, data_df: pd.DataFrame) -> None:
def validate_control_columns(self, data: pd.DataFrame) -> None:
if self.control_columns is None:
return None
if not isinstance(self.control_columns, (list, tuple)):
Expand All @@ -64,8 +64,8 @@ def validate_control_columns(self, data_df: pd.DataFrame) -> None:
raise ValueError(
"If control_columns is not None, then it must not be empty"
)
if not set(self.control_columns).issubset(data_df.columns):
raise ValueError(f"control_columns {self.control_columns} not in data_df")
if not set(self.control_columns).issubset(data.columns):
raise ValueError(f"control_columns {self.control_columns} not in data")
if len(set(self.control_columns)) != len(self.control_columns):
raise ValueError(
f"control_columns {self.control_columns} contains duplicates"
Expand Down
Loading

0 comments on commit 227e6a9

Please sign in to comment.