Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Surprising predict() behaviour #533

Closed
markgoodhead opened this issue Jun 11, 2022 · 8 comments
Closed

Surprising predict() behaviour #533

markgoodhead opened this issue Jun 11, 2022 · 8 comments
Labels
Discussion Issue open for discussion, still not ready for a PR on it.

Comments

@markgoodhead
Copy link
Contributor

I'm not sure if this is an issue as such but more a comment / point of discussion. As a fairly new user to bambi, I found the behaviour of predict() to be very unintuitive because the fit() / predict() interface in the python data science ecosystem has very strong connotions with the standard sklearn interface that's very common.

I've an appreciation of how the bayesian approach is different to the usual ML approach in that generally you're interested in not just making point estimates but getting the full posterior predictive distribution for predictions, however I think it's very beginner friendly to have the option be able to treat the model like the frequentist/ML sklearn-style models for which predict() would produce point estimates. Even if you don't care too much about the full posterior and uncertainty estimation, there's still reasons to prefer the bayesian approach if you're only after point estimates.

One proposal would be to add a "point-estimate-prediction" option alongside "mean" and "pps" which instead returns a numpy array of the per data point predictions. Whilst this would just be equivalent to doing pps then taking the mean, I think this would help the beginner tremendously as otherwise they have to understand the InferenceData/xarray normal structure to be able to do so, which has its own learning curve.

@aloctavodia
Copy link
Collaborator

Working with distributions and samples is the bread and butter of computational bayesians. What about adding or extending and example about this. The more examples we have about how to use inferencedata the least steep that learning curve will be.

@aloctavodia aloctavodia added the Discussion Issue open for discussion, still not ready for a PR on it. label Jun 11, 2022
@tomicapretto
Copy link
Collaborator

tomicapretto commented Jun 12, 2022

I understand your point, and I agree it's harder to think about samples. Even more, we have chains and draws, and new data structures as well which make things more complicated I think.

On the other hand, I don't think we should change anything in the .predict() method itself. It should always return a non-post-processed Bayesian prediction.

I think the way to make Bambi more appealing to newcomers (both beginners and people experienced with frequentist frameworks) is to provide utility functions that do much if not all of the heavy lifting (for example #517).

In this particular case, I think there could be another utility function that post-process the InferenceData to return point estimates. This way, users will be aware that Bambi returns a whole posterior (as samples) but they are explicitly converting them to point estimates (no matter they don't implement that conversion by hand).

Some open questions

  • How this function should behave?
  • In which module do we put this function?
  • Do we create a new one? Do we want to have a utility module?

@markgoodhead
Copy link
Contributor Author

Yes I think a utility method is a good solution and would address this quite simply, I think it'd be quite small as my analogous code I use in my pymc models is just:

y_test = pm.sample_posterior_predictive(results)
return y_test.posterior_predictive.y.mean(axis=(0,1)).values

If it were a brand new library, I'd argue for keeping predict() closer to the sklearn interface because that's what's 'surprising' to new users; I'd bet a high percentage of people without a Bayesian/PyMC background but with a Python data science/ML background see fit/predict and immediately have a misconception about how the predict() function will work. However I appreciate that breaking backwards compatibility to address this is probably too annoying to existing bambi users who expect the current behaviour, so a small utility function to add an sklearn-API-like predict call is a good compromise I think.

In terms of where it should exist, ideally it'd be another method on the bambi Model object named something like predict_sklearn() or predict_point_estimates() (better naming suggestions welcomed!) such that when another new user does what I did (calls predict() then goes "huh, where's my numpy array of predictions?!") and then they go to the API reference documentation they'll see this function directly below/above predict() and can go "Ah, that's the one I should use!".

I'll produce a small example branch with this feature on (as I think it should be fairly straightforward) to better demonstrate what I'm thinking of.

@markgoodhead
Copy link
Contributor Author

#535

Here's an example of what I was thinking of. I appreciate it's incredibly simple (literally a one liner once you've called 'predict') but this one-liner took me a few hours to work out the first time I used bambi 😅

@aloctavodia
Copy link
Collaborator

aloctavodia commented Jun 13, 2022

Notice that with InferenceData/xarray you can use labels.

y_test.posterior_predictive["y"].mean(("chain", "draw"))

This generally requires writing more characters, but the result is easier to read. For example, here is clear that the intention is to average over both chains and draws.

Another comment. We have discussed in ArviZ to hide the information about the chains to the user. The main reason is to simplify working with InferenceData, because as a general rule uses do not care about directly accessing to individuals chains. The chain information is useful mostly to diagnose samples, so ArviZ internally can get access to chains.

@tomicapretto
Copy link
Collaborator

I'm going to close the issue because we're not changing the way .predict() behaves.

@tomicapretto tomicapretto closed this as not planned Won't fix, can't repro, duplicate, stale Jan 5, 2023
@gshotwell
Copy link

I'm pretty new to bayesian modelling, but I did find this behaviour to be quite confusing in Bambi. Maybe a good solution would be to have a vignette which went through how to practically work with the XArray objects?

@tomicapretto
Copy link
Collaborator

@gshotwell that's definitely a good idea, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Discussion Issue open for discussion, still not ready for a PR on it.
Projects
None yet
Development

No branches or pull requests

4 participants