-
-
Notifications
You must be signed in to change notification settings - Fork 132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add plot_cap (Plot Conditional Adjusted Predictions) #517
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
@@ Coverage Diff @@
## main #517 +/- ##
==========================================
- Coverage 90.86% 86.88% -3.99%
==========================================
Files 29 32 +3
Lines 2442 2562 +120
==========================================
+ Hits 2219 2226 +7
- Misses 223 336 +113
Continue to review full report at Codecov.
|
Definitive, a cool addition! I already want to try it! I left a couple of comments. I think this PR is good as is, a future addition could be one or more kwargs so users are able to fine tune the plots. |
Thanks for the prompt review. I agree this function should incorporate optional arguments in the future so users can tune more things. For example, it would be great to have another dimension that is mapped to the axes, so you can create plots with multiple axes. And we could let them choose if they want to map a dimension to the color or to the axes. One of the things I don't really like is having functions with extremely long signatures... but it seems it is how we are used to working with Matplotlib and I don't think we could do much to change it. I think we could merge this as it is (maybe after adding some tests? I'm not sure how to test plotting functions btw) and then iterate to refine how it works. |
I'll review in a couple of hours, at a glance this is already look pretty cool |
Matplotlib is not in our dependencies, but it's an indirect dependency because of ArviZ. Do you think we should change anything in our |
I don't think we don't need to change our |
lower_bound = round((1 - hdi_prob) / 2, 4) | ||
upper_bound = 1 - lower_bound | ||
|
||
y_hat = idata.posterior[f"{model.response.name}_mean"] | ||
y_hat_mean = y_hat.mean(("chain", "draw")) | ||
y_hat_bounds = y_hat.quantile(q=(lower_bound, upper_bound), dim=("chain", "draw")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use az.hdi
and pass hdi_prob
directly to it. Additionally, we could have the option to use quantiles or HDI, but still HDI should probably be the default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the advantages of HDI? I recall this discussion arviz-devs/arviz#2021 where HDI can result in an unexpected result.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consistency, we use HDI everywhere is ArviZ. Having both is probably a better option.
For extra kwargs I agree with Tomas its hard to anticipate everything without adding a ton of kwargs. Maybe we can just merge as is and add flexibility as needed in future PRs once we run into cases in actual usage |
lower_bound = round((1 - hdi_prob) / 2, 4) | ||
upper_bound = 1 - lower_bound | ||
|
||
y_hat = idata.posterior[f"{model.response.name}_mean"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love plot_cap so far! One potential issue here: after using plot_cap once, I realized the function adds 200 (grid_n) new entries into the idata's posterior data variables. The issue is, when running az.summary(idata) after plot_cap, it will include all of those y_hat_mean values (as would any other az plot like plot_trace).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! That should be a problem with the .predict()
method ignoring the inplace=False
argument. Could you please open an issue with a minimum reproducible example that shows the problem? Thanks @vishalthatsme !
This PR adds a new sub-package called
plots
. Right now it only contains one function,plot_cap()
, which is very versatile and very powerful. This function is highly inspired by theplot_cap
function in the R package {marginaleffects}.I'll let some examples talk for themselves
Now let's see another example, using logistic regression. This is also borrowed from {marginaleffects} documentation.
Extra point: This model is an excellent example of how sometimes
adapt_diag+jitter
isn't good. If we useadapt_diag+jitter
sampling never finishes. Chain don't mix. All types of problems.I honestly think this is a very cool addition. Would like to know your thoughts @aloctavodia @canyon289