Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding simple, multiple and hierarchical regression plots #512

Open
GWeindel opened this issue Jan 8, 2019 · 25 comments
Open

Adding simple, multiple and hierarchical regression plots #512

GWeindel opened this issue Jan 8, 2019 · 25 comments

Comments

@GWeindel
Copy link
Contributor

GWeindel commented Jan 8, 2019

I have written some functions to draw regression plots from mixed models fitted in pure Stan. I wonder whether creating a branch in arviz for such plots would be interesting (guess so by seeing request #313). The basic idea would be to have the possibility of plotting linear (at least) regression, simple effects, interaction effects, with or without random effects (like sjPlot in R ).

If such a project would fit in the arviz package I could begin to code it, but I would definitely benefit from people with stronger skills.

@ahartikainen
Copy link
Contributor

Sounds great. I think we need to think API carefully.

@GWeindel
Copy link
Contributor Author

Great, if the other devs agree I can start to think about it but it surely will take some time.

@ahartikainen
Copy link
Contributor

Sure, no problem.

Do you have some model(s) you can share that could be used as a reference?

Also, do you think observed_data is suitable location for the data or do we need some other structure?

@aloctavodia
Copy link
Contributor

This will be a great addition to ArviZ. @GWeindel please be sure to check plot_hpd function in case you find it useful for this project.

@GWeindel
Copy link
Contributor Author

I am starting to have some doubts about the feasibility. It appears to me that either one constructs a post-fit structure which then needs a lot of information about the fitted object (increasing users time and effort), or one has to master what goes in (e.g. stan or PYMC3 code) and what goes out to draw these plots.
Hence I would suggest that this project should be build on top of a specialized package like bambi perhaps (https://github.com/bambinos/bambi)

@ahartikainen
Copy link
Contributor

What if we start with regression plot done with InferenceData?

So user needs to add at creation step the following information

Scatter information:

  • x-variable_name --> observed_data
  • y-variable_name --> posterior_predictive
  • y-data --> observed_data

Model information (Line plot)

  • x-model --> observed_data
  • y-model --> posterior or a function that reads InferenceData (user defined)

Then we need to define same stuff as in ppcplot (do we take subsample etc)

@ahartikainen
Copy link
Contributor

After that we could implement multiple regression (where each dimension is either a new axis, or something similar)

And later do hierachical structure.

Let's assume user can provide data.

@ahartikainen
Copy link
Contributor

I was doing something simple today: linear regression...

It does get complicated fast.

We need a better interface to describe our models

Like getting the following to work is not hard

x
y_data
y_ppc

What is more or less hard

y_model

It would be great to give a function or something similar

y_model = "m*x+b"
y_model = "y = m*x+b"
y_model = "y ~ x"

and then m, x, and b are found from posterior.

Also I'm not sure, but there could still be better interface:

plot_lm("y ~ x", param=["m","b"], data=data)
plot_lm("y ~ m*x+b", data=data)

Could this work with glm also? If we assume InferenceData has all the needed data, we just need to parse the function and also accept numpy functions inside the

plot_lm("exp(y) ~ m*x + log(x) + sqrt(b)", data=data)

How hard would it if we did that parsing with re?

~ splits x,y
functions have ()
others are parameters in InferenceData

Then after we have y (and possibly added pair for ppc: {"y_hat" : "y"}

def plot_lm(x, y_ppc, y_data, y_model, data, x_group=None, y_ppc_group=None, y_data_group=None, num_ppc_samples=100):
    """Plot lm
    
    Parameters
    ----------
    x : str or Sequence 
    y_ppc : str
    y_data : str or Sequence
    y_model : str or Sequence
    data : obj or list[obj]
        Any object that can be converted to an az.InferenceData object
        Refer to documentation of az.convert_to_dataset for details
    xgroup : str
    ygroup : str
    num_ppc_samples : int
    line_err : bool
    y_err : bool
    x_err : bool
    xscale : str
    yscale : str
    
    Returns
    ------
    axes
    """
    
    if isinstance(x, str):
        if x_group is None:
            groups = data._groups
            if hasattr(data, "observed_data"):
                groups = ["observed_data"] + [group for group in groups if group != "observed_data"]
            for group in groups:
                item = getattr(data, group)
                if x in item and x_group is None:
                    x_group = group
                elif x in item:
                    print("Warning, duplicate variable names for x, using variable from group {}".format(x_group))
        x_values = getattr(data, x_group)[x]
    
    if isinstance(y_ppc, str):
        if y_ppc_group is None:
            groups = data._groups
            if hasattr(data, "posterior_predictive"):
                groups = ["posterior_predictive"] + [group for group in groups if group != "posterior_predictive"]
            for group in groups:
                item = getattr(data, group)
                if y_ppc in item and y_ppc_group is None:
                    y_ppc_group = group
                elif y_ppc in item:
                    print("Warning, duplicate variable names for y_ppc, using variable from group {}".format(y_ppc_group))
        y_ppc_values = getattr(data, y_ppc_group)[y_ppc]
    
    if isinstance(y_data, str):
        if y_data_group is None:
            if hasattr(data, "observed_data"):
                groups = ["observed_data"] + [group for group in groups if group != "observed_data"]
            for group in groups:
                item = getattr(data, group)
                if y_data in item and y_data_group is None:
                    y_data_group = group
                elif y_data in item:
                    print("Warning, duplicate variable names for y_data, using variable from group {}".format(y_data_group))
        y_data_values = getattr(data, y_data_group)[y_data]
    
    
    fig, ax = plt.subplots(1,1, figsize=(6,4), dpi=100)
    
    # plot data
    ax.plot(x_values, y_data_values, marker='.', color='C3', lw=0, zorder=10)
    
    # plot uncertainty in y
    slicer = np.random.choice(list(range(4000)),  size=num_pp_samples, replace=False)
    y_ppc_values_ = y_ppc_values.stack(sample=("chain", "draw"))[..., slicer]
    for i in range(num_pp_samples):
        ax.plot(x_values, y_ppc_values_[..., i], marker='.', lw=0, alpha=0.1, color='C1') 
    
    y_model_values = y_model.stack(sample=("chain", "draw"))[... ,slicer]
    # plot uncertainty in line
    for i in range(num_pp_samples):
        ax.plot(x_values, y_model_values[..., i], lw=0.5, alpha=0.2, c='k')

    for spine in ax.spines.values():
        spine.set_visible(False)
    ax.grid(True)
    return ax

image

@jankaWIS
Copy link

Speaking of which, I was wondering is there currently in arviz something like regplot in seaborn? That would be great and that could also give a beginning to what has been asked here.

@utkarsh-maheshwari
Copy link
Contributor

Just a thought.
Inspite of asking for y_model, can't we calculate m abd b inside the function plot_lm?
Though it will increase the complexity of the function but would reduce the complexity at the input end and make it more user-friendly.

@OriolAbril
Copy link
Member

Inspite of asking for y_model, can't we calculate m abd b inside the function plot_lm?

The problem is that there is no way to know what y_model is in ArviZ (it could be possible at a higher level like in bambi, but not in ArviZ), it depends on the model, it can be a y ~ b1*x+b0 but it could have multiple covariates, higher order terms, splines...

@utkarsh-maheshwari
Copy link
Contributor

utkarsh-maheshwari commented Jun 2, 2021

@ahartikainen
What are the assumptions we make about the data groups that should be present in infernecData passed as input??

This is an example kidiq that I am taking from posteriordb but there is no posterior predictive here, just the posterior. Can it be used as an example?
image

@ahartikainen
Copy link
Contributor

good question. I think we need to calculate the posterior predictive with python.

@utkarsh-maheshwari
Copy link
Contributor

https://gist.github.com/utkarsh-maheshwari/8d4cd2fd84c763bf85291c3f0881d588

Here is my initial try for visualization of linear regression models inspired by plot_posterior_predictive_glm from pymc3. There are still lots of things that need to be considered though.

@OriolAbril
Copy link
Member

Use

with pm.Model() as model:
    mom_iq = pm.Data("mom_iq", data["mom_iq"])
    
    sigma = pm.HalfNormal('sigma', sd=10)
    intercept = pm.Normal('Intercept', 0, sd=10)
    x_coeff = pm.Normal('slope', 0, sd=10)
    
    mean = intercept + x_coeff * mom_iq
    likelihood = pm.Normal('kid_score', mu=mean, 
                        sd=sigma, observed=data["kid_score"])
    
    idata = pm.sample(1000, return_inferencedata=True)

so mom_iq gets automaticaly stored as constant data. Moreover, we should definitely not convert to dataframe:

idata.posterior["Intercept"] + idata.posterior["slope"] * idata.constant_data["mom_iq"]

will work with xarray out of the box and avoid the need to loop for computation, Ari's function above has an example with stacking to get a random subsample.

Also bit of a side note, eval is a reserved word in python, it's not a good idea to use as variable name.

@utkarsh-maheshwari
Copy link
Contributor

utkarsh-maheshwari commented Jun 7, 2021

@OriolAbril Thank you for the suggestions. Made the suggested changes.
I think, here, visualizing uncertainty in y points is insignificant because points are closely packed. (Should we include an option to show it ?)

Now there are many points that are needed to be considered for the function plot_lm:

  • Other input parameters
  • Initial checks on the input data
  • Fill area of uncertainty?

Open to suggestions

@utkarsh-maheshwari
Copy link
Contributor

I guess, using plot_hdi, as suggested by @aloctavodia would make it look great.

@ahartikainen, about the y_model, I think we can do it like this?

Should I open a new issue to discuss particularly plot_lm and it's visualization. Otherwise, this issue will stretch very long.

@utkarsh-maheshwari
Copy link
Contributor

image

@utkarsh-maheshwari
Copy link
Contributor

image

@utkarsh-maheshwari
Copy link
Contributor

image

@utkarsh-maheshwari
Copy link
Contributor

image

@utkarsh-maheshwari
Copy link
Contributor

Also, do you think observed_data is suitable location for the data or do we need some other structure?

I think data could be in constant_data as well.

@utkarsh-maheshwari
Copy link
Contributor

utkarsh-maheshwari commented Jun 10, 2021

I have tried to modified the Ari's plot_lm function with some added features.

  • Made a parser for y_model. I think it would work for simple linear models (irrespective of the order of terms in y_model ) and can be extended to glm as well.
  • Visualized uncertainty in mean, and uncertainty in data.

Achieved this
image

input :

plot_lm(
    x="mom_iq", 
    y_ppc="kid_score",
    y_data="kid_score", 
    data = idata, 
    y_model = "kid_score ~ Intercept + slope * mom_iq"
)

@utkarsh-maheshwari
Copy link
Contributor

I think we need to calculate the posterior predictive with python.

Can we use pm.sample_posterior_predective() to calculate it?

@ahartikainen
Copy link
Contributor

I think we need to calculate the posterior predictive with python.

Can we use pm.sample_posterior_predective() to calculate it?

It depends what PPL you use for the model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants