Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add faster HalfNormal distribution #346

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ This reference provides detailed documentation for user functions in the current
:mod:`preliz.distributions.continuous`
======================================

.. automodule:: preliz.distributions.halfnormal
:members:

.. automodule:: preliz.distributions.normal
:members:

Expand Down
94 changes: 2 additions & 92 deletions preliz/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from ..internal.distribution_helper import garcia_approximation, all_not_none, any_not_none
from .distributions import Continuous
from .normal import Normal # pylint: disable=unused-import
from .halfnormal import HalfNormal # pylint: disable=unused-import


eps = np.finfo(float).eps

Expand Down Expand Up @@ -963,98 +965,6 @@ def _fit_mle(self, sample, **kwargs):
self._update(beta)


class HalfNormal(Continuous):
r"""
HalfNormal Distribution

The pdf of this distribution is

.. math::

f(x \mid \sigma) =
\sqrt{\frac{2}{\pi\sigma^2}}
\exp\left(\frac{-x^2}{2\sigma^2}\right)

.. plot::
:context: close-figs

import arviz as az
from preliz import HalfNormal
az.style.use('arviz-white')
for sigma in [0.4, 2.]:
HalfNormal(sigma).plot_pdf(support=(0,5))

======== ==========================================
Support :math:`x \in [0, \infty)`
Mean :math:`\dfrac{\sigma \sqrt{2}}{\sqrt{\pi}}`
Variance :math:`\sigma^2\left(1 - \dfrac{2}{\pi}\right)`
======== ==========================================

HalfNormal distribution has 2 alternative parameterizations. In terms of sigma (standard
deviation) or tau (precision).

The link between the 2 alternatives is given by

.. math::

\tau = \frac{1}{\sigma^2}

Parameters
----------
sigma : float
Scale parameter :math:`\sigma` (``sigma`` > 0).
tau : float
Precision :math:`\tau` (``tau`` > 0).
"""

def __init__(self, sigma=None, tau=None):
super().__init__()
self.dist = copy(stats.halfnorm)
self.support = (0, np.inf)
self._parametrization(sigma, tau)

def _parametrization(self, sigma=None, tau=None):
if all_not_none(sigma, tau):
raise ValueError("Incompatible parametrization. Either use sigma or tau.")

self.param_names = ("sigma",)
self.params_support = ((eps, np.inf),)

if tau is not None:
sigma = from_precision(tau)
self.param_names = ("tau",)

self.sigma = sigma
self.tau = tau
if self.sigma is not None:
self._update(self.sigma)

def _get_frozen(self):
frozen = None
if all_not_none(self.params):
frozen = self.dist(scale=self.sigma)
return frozen

def _update(self, sigma):
self.sigma = np.float64(sigma)
self.tau = to_precision(sigma)

if self.param_names[0] == "sigma":
self.params = (self.sigma,)
elif self.param_names[0] == "tau":
self.params = (self.tau,)

self._update_rv_frozen()

def _fit_moments(self, mean, sigma): # pylint: disable=unused-argument
sigma = sigma / (1 - 2 / np.pi) ** 0.5
self._update(sigma)

def _fit_mle(self, sample, **kwargs):
sigma = np.mean(sample**2) ** 0.5
self._update(sigma)


class HalfStudentT(Continuous):
r"""
HalfStudentT Distribution
Expand Down
178 changes: 178 additions & 0 deletions preliz/distributions/halfnormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# pylint: disable=attribute-defined-outside-init
# pylint: disable=arguments-differ
import numba as nb
import numpy as np
from scipy.special import erf, erfinv # pylint: disable=no-name-in-module

from .distributions import Continuous
from ..internal.distribution_helper import eps, to_precision, from_precision, all_not_none


class HalfNormal(Continuous):
r"""
HalfNormal Distribution

The pdf of this distribution is

.. math::

f(x \mid \sigma) =
\sqrt{\frac{2}{\pi\sigma^2}}
\exp\left(\frac{-x^2}{2\sigma^2}\right)

.. plot::
:context: close-figs

import arviz as az
from preliz import HalfNormal
az.style.use('arviz-white')
for sigma in [0.4, 2.]:
HalfNormal(sigma).plot_pdf(support=(0,5))

======== ==========================================
Support :math:`x \in [0, \infty)`
Mean :math:`\dfrac{\sigma \sqrt{2}}{\sqrt{\pi}}`
Variance :math:`\sigma^2\left(1 - \dfrac{2}{\pi}\right)`
======== ==========================================

