Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix HSGP predictions #780

Merged
merged 7 commits into from
Feb 29, 2024

Conversation

tomicapretto
Copy link
Collaborator

@tomicapretto tomicapretto commented Feb 18, 2024

Fixes predictions when HSGP contains a by variable.

TODO: implement tests?

Edit : it closes #776

@GStechschulte
Copy link
Collaborator

Thanks a lot @tomicapretto 👍🏼

@tomicapretto
Copy link
Collaborator Author

@GStechschulte could you try this?

import bambi as bmb
import numpy as np
import pandas as pd

df = pd.read_csv("tests/data/gam_data.csv")

rng = np.random.default_rng(1234)
df["fac2"] = rng.choice(["a", "b", "c"], size=df.shape[0])

formula = "y ~ 1 + x0 + hsgp(x1, by=fac, m=10, c=2) + hsgp(x1, by=fac2, m=10, c=2)"
model = bmb.Model(formula, df, categorical=["fac"])
idata = model.fit(tune=500, draws=500, target_accept=0.9)

Plot 1

bmb.interpret.plot_predictions(
    model, 
    idata, 
    conditional="x1", 
    subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
);

image

Plot 2

bmb.interpret.plot_predictions(
    model, 
    idata, 
    conditional={
        "x1": np.linspace(0, 1, num=100),
        "fac2": ["a", "b", "c"]
    }, 
    legend=False,
    subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
);

image

I was expecting to get the second plot with the code for the first plot. I think we got the result we got because we first generate the data, and only then, we use the subplot_kwargs? At that point, it's just too late, you only have one value of fac2

@codecov-commenter
Copy link

codecov-commenter commented Feb 23, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 90.16%. Comparing base (b5b9f09) to head (bdb48d8).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #780      +/-   ##
==========================================
+ Coverage   89.86%   90.16%   +0.29%     
==========================================
  Files          46       46              
  Lines        3810     3814       +4     
==========================================
+ Hits         3424     3439      +15     
+ Misses        386      375      -11     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@GStechschulte
Copy link
Collaborator

GStechschulte commented Feb 23, 2024

@tomicapretto thanks! Plot 1 is displaying correctly. It is because you are not explicitly passing fac2 to conditional. Which results in, as you stated, a single default value computed for fac2. The single value cannot have any subplots.

This is the behavior both interpret and marginaleffects uses if a covariate was specified in the model, but not passed to conditional.

bmb.interpret.plot_predictions(
    model, 
    idata, 
    conditional=["x1", "fac2"], 
    subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
    legend=False
);

image

@tomicapretto tomicapretto marked this pull request as ready for review February 23, 2024 18:55
@tomicapretto
Copy link
Collaborator Author

Thanks @GStechschulte! I think this is done.

I know the test is actually testing many things at the same time, not just the fix. But I think it's not possible to write a test for the fix in particular, and if possible, it would be so complicated.

Copy link
Collaborator

@GStechschulte GStechschulte left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! However, my knowledge on the implementation of HSGP in Bambi is a bit lacking.

@GStechschulte
Copy link
Collaborator

Thanks @GStechschulte! I think this is done.

I know the test is actually testing many things at the same time, not just the fix. But I think it's not possible to write a test for the fix in particular, and if possible, it would be so complicated.

Yup, I agree. It is also nice to have the text for interpret in there.

@GStechschulte GStechschulte merged commit ff685b7 into bambinos:main Feb 29, 2024
4 checks passed
GStechschulte pushed a commit to GStechschulte/bambi that referenced this pull request Mar 1, 2024
* Delete all HSGP slices at the same time

* Make interpret consider kwargs in function calls

* Update code of conduct (bambinos#783)

* Update code of conduct

* update changelog

* Update formulae to >=0.5.3

* start a test for the hsgp and 'by'

* update changelog
GStechschulte added a commit that referenced this pull request Mar 29, 2024
* use bayeux to access a wide range of samplers

* use bayeux to access a wide range of samplers

* add notebook links to family table (#774)

* access methods programatically

* clean bayeux idata to be consistent with pymc model coords

* rename alternative sampler args in tests

* change docstring to reflect bayeux sampler names

* bayeux dependencies are numpyro/jax/jaxlib/blackjax

* rename idata coords and dims to PyMC model

* add JAX based sampler dependencies

* Update code of conduct (#783)

* Update code of conduct

* update changelog

* [WIP] Fix HSGP predictions (#780)

* Delete all HSGP slices at the same time

* Make interpret consider kwargs in function calls

* Update code of conduct (#783)

* Update code of conduct

* update changelog

* Update formulae to >=0.5.3

* start a test for the hsgp and 'by'

* update changelog

* bayeux 0.1.9 updates

* bump bayeux version

* remove TFP methods, optimizers, and resolve pylint errors

* alternative backends docs

* tests for JAX based samplers except TFP

* add TFP backend example

* add TFP MCMC methods

* don't use flowmc, chees, meads for categorical model

* call model.backend.inference_methods to show list of samplers

* docstring changes

* inference_methods attribute and change JAX random seed

* Add FutureWarning to inference_method parameter

* black formatting and resolve pylint errors

* fix package name

* drop 3.9 and add 3.12 to testing matrix

* change Python versions in requires-python and target-version

* remove python 3.11 black target-version

* pin requires-python to <3.13

* pip upgrade setuptools

* Bump PyMC to 5.12

* Upgrade black and pylint

* remove upgrading of setup tools

---------

Co-authored-by: Tomás Capretto <tomicapretto@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

plot_predictions breaks with HSGP
3 participants