Skip to content

Commit

Permalink
Dirichlet: more stable joint pdf (#593)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Nov 21, 2024
1 parent 16a47fb commit 9c8b408
Showing 1 changed file with 30 additions and 16 deletions.
46 changes: 30 additions & 16 deletions preliz/internal/plot_helper_multivariate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from functools import reduce
from operator import mul
import warnings

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import tri
from scipy.special import gamma
from .plot_helper import repr_to_matplotlib
from preliz.internal.plot_helper import repr_to_matplotlib
from preliz.internal.special import gammaln


def get_cols_rows(n_plots):
Expand All @@ -26,14 +27,29 @@ def __init__(self, alpha):
"""

self._alpha = np.array(alpha)
self._coef = gamma(np.sum(self._alpha)) / reduce(mul, [gamma(a) for a in self._alpha])
self._coef = np.exp(
gammaln(np.sum(self._alpha)) - np.sum([gammaln(a) for a in self._alpha])
)

self._corners = np.array([[0.0, 0.0], [1.0, 0.0], [0.5, 0.75**0.5]])
self._triangle = tri.Triangulation(self._corners[:, 0], self._corners[:, 1])
self._midpoints = [
(self._corners[(i + 1) % 3] + self._corners[(i + 2) % 3]) / 2.0 for i in range(3)
]

refiner = tri.UniformTriRefiner(self._triangle)
self.trimesh = refiner.refine_triangulation(subdiv=8)
self.pvals = np.nan_to_num(
[self.pdf(self.xy2bc(xy)) for xy in zip(self.trimesh.x, self.trimesh.y)]
)
self.ok = True
if not np.any(self.pvals):
self.ok = False
warnings.warn(
"The joint pdf is to concentrated to plot, use `marginals=True` instead",
stacklevel=2,
)

def xy2bc(self, x_y, tol=1.0e-3):
"""
Converts 2D Cartesian coordinates to barycentric coordinates.
Expand Down Expand Up @@ -64,18 +80,14 @@ def plot(self, ax=None):
subdiv: int
Number of recursive mesh subdivisions to create.
"""
refiner = tri.UniformTriRefiner(self._triangle)
trimesh = refiner.refine_triangulation(subdiv=8)
pvals = [self.pdf(self.xy2bc(xy)) for xy in zip(trimesh.x, trimesh.y)]

hdi_probs = [0.1, 0.5, 0.94]
contour_levels = find_hdi_contours(pvals, hdi_probs)
contour_levels = find_hdi_contours(self.pvals, hdi_probs)
if all(contour_levels == contour_levels[0]):
ax.tricontourf(trimesh, pvals)
ax.tricontourf(self.trimesh, self.pvals)
else:
ax.tricontour(
trimesh,
pvals,
self.trimesh,
self.pvals,
levels=contour_levels,
)
ax.triplot(self._triangle, color="0.8", linestyle="--", linewidth=2)
Expand Down Expand Up @@ -168,11 +180,13 @@ def plot_dirichlet(

else:
if dim == 3:
if axes is None:
_, axes = plt.subplots(1, 1)
DirichletOnSimplex(alpha).plot(ax=axes)
if legend == "title":
axes.set_title(repr_to_matplotlib(dist))
dirichlet_ = DirichletOnSimplex(alpha)
if dirichlet_.ok:
if axes is None:
_, axes = plt.subplots(1, 1)
dirichlet_.plot(ax=axes)
if legend == "title":
axes.set_title(repr_to_matplotlib(dist))
else:
raise ValueError("joint only works for Dirichlet of dim=3")

Expand Down

0 comments on commit 9c8b408

Please sign in to comment.