Skip to content

Commit

Permalink
remove unnecesary calls to rv_frozen method (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Dec 13, 2022
1 parent c49e1a6 commit 3f8c0c8
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions preliz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ def _finite_endpoints(self, support):
lower_ep, upper_ep = self.support

if not np.isfinite(lower_ep) or support == "restricted":
lower_ep = self.rv_frozen.ppf(0.0001)
lower_ep = self.ppf(0.0001)
if not np.isfinite(upper_ep) or support == "restricted":
upper_ep = self.rv_frozen.ppf(0.9999)
upper_ep = self.ppf(0.9999)

return lower_ep, upper_ep

Expand Down
15 changes: 8 additions & 7 deletions preliz/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


@pytest.fixture(scope="function")
def a_dist():
return pz.Beta(2, 6)
def two_dist():
return pz.Beta(2, 6), pz.Poisson(4.5)


@pytest.mark.parametrize(
Expand All @@ -24,8 +24,9 @@ def a_dist():
{"ax": plt.subplots()[1]},
],
)
def test_plot_pdf_cdf_ppf(a_dist, kwargs):
a_dist.plot_pdf(**kwargs)
a_dist.plot_cdf(**kwargs)
kwargs.pop("support", None)
a_dist.plot_ppf(**kwargs)
def test_continuous_plot_pdf_cdf_ppf(two_dist, kwargs):
for a_dist in two_dist:
a_dist.plot_pdf(**kwargs)
a_dist.plot_cdf(**kwargs)
kwargs.pop("support", None)
a_dist.plot_ppf(**kwargs)
11 changes: 5 additions & 6 deletions preliz/utils/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
def optimize_max_ent(dist, lower, upper, mass):
def prob_bound(params, dist, lower, upper, mass):
dist._update(*params)
rv_frozen = dist.rv_frozen
if dist.kind == "discrete":
lower -= 1
cdf0 = rv_frozen.cdf(lower)
cdf1 = rv_frozen.cdf(upper)
cdf0 = dist.cdf(lower)
cdf1 = dist.cdf(upper)
loss = (cdf1 - cdf0) - mass
return loss

Expand Down Expand Up @@ -47,7 +46,7 @@ def entropy_loss(params, dist):
def optimize_quartile(dist, x_vals):
def func(params, dist, x_vals):
dist._update(*params)
loss = dist.rv_frozen.cdf(x_vals) - [0.25, 0.5, 0.75]
loss = dist.cdf(x_vals) - [0.25, 0.5, 0.75]
return loss

init_vals = dist.params
Expand All @@ -72,7 +71,7 @@ def func(params, dist, x_vals):
def optimize_cdf(dist, x_vals, ecdf, **kwargs):
def func(params, dist, x_vals, ecdf, **kwargs):
dist._update(*params, **kwargs)
loss = dist.rv_frozen.cdf(x_vals) - ecdf
loss = dist.cdf(x_vals) - ecdf
return loss

init_vals = dist.params[:2]
Expand Down Expand Up @@ -115,7 +114,7 @@ def negll(params, dist, sample):
def relative_error(dist, lower, upper, required_mass):
if dist.kind == "discrete":
lower -= 1
computed_mass = dist.rv_frozen.cdf(upper) - dist.rv_frozen.cdf(lower)
computed_mass = dist.cdf(upper) - dist.cdf(lower)
return abs((computed_mass - required_mass) / required_mass * 100), computed_mass


Expand Down
8 changes: 4 additions & 4 deletions preliz/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def plot_pdfpmf(dist, moments, pointinterval, quantiles, support, legend, figsiz

x = dist.xvals(support)
if dist.kind == "continuous":
density = dist.rv_frozen.pdf(x)
density = dist.pdf(x)
ax.plot(x, density, label=label, color=color)
ax.get_ylim()
ax.set_yticks([])
else:
mass = dist.rv_frozen.pmf(x)
mass = dist.pdf(x)
eps = np.clip(dist._finite_endpoints(support), *dist.support)
x_c = np.linspace(*eps, 1000)

Expand Down Expand Up @@ -111,7 +111,7 @@ def plot_cdf(dist, moments, pointinterval, quantiles, support, legend, figsize,

eps = dist._finite_endpoints(support)
x = np.linspace(*eps, 1000)
cdf = dist.rv_frozen.cdf(x)
cdf = dist.cdf(x)
ax.plot(x, cdf, label=label, color=color)

if pointinterval:
Expand Down Expand Up @@ -139,7 +139,7 @@ def plot_ppf(dist, moments, pointinterval, quantiles, legend, figsize, ax):
label = None

x = np.linspace(0, 1, 1000)
ax.plot(x, dist.rv_frozen.ppf(x), label=label, color=color)
ax.plot(x, dist.ppf(x), label=label, color=color)

if pointinterval:
plot_pointinterval(dist.rv_frozen, quantiles=quantiles, rotated=True, ax=ax)
Expand Down

0 comments on commit 3f8c0c8

Please sign in to comment.