Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot_cap (Plot Conditional Adjusted Predictions) #517

Merged
merged 11 commits into from
Jun 6, 2022

Conversation

tomicapretto
Copy link
Collaborator

This PR adds a new sub-package called plots. Right now it only contains one function, plot_cap(), which is very versatile and very powerful. This function is highly inspired by the plot_cap function in the R package {marginaleffects}.

I'll let some examples talk for themselves

import arviz as az
import bambi as bmb
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from bambi.plots import plot_cap

data = pd.read_csv("mtcars.csv")
data["cyl"] = data["cyl"].replace({4: "low", 6: "medium", 8: "high"})
data["gear"] = data["gear"].replace({3: "A", 4: "B", 5: "C"})
data["cyl"] = pd.Categorical(data["cyl"], categories=["low", "medium", "high"], ordered=True)

model = bmb.Model("mpg ~ 0 + hp * wt + cyl + gear", data)
idata = model.fit(draws=1000, target_accept=0.95, random_seed=1234)
  • One numerical covariate.
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
plot_cap(model, idata, "hp", ax=ax);

image

  • Two numerical covariates (the second is interpreted as a group, quantiles are used)
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
plot_cap(model, idata, ["hp", "wt"], ax=ax);

image

  • Main numerical and grouping categoric
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
plot_cap(model, idata, ["hp", "cyl"], ax=ax);

image

  • Main categoric
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
plot_cap(model, idata, ["gear"], ax=ax);

image

  • Main categoric and grouping categoric
fig, ax = plt.subplots(figsize=(7, 5), dpi=120)
plot_cap(model, idata, ["gear", "cyl"], ax=ax);

image

  • Main categoric and grouping numeric
fig, ax = plt.subplots(figsize=(7, 5), dpi=120, tight_layout=True)
plot_cap(model, idata, ["gear", "wt"], ax=ax);

image


Now let's see another example, using logistic regression. This is also borrowed from {marginaleffects} documentation.

data = pd.read_csv("https://vincentarelbundock.github.io/Rdatasets/csv/ggplot2movies/movies.csv")

data["style"] = "Other"
data.loc[data["Action"] == 1, "style"] = "Action"
data.loc[data["Comedy"] == 1, "style"] = "Comedy"
data.loc[data["Drama"] == 1, "style"] = "Drama"
data["certified_fresh"] = (data["rating"] >= 8) * 1
data = data[data["length"] < 240]

priors = {"style": bmb.Prior("Normal", mu=0, sigma=2)}
model = bmb.Model("certified_fresh ~ 0 + length * style", data=data, priors=priors, family="bernoulli")
model
Formula: certified_fresh ~ 0 + length * style
Family name: Bernoulli
Link: logit
Observations: 58662
Priors:
  Common-level effects
    length ~ Normal(mu: 0.0, sigma: 0.0708)
    style ~ Normal(mu: 0, sigma: 2)
    length:style ~ Normal(mu: [0. 0. 0.], sigma: [0.0702 0.0509 0.0611])
idata = model.fit(random_seed=1234, target_accept=0.9, init="adapt_diag")
fig, ax = plt.subplots(figsize=(7, 5), dpi=120, tight_layout=True)
plot_cap(model, idata, "length", ax=ax)

plot_cap1

fig, ax = plt.subplots(figsize=(9, 5), dpi=120, tight_layout=True)
plot_cap(model, idata, ["length", "style"], ax=ax)

plot_cap2

Extra point: This model is an excellent example of how sometimes adapt_diag+jitter isn't good. If we use adapt_diag+jitter sampling never finishes. Chain don't mix. All types of problems.


I honestly think this is a very cool addition. Would like to know your thoughts @aloctavodia @canyon289

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter
Copy link

codecov-commenter commented Jun 4, 2022

Codecov Report

Merging #517 (5cf4958) into main (4886452) will decrease coverage by 3.98%.
The diff coverage is 0.00%.

❗ Current head 5cf4958 differs from pull request most recent head c886a16. Consider uploading reports for the commit c886a16 to get more accurate results

@@            Coverage Diff             @@
##             main     #517      +/-   ##
==========================================
- Coverage   90.86%   86.88%   -3.99%     
==========================================
  Files          29       32       +3     
  Lines        2442     2562     +120     
