Skip to content

Commit

Permalink
Add ruff RUF rules (#636)
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz authored and twiecki committed Sep 10, 2024
1 parent 194f389 commit 354cd4a
Show file tree
Hide file tree
Showing 20 changed files with 139 additions and 85 deletions.
2 changes: 1 addition & 1 deletion docs/source/notebooks/clv/clv_quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@
"source": [
"ids = [841, 1981, 157, 1516]\n",
"ax = az.plot_posterior(num_purchases.sel(customer_id=ids), grid=(2, 2))\n",
"for axi, id in zip(ax.ravel(), ids):\n",
"for axi, id in zip(ax.ravel(), ids, strict=False):\n",
" axi.set_title(f\"Customer: {id}\", size=20)\n",
"plt.suptitle(\"Expected number purchases in the next period\", fontsize=28, y=1.05);"
]
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notebooks/general/other_nuts_samplers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
" for j, (idata, label) in enumerate(\n",
" zip(\n",
" [idata_blackjax, idata_nutpie, idata_numpyro],\n",
" [\"blackjax\", \"nutpie\", \"numpyro\"],\n",
" [\"blackjax\", \"nutpie\", \"numpyro\"], strict=False,\n",
" )\n",
" ):\n",
" az.plot_posterior(\n",
Expand Down
76 changes: 49 additions & 27 deletions docs/source/notebooks/mmm/mmm_lift_test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"outputs": [],
"source": [
"n_dates = 52\n",
"dates = pd.date_range(start='2024-01-01', periods=n_dates, freq=\"W-MON\")\n",
"dates = pd.date_range(start=\"2024-01-01\", periods=n_dates, freq=\"W-MON\")\n",
"\n",
"spend_rv = pm.Uniform.dist(lower=0, upper=1, size=n_dates)\n",
"\n",
Expand All @@ -127,11 +127,13 @@
}
],
"source": [
"df = pd.DataFrame({\n",
" \"date\": dates,\n",
" \"channel 1\": spend,\n",
" \"channel 2\": spend,\n",
"})\n",
"df = pd.DataFrame(\n",
" {\n",
" \"date\": dates,\n",
" \"channel 1\": spend,\n",
" \"channel 2\": spend,\n",
" }\n",
")\n",
"\n",
"ax = df.set_index(\"date\").plot(ylabel=\"channel spend\");"
]
Expand Down Expand Up @@ -528,7 +530,7 @@
"c2_curve = c2_curve_fn(xx)\n",
"\n",
"\n",
"def plot_actual_curves(ax: plt.Axes, linestyle: str = None) -> plt.Axes:\n",
"def plot_actual_curves(ax: plt.Axes, linestyle: str | None = None) -> plt.Axes:\n",
" ax.plot(xx, c1_curve, label=\"channel 1\", color=\"C0\", linestyle=linestyle)\n",
" ax.plot(xx, c2_curve, label=\"channel 2\", color=\"C1\", linestyle=linestyle)\n",
"\n",
Expand Down Expand Up @@ -579,7 +581,9 @@
"metadata": {},
"outputs": [],
"source": [
"def create_lift_test_from_actual_curve(channel: str, x: float, delta_x: float, sigma: float) -> dict[str, float]:\n",
"def create_lift_test_from_actual_curve(\n",
" channel: str, x: float, delta_x: float, sigma: float\n",
") -> dict[str, float]:\n",
" curve_fn = c1_curve_fn if channel == \"channel 1\" else c2_curve_fn\n",
"\n",
" delta_y = curve_fn(x + delta_x) - curve_fn(x)\n",
Expand Down Expand Up @@ -678,14 +682,16 @@
}
],
"source": [
"df_lift_test = pd.DataFrame([\n",
" # Channel x1\n",
" create_lift_test_from_actual_curve(\"channel 1\", 0.0, 0.05, 0.05),\n",
" create_lift_test_from_actual_curve(\"channel 1\", 0.15, 0.05, 0.05),\n",
" create_lift_test_from_actual_curve(\"channel 1\", 0.3, 0.05, 0.05),\n",
" # Channel x2\n",
" create_lift_test_from_actual_curve(\"channel 2\", 0.5, 0.05, 0.10),\n",
"])\n",
"df_lift_test = pd.DataFrame(\n",
" [\n",
" # Channel x1\n",
" create_lift_test_from_actual_curve(\"channel 1\", 0.0, 0.05, 0.05),\n",
" create_lift_test_from_actual_curve(\"channel 1\", 0.15, 0.05, 0.05),\n",
" create_lift_test_from_actual_curve(\"channel 1\", 0.3, 0.05, 0.05),\n",
" # Channel x2\n",
" create_lift_test_from_actual_curve(\"channel 2\", 0.5, 0.05, 0.10),\n",
" ]\n",
")\n",
"\n",
"df_lift_test"
]
Expand All @@ -701,7 +707,15 @@
},
"outputs": [],
"source": [
"def plot_triangle(x, delta_x, delta_y, color: str, ax: plt.Axes, offset: float = 0, label: str = None) -> plt.Axes:\n",
"def plot_triangle(\n",
" x,\n",
" delta_x,\n",
" delta_y,\n",
" color: str,\n",
" ax: plt.Axes,\n",
" offset: float = 0,\n",
" label: str | None = None,\n",
") -> plt.Axes:\n",
" x_after = x + delta_x\n",
"\n",
" y = offset\n",
Expand All @@ -714,10 +728,14 @@
" return ax\n",
"\n",
"\n",
"def plot_channel_triangles(df: pd.DataFrame, color: str, ax: plt.Axes, label: str) -> plt.Axes:\n",
"def plot_channel_triangles(\n",
" df: pd.DataFrame, color: str, ax: plt.Axes, label: str\n",
") -> plt.Axes:\n",
" kwargs = {\"label\": label}\n",
" for _, row in df.iterrows():\n",
" plot_triangle(row[\"x\"], row[\"delta_x\"], row[\"delta_y\"], ax=ax, color=color, **kwargs)\n",
" plot_triangle(\n",
" row[\"x\"], row[\"delta_x\"], row[\"delta_y\"], ax=ax, color=color, **kwargs\n",
" )\n",
" if \"label\" in kwargs:\n",
" kwargs.pop(\"label\")\n",
" return ax\n",
Expand Down Expand Up @@ -1120,7 +1138,9 @@
},
"outputs": [],
"source": [
"def plot_channel_rug(df: pd.DataFrame, color: str, ax: plt.Axes, height: float) -> plt.Axes:\n",
"def plot_channel_rug(\n",
" df: pd.DataFrame, color: str, ax: plt.Axes, height: float\n",
") -> plt.Axes:\n",
" for x in df[\"x\"].to_numpy():\n",
" ax.axvline(x, ymin=0, ymax=height, color=color)\n",
" return ax\n",
Expand Down Expand Up @@ -1323,13 +1343,15 @@
}
],
"source": [
"df_additional_lift_test = pd.DataFrame([\n",
" # More for Channel x1\n",
" create_lift_test_from_actual_curve(\"channel 1\", 0.1, 0.05, sigma=0.01),\n",
" create_lift_test_from_actual_curve(\"channel 1\", 0.5, 0.05, sigma=0.01),\n",
" # More for channel x2\n",
" create_lift_test_from_actual_curve(\"channel 2\", 0.3, 0.05, sigma=0.01),\n",
"])\n",
"df_additional_lift_test = pd.DataFrame(\n",
" [\n",
" # More for Channel x1\n",
" create_lift_test_from_actual_curve(\"channel 1\", 0.1, 0.05, sigma=0.01),\n",
" create_lift_test_from_actual_curve(\"channel 1\", 0.5, 0.05, sigma=0.01),\n",
" # More for channel x2\n",
" create_lift_test_from_actual_curve(\"channel 2\", 0.3, 0.05, sigma=0.01),\n",
" ]\n",
")\n",
"\n",
"df_additional_lift_test"
]
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/clv/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def rng_fn(cls, rng, r, alpha, s, beta, T, size):
beta = np.broadcast_to(beta, size)
T = np.broadcast_to(T, size)

