Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for engine="auto" and bambi models in predictive explorer #455

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions preliz/internal/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from sys import modules

import numpy as np
import pymc as pm
import bambi as bmb
rohanbabbar04 marked this conversation as resolved.
Show resolved Hide resolved

from preliz import distributions
from .distribution_helper import init_vals
Expand All @@ -13,8 +15,18 @@ def inspect_source(fmodel):
source = inspect.getsource(fmodel)
signature = inspect.signature(fmodel)
source = re.sub(r"#.*$|^#.*$", "", source, flags=re.MULTILINE)

return source, signature
default_params = {
name: (param.default if param.default is not inspect.Parameter.empty else np.nan)
for name, param in signature.parameters.items()
}
model = fmodel(**default_params)
if isinstance(model, pm.Model):
engine = "pymc"
elif isinstance(model, bmb.Model):
rohanbabbar04 marked this conversation as resolved.
Show resolved Hide resolved
engine = "bambi"
else:
engine = "preliz"
return source, signature, engine


def parse_function_for_pred_textboxes(source, signature, engine="preliz"):
Expand Down
2 changes: 1 addition & 1 deletion preliz/internal/plot_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def looper(*args, **kwargs):
model = func(*args, **kwargs)
model.build()
with disable_pymc_sampling_logs():
idata = model.prior_predictive(iterations, random_seed=iterations)
idata = model.prior_predictive(iterations)
results = extract(idata, group="prior_predictive")[model.response_name].values.T

_, ax = plt.subplots()
Expand Down
10 changes: 0 additions & 10 deletions preliz/internal/predictive_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from sys import modules

import re
import numpy as np


Expand Down Expand Up @@ -69,12 +68,3 @@ def select_prior_samples(selected, prior_samples, model):
subsample = {rv: prior_samples[rv][selected] for rv in model.keys()}

return subsample


def get_engine(source):
source = re.sub(r"#.*$|^#.*$", "", source, flags=re.MULTILINE)
if re.search(r"\s*(?:bmb|bambi)\.Model\s*", source):
return "bambi"
if re.search(r"\s*(?:pm|pymc)\.Model\s*", source):
return "pymc"
return "preliz"
4 changes: 1 addition & 3 deletions preliz/predictive/predictive_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
pymc_plot_decorator,
bambi_plot_decorator,
)
from preliz.internal.predictive_helper import get_engine


def predictive_explorer(
Expand Down Expand Up @@ -44,8 +43,7 @@ def predictive_explorer(
The function will automatically select the appropriate library to use based on the fmodel
provided.
"""
source, signature = inspect_source(fmodel)
engine = get_engine(source) if engine == "auto" else engine
source, signature, engine = inspect_source(fmodel)
model = parse_function_for_pred_textboxes(source, signature, engine)
textboxes = get_textboxes(signature, model)
if engine == "pymc":
Expand Down
Loading