HalfNormal distribution has 2 alternative parameterizations. In terms of sigma (standard
deviation) or tau (precision).

The link between the 2 alternatives is given by

.. math::

\tau = \frac{1}{\sigma^2}

Parameters
----------
sigma : float
Scale parameter :math:`\sigma` (``sigma`` > 0).
tau : float
Precision :math:`\tau` (``tau`` > 0).
"""

def __init__(self, sigma=None, tau=None):
super().__init__()
self.support = (0, np.inf)
self._parametrization(sigma, tau)

def _parametrization(self, sigma=None, tau=None):
if all_not_none(sigma, tau):
raise ValueError("Incompatible parametrization. Either use sigma or tau.")

self.param_names = ("sigma",)
self.params_support = ((eps, np.inf),)

if tau is not None:
sigma = from_precision(tau)
self.param_names = ("tau",)

self.sigma = sigma
self.tau = tau
if self.sigma is not None:
self._update(self.sigma)

def _update(self, sigma):
self.sigma = np.float64(sigma)
self.tau = to_precision(sigma)

if self.param_names[0] == "sigma":
self.params = (self.sigma,)
elif self.param_names[0] == "tau":
self.params = (self.tau,)

self.is_frozen = True

def pdf(self, x):
"""
Compute the probability density function (PDF) at a given point x.
"""
return nb_pdf(x, self.sigma)

def cdf(self, x):
"""
Compute the cumulative distribution function (CDF) at a given point x.
"""
return nb_cdf(x, self.sigma)

def ppf(self, q):
"""
Compute the percent point function (PPF) at a given probability q.
"""
return nb_ppf(q, self.sigma)

def logpdf(self, x):
"""
Compute the log probability density function (log PDF) at a given point x.
"""
return _logpdf(x, self.sigma)

def entropy(self):
return nb_entropy(self.sigma)

def mean(self):
return self.sigma * 0.7978845608028655

def median(self):
return self.sigma * 0.6744897501960818

def var(self):
return self.sigma**2 * 0.3633802276324186

def std(self):
return self.sigma * 0.6028102749890869

def skewness(self):
return 0.9952717464311565

def kurtosis(self):
return 0.8691773036059736

def rvs(self, size=1, random_state=None):
random_state = np.random.default_rng(random_state)
return np.abs(random_state.normal(0, self.sigma, size))

def _fit_moments(self, mean, sigma): # pylint: disable=unused-argument
self._update(sigma / (1 - 2 / np.pi) ** 0.5)

def _fit_mle(self, sample):
self._update(nb_fit_mle(sample))


# @nb.jit
# erf not supported by numba
def nb_cdf(x, sigma):
x = np.asarray(x)
return erf(x / (sigma * 2**0.5))


# @nb.jit
# erfinv not supported by numba
def nb_ppf(q, sigma):
q = np.asarray(q)
return sigma * 2**0.5 * erfinv(q)


@nb.njit
def nb_pdf(x, sigma):
x = np.asarray(x)
return np.where(x < 0, 0, np.sqrt(2 / np.pi) * (1 / sigma) * np.exp(-0.5 * (x / sigma) ** 2))


@nb.njit
def nb_entropy(sigma):
return 0.5 * np.log(np.pi * sigma**2.0 / 2.0) + 0.5


@nb.njit
def nb_fit_mle(sample):
return np.mean(sample**2) ** 0.5


@nb.njit
def _logpdf(x, sigma):
x = np.asarray(x)
return np.where(
x < 0, -np.inf, np.log(np.sqrt(2 / np.pi)) + np.log(1 / sigma) - 0.5 * ((x / sigma) ** 2)
)
11 changes: 7 additions & 4 deletions preliz/tests/test_dist_scipy.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import pytest
from numpy.testing import assert_almost_equal
import numpy as np
from scipy import stats


from preliz.distributions import Normal
from scipy import stats
from preliz.distributions import Normal, HalfNormal


@pytest.mark.parametrize(
"p_dist, sp_dist, p_params, sp_params",
[(Normal, stats.norm, {"mu": 0, "sigma": 2}, {"loc": 0, "scale": 2})],
[
(Normal, stats.norm, {"mu": 0, "sigma": 2}, {"loc": 0, "scale": 2}),
(HalfNormal, stats.halfnorm, {"sigma": 2}, {"scale": 2}),
],
)
def test_lala(p_dist, sp_dist, p_params, sp_params):
def test_match_scipy(p_dist, sp_dist, p_params, sp_params):
preliz_dist = p_dist(**p_params)
scipy_dist = sp_dist(**sp_params)

Expand Down
Loading