output = np.zeros(shape=size + (2,))
output = np.zeros(shape=size + (2,)) # noqa:RUF005

lam = rng.gamma(shape=r, scale=1 / alpha, size=size)
mu = rng.gamma(shape=s, scale=1 / beta, size=size)
Expand Down
8 changes: 4 additions & 4 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _add_fit_data_group(self, data: pd.DataFrame) -> None:
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
assert self.idata is not None
assert self.idata is not None # noqa: S101
self.idata.add_groups(fit_data=data.to_xarray())

def fit( # type: ignore
Expand Down Expand Up @@ -206,8 +206,8 @@ def thin_fit_result(self, keep_every: int):
)
"""
self.fit_result # Raise Error if fit didn't happen yet
assert self.idata is not None
self.fit_result # noqa: B018 (Raise Error if fit didn't happen yet)
assert self.idata is not None # noqa: S101
new_idata = self.idata.isel(draw=slice(None, None, keep_every)).copy()
return type(self)._build_with_idata(new_idata)

Expand Down Expand Up @@ -237,7 +237,7 @@ def fit_result(self, res: az.InferenceData) -> None:
if self.idata is None:
self.idata = res
elif "posterior" in self.idata:
warnings.warn("Overriding pre-existing fit_result")
warnings.warn("Overriding pre-existing fit_result", stacklevel=1)
self.idata.posterior = res
else:
self.idata.posterior = res
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/clv/models/beta_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class BetaGeoModel(CLVModel):
.. [2] Fader, P. S., Hardie, B. G., & Lee, K. L. (2008). Computing
P (alive) using the BG/NBD model. Research Note available via
http://www.brucehardie.com/notes/021/palive_for_BGNBD.pdf.
.. [3] Fader, P. S. & Hardie, B. G. (2013) Overcoming the BG/NBD Models #NUM!
.. [3] Fader, P. S. & Hardie, B. G. (2013) Overcoming the BG/NBD Model's #NUM!
Error Problem. Research Note available via
http://brucehardie.com/notes/027/bgnbd_num_error.pdf.
"""
Expand Down
19 changes: 10 additions & 9 deletions pymc_marketing/clv/models/gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,16 @@ def expected_customer_spend(
Eq 5 from [1], p.3
Adapted from: https://github.com/CamDavidsonPilon/lifetimes/blob/aae339c5437ec31717309ba0ec394427e19753c4/lifetimes/fitters/gamma_gamma_fitter.py#L117
""" # noqa: E501
"""

mean_transaction_value, frequency = to_xarray(
customer_id, mean_transaction_value, frequency
)
assert self.idata is not None, "Model must be fitted first"
p = self.idata.posterior["p"]
q = self.idata.posterior["q"]
v = self.idata.posterior["v"]
posterior = self.fit_result

p = posterior["p"]
q = posterior["q"]
v = posterior["v"]

individual_weight = p * frequency / (p * frequency + q - 1)
population_mean = v * p / (q - 1)
Expand Down Expand Up @@ -89,10 +90,10 @@ def distribution_new_customer_spend(
def expected_new_customer_spend(self) -> xarray.DataArray:
"""Expected transaction value for a new customer"""

assert self.idata is not None, "Model must be fitted first"
p_mean = self.idata.posterior["p"]
q_mean = self.idata.posterior["q"]
v_mean = self.idata.posterior["v"]
posterior = self.fit_result
p_mean = posterior["p"]
q_mean = posterior["q"]
v_mean = posterior["v"]

# Closed form solution to the posterior of nu
# Eq 3 from [1], p.3
Expand Down
9 changes: 7 additions & 2 deletions pymc_marketing/clv/models/pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,13 @@ def __init__(
self.covariate_cols = self.purchase_covariate_cols + self.dropout_covariate_cols
self._validate_cols(
data,
required_cols=["customer_id", "frequency", "recency", "T"]
+ self.covariate_cols,
required_cols=[
"customer_id",
"frequency",
"recency",
"T",
*self.covariate_cols,
],
must_be_unique=["customer_id"],
)

Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/clv/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _find_first_transactions(


def clv_summary(*args, **kwargs):
warnings.warn("clv_summary was renamed to rfm_summary", UserWarning)
warnings.warn("clv_summary was renamed to rfm_summary", UserWarning, stacklevel=1)
return rfm_summary(*args, **kwargs)


Expand Down
16 changes: 11 additions & 5 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure:
for arg, var_contribution in zip(
["control_columns", "yearly_seasonality"],
["control_contributions", "fourier_contributions"],
strict=True,
):
if getattr(self, arg, None):
contributions = self._format_model_contributions(
Expand All @@ -413,6 +414,7 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure:
"control_contribution",
"fourier_contribution",
],
strict=False,
)
):
if self.X is not None:
Expand Down Expand Up @@ -665,7 +667,7 @@ def plot_budget_scenearios(
standardize_scenarios_dict_keys(base_data, ["contribution", "budget"])

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 6))
scenarios = [base_data] + list(scenarios_data)
scenarios = [base_data, *list(scenarios_data)]
num_scenarios = len(scenarios)
bar_width = (
0.8 / num_scenarios
Expand All @@ -688,7 +690,7 @@ def plot_budget_scenearios(

# Plot all scenarios
for i, (scenario, upper_bound, lower_bound) in enumerate(
zip(scenarios, upper_bounds, lower_bounds)
zip(scenarios, upper_bounds, lower_bounds, strict=False)
):
color = f"C{i}"
offset = i * bar_width - 0.4 + bar_width / 2
Expand Down Expand Up @@ -914,7 +916,9 @@ def optimize_channel_budget_for_maximum_contribution(
"The 'parameters' argument (keyword-only) must be provided and non-empty."
)

warnings.warn("This budget allocator method is experimental", UserWarning)
warnings.warn(
"This budget allocator method is experimental", UserWarning, stacklevel=1
)

return budget_allocator(
method=method,
Expand Down Expand Up @@ -947,7 +951,9 @@ def compute_channel_curve_optimization_parameters_original_scale(
parameters for each channel based on the method used.
"""
warnings.warn(
"The curve optimization parameters method is experimental", UserWarning
"The curve optimization parameters method is experimental",
UserWarning,
stacklevel=1,
)

channel_contributions = self.compute_channel_contribution_original_scale().mean(
Expand Down Expand Up @@ -1044,7 +1050,7 @@ def legend_title_func(channel):
axes_channels = (
zip(repeat(axes), channels_to_plot)
if same_axes
else zip(np.ravel(axes), channels_to_plot)
else zip(np.ravel(axes), channels_to_plot, strict=False)
)

for i, (ax, channel) in enumerate(axes_channels):
Expand Down
6 changes: 4 additions & 2 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def objective_distribution(

sum_contributions = 0.0

for channel, budget in zip(channels, x):
for channel, budget in zip(channels, x, strict=False):
if method == "michaelis-menten":
L, k = parameters[channel]
sum_contributions += michaelis_menten(budget, L, k)
Expand Down Expand Up @@ -182,7 +182,9 @@ def optimize_budget_distribution(
constraints=constraints,
)

return {channel: budget for channel, budget in zip(channels, result.x)}
return {
channel: budget for channel, budget in zip(channels, result.x, strict=False)
}


def budget_allocator(
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def _data_setter(
try:
new_channel_data = X[self.channel_columns].to_numpy()
except KeyError as e:
raise RuntimeError("New data must contain channel_data!", e)
raise RuntimeError("New data must contain channel_data!") from e

def identity(x):
return x
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def batched_convolution(
if l_max is None:
try:
l_max = w.shape[-1].eval()
except Exception:
except Exception: # noqa: S110
pass
# Get the broadcast shapes of x and w but ignoring their last dimension.
# The last dimension of x is the "time" axis, which doesn't get broadcast
Expand Down
3 changes: 2 additions & 1 deletion pymc_marketing/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,8 @@ def fit(

X_df = pd.DataFrame(X, columns=X.columns)
combined_data = pd.concat([X_df, y_df], axis=1)
assert all(combined_data.columns), "All columns must have non-empty names"
if not all(combined_data.columns):
raise ValueError("All columns must have non-empty names")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
Expand Down
Loading

0 comments on commit 354cd4a

Please sign in to comment.