Skip to content

Commit

Permalink
ppa: add argument to specify reference values (#143)
Browse files Browse the repository at this point in the history
* ppa: add argument to specify reference values

* add doctring entry
  • Loading branch information
aloctavodia authored Dec 16, 2022
1 parent a844bac commit d65ff34
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions preliz/ppa.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
_log = logging.getLogger("preliz")


def ppa(idata, model, summary="octiles", init=None):
def ppa(idata, model, summary="octiles", references=0, init=None):
"""
Prior predictive check assistant.
Expand All @@ -31,10 +31,13 @@ def ppa(idata, model, summary="octiles", init=None):
With at least the `prior` and `prior_predictive` groups
model : PyMC model
Model associated to ``idata``.
summary : str:
summary : str
Summary statistics applied to prior samples in order to define (dis)similarity
of distributions. Current options are `octiles`, `hexiles`, `quantiles`,
`sort` (sort data) `octi_sum` (robust estimation of first 4 moments from octiles).
references : int, float, list or tuple
Value(s) used as reference points representing prior knowledge. For example expected
values or values that are considered extreme.
init : tuple or PreliZ distribtuion
Initial distribution. The first shown distributions will be selected to be as close
as possible to `init`. Available options are, a PreliZ distribution or a 2-tuple with
Expand All @@ -54,6 +57,10 @@ def ppa(idata, model, summary="octiles", init=None):
)
except NameError:
pass

if isinstance(references, (float, int)):
references = [references]

global pp_samples_idxs # pylint:disable=invalid-name

shown = []
Expand All @@ -66,7 +73,7 @@ def ppa(idata, model, summary="octiles", init=None):
sample_size = pp_samples.shape[0]
pp_summary, kdt = compute_summaries(pp_samples, summary)
pp_samples_idxs, shown = initialize_subsamples(pp_summary, shown, kdt, init)
fig, axes = plot_samples(pp_samples)
fig, axes = plot_samples(pp_samples, references)

clicked = []
selected = []
Expand Down Expand Up @@ -94,6 +101,7 @@ def carry_on_(_):
axes,
radio_buttons_kind.value,
check_button_sharex.value,
references,
clicked,
pp_samples,
pp_summary,
Expand All @@ -111,7 +119,9 @@ def on_return_prior_(_):
button_return_prior.on_click(on_return_prior_)

def kind_(_):
plot_samples(pp_samples, radio_buttons_kind.value, check_button_sharex.value, fig)
plot_samples(
pp_samples, references, radio_buttons_kind.value, check_button_sharex.value, fig
)

radio_buttons_kind.observe(kind_, names=["value"])

Expand Down Expand Up @@ -161,7 +171,18 @@ def on_return_prior(fig, selected, model, sample_size):


def carry_on(
fig, axes, kind, sharex, clicked, pp_samples, pp_summary, choices, selected, shown, kdt
fig,
axes,
kind,
sharex,
references,
clicked,
pp_samples,
pp_summary,
choices,
selected,
shown,
kdt,
):
global pp_samples_idxs # pylint:disable=invalid-name

Expand All @@ -179,7 +200,7 @@ def carry_on(
pp_samples_idxs, shown = keep_sampling(pp_summary, choices, shown, kdt)
if not pp_samples_idxs:
pp_samples_idxs, shown = initialize_subsamples(pp_summary, shown, kdt, None)
fig, _ = plot_samples(pp_samples, kind, sharex, fig)
fig, _ = plot_samples(pp_samples, references, kind, sharex, fig)


def compute_summaries(pp_samples, summary):
Expand Down Expand Up @@ -280,7 +301,7 @@ def keep_sampling(pp_summary, choices, shown, kdt):
return [], shown


def plot_samples(pp_samples, kind="pdf", sharex=True, fig=None):
def plot_samples(pp_samples, references, kind="pdf", sharex=True, fig=None):
row_colum = int(np.ceil(len(pp_samples_idxs) ** 0.5))

if fig is None:
Expand All @@ -297,7 +318,8 @@ def plot_samples(pp_samples, kind="pdf", sharex=True, fig=None):

for ax, idx in zip(axes, pp_samples_idxs):
ax.clear()
ax.axvline(0, ls="--", color="0.5")
for ref in references:
ax.axvline(ref, ls="--", color="0.5")
ax.relim()

sample = pp_samples[idx]
Expand Down

0 comments on commit d65ff34

Please sign in to comment.