==========================================
+ Hits         2219     2226       +7     
- Misses        223      336     +113     
Impacted Files Coverage Δ
bambi/plots/__init__.py 0.00% <0.00%> (ø)
bambi/plots/plot_cap.py 0.00% <0.00%> (ø)
bambi/plots/utils.py 0.00% <0.00%> (ø)
bambi/backend/terms.py 96.22% <0.00%> (ø)
bambi/tests/test_built_models.py 98.91% <0.00%> (+<0.01%) ⬆️
bambi/backend/pymc.py 80.28% <0.00%> (+0.28%) ⬆️
bambi/backend/utils.py 90.00% <0.00%> (+1.11%) ⬆️
bambi/models.py 88.65% <0.00%> (+3.58%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4886452...c886a16. Read the comment docs.

bambi/plots/plot_cap.py Outdated Show resolved Hide resolved
@aloctavodia
Copy link
Collaborator

Definitive, a cool addition! I already want to try it! I left a couple of comments. I think this PR is good as is, a future addition could be one or more kwargs so users are able to fine tune the plots.

@tomicapretto
Copy link
Collaborator Author

Definitive, a cool addition! I already want to try it! I left a couple of comments. I think this PR is good as is, a future addition could be one or more kwargs so users are able to fine tune the plots.

Thanks for the prompt review. I agree this function should incorporate optional arguments in the future so users can tune more things. For example, it would be great to have another dimension that is mapped to the axes, so you can create plots with multiple axes. And we could let them choose if they want to map a dimension to the color or to the axes.

One of the things I don't really like is having functions with extremely long signatures... but it seems it is how we are used to working with Matplotlib and I don't think we could do much to change it.

I think we could merge this as it is (maybe after adding some tests? I'm not sure how to test plotting functions btw) and then iterate to refine how it works.

@canyon289
Copy link
Collaborator

I'll review in a couple of hours, at a glance this is already look pretty cool

@tomicapretto
Copy link
Collaborator Author

Matplotlib is not in our dependencies, but it's an indirect dependency because of ArviZ. Do you think we should change anything in our requirements.txt file?

@aloctavodia
Copy link
Collaborator

I don't think we don't need to change our requirements.txt

Comment on lines +157 to +162
lower_bound = round((1 - hdi_prob) / 2, 4)
upper_bound = 1 - lower_bound

y_hat = idata.posterior[f"{model.response.name}_mean"]
y_hat_mean = y_hat.mean(("chain", "draw"))
y_hat_bounds = y_hat.quantile(q=(lower_bound, upper_bound), dim=("chain", "draw"))
Copy link
Collaborator

@aloctavodia aloctavodia Jun 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use az.hdi and pass hdi_prob directly to it. Additionally, we could have the option to use quantiles or HDI, but still HDI should probably be the default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the advantages of HDI? I recall this discussion arviz-devs/arviz#2021 where HDI can result in an unexpected result.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consistency, we use HDI everywhere is ArviZ. Having both is probably a better option.

@canyon289
Copy link
Collaborator

canyon289 commented Jun 5, 2022

I think we could merge this as it is (maybe after adding some tests? I'm not sure how to test plotting functions btw) and then iterate to refine how it works.
Testing plotting functions is challenging. One level is just testing it runs with reasonable parameters which I suggest doing. In ArviZ another hacky way we do it is by having the plotting functions save their output if a keyword is passed and inspecting it manually.

For extra kwargs I agree with Tomas its hard to anticipate everything without adding a ton of kwargs. Maybe we can just merge as is and add flexibility as needed in future PRs once we run into cases in actual usage

@aloctavodia aloctavodia merged commit 62fed83 into bambinos:main Jun 6, 2022
@tomicapretto tomicapretto deleted the plots branch June 7, 2022 11:42
lower_bound = round((1 - hdi_prob) / 2, 4)
upper_bound = 1 - lower_bound

y_hat = idata.posterior[f"{model.response.name}_mean"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love plot_cap so far! One potential issue here: after using plot_cap once, I realized the function adds 200 (grid_n) new entries into the idata's posterior data variables. The issue is, when running az.summary(idata) after plot_cap, it will include all of those y_hat_mean values (as would any other az plot like plot_trace).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! That should be a problem with the .predict() method ignoring the inplace=False argument. Could you please open an issue with a minimum reproducible example that shows the problem? Thanks @vishalthatsme !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants