-
-
Notifications
You must be signed in to change notification settings - Fork 132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Fix HSGP predictions #780
[WIP] Fix HSGP predictions #780
Conversation
Thanks a lot @tomicapretto 👍🏼 |
* Update code of conduct * update changelog
…nto fix_hsgp_prediction
@GStechschulte could you try this? import bambi as bmb
import numpy as np
import pandas as pd
df = pd.read_csv("tests/data/gam_data.csv")
rng = np.random.default_rng(1234)
df["fac2"] = rng.choice(["a", "b", "c"], size=df.shape[0])
formula = "y ~ 1 + x0 + hsgp(x1, by=fac, m=10, c=2) + hsgp(x1, by=fac2, m=10, c=2)"
model = bmb.Model(formula, df, categorical=["fac"])
idata = model.fit(tune=500, draws=500, target_accept=0.9) Plot 1 bmb.interpret.plot_predictions(
model,
idata,
conditional="x1",
subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
); Plot 2 bmb.interpret.plot_predictions(
model,
idata,
conditional={
"x1": np.linspace(0, 1, num=100),
"fac2": ["a", "b", "c"]
},
legend=False,
subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
); I was expecting to get the second plot with the code for the first plot. I think we got the result we got because we first generate the data, and only then, we use the |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #780 +/- ##
==========================================
+ Coverage 89.86% 90.16% +0.29%
==========================================
Files 46 46
Lines 3810 3814 +4
==========================================
+ Hits 3424 3439 +15
+ Misses 386 375 -11 ☔ View full report in Codecov by Sentry. |
@tomicapretto thanks! Plot 1 is displaying correctly. It is because you are not explicitly passing This is the behavior both bmb.interpret.plot_predictions(
model,
idata,
conditional=["x1", "fac2"],
subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
legend=False
); |
Thanks @GStechschulte! I think this is done. I know the test is actually testing many things at the same time, not just the fix. But I think it's not possible to write a test for the fix in particular, and if possible, it would be so complicated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! However, my knowledge on the implementation of HSGP in Bambi is a bit lacking.
Yup, I agree. It is also nice to have the text for |
* Delete all HSGP slices at the same time * Make interpret consider kwargs in function calls * Update code of conduct (bambinos#783) * Update code of conduct * update changelog * Update formulae to >=0.5.3 * start a test for the hsgp and 'by' * update changelog
* use bayeux to access a wide range of samplers * use bayeux to access a wide range of samplers * add notebook links to family table (#774) * access methods programatically * clean bayeux idata to be consistent with pymc model coords * rename alternative sampler args in tests * change docstring to reflect bayeux sampler names * bayeux dependencies are numpyro/jax/jaxlib/blackjax * rename idata coords and dims to PyMC model * add JAX based sampler dependencies * Update code of conduct (#783) * Update code of conduct * update changelog * [WIP] Fix HSGP predictions (#780) * Delete all HSGP slices at the same time * Make interpret consider kwargs in function calls * Update code of conduct (#783) * Update code of conduct * update changelog * Update formulae to >=0.5.3 * start a test for the hsgp and 'by' * update changelog * bayeux 0.1.9 updates * bump bayeux version * remove TFP methods, optimizers, and resolve pylint errors * alternative backends docs * tests for JAX based samplers except TFP * add TFP backend example * add TFP MCMC methods * don't use flowmc, chees, meads for categorical model * call model.backend.inference_methods to show list of samplers * docstring changes * inference_methods attribute and change JAX random seed * Add FutureWarning to inference_method parameter * black formatting and resolve pylint errors * fix package name * drop 3.9 and add 3.12 to testing matrix * change Python versions in requires-python and target-version * remove python 3.11 black target-version * pin requires-python to <3.13 * pip upgrade setuptools * Bump PyMC to 5.12 * Upgrade black and pylint * remove upgrading of setup tools --------- Co-authored-by: Tomás Capretto <tomicapretto@gmail.com>
Fixes predictions when HSGP contains a
by
variable.get_model_covariates
so it also looks at the named arguments of function calls.TODO: implement tests?Edit : it closes #776