Skip to content

Commit

Permalink
Loc-scale variant of Weibull distribution (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash authored Jan 22, 2024
1 parent 9dd7988 commit bb3af6f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
24 changes: 15 additions & 9 deletions jaxampler/_src/rvs/weibull.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,19 @@


class Weibull(ContinuousRV):
def __init__(self, lmbda: Numeric | Any, k: Numeric | Any, name: Optional[str] = None) -> None:
shape, self._lmbda, self._k = jx_cast(lmbda, k)
def __init__(
self,
k: Numeric | Any,
loc: Numeric | Any = 0.0,
scale: Numeric | Any = 1.0,
name: Optional[str] = None,
) -> None:
shape, self._k, self._loc, self._scale = jx_cast(k, loc, scale)
self.check_params()
super().__init__(name=name, shape=shape)

def check_params(self) -> None:
assert jnp.all(self._lmbda > 0.0), "scale must be greater than 0"
assert jnp.all(self._scale > 0.0), "scale must be greater than 0"
assert jnp.all(self._k > 0.0), "concentration must be greater than 0"

@partial(jit, static_argnums=(0,))
Expand All @@ -41,22 +47,22 @@ def logpdf_x(self, x: Numeric) -> Numeric | tuple[Numeric, ...]:
x <= 0,
jnp.full_like(x, -jnp.inf),
jnp.log(self._k)
- (self._k * jnp.log(self._lmbda))
+ (self._k - 1.0) * jnp.log(x)
- jnp.power(x / self._lmbda, self._k),
- (self._k * jnp.log(self._scale))
+ (self._k - 1.0) * jnp.log(x - self._loc)
- jnp.power(x / self._scale, self._k),
)

@partial(jit, static_argnums=(0,))
def cdf_x(self, x: Numeric) -> Numeric:
return jnp.where(
x <= 0.0,
0.0,
1.0 - jnp.exp(-jnp.power(x / self._lmbda, self._k)),
1.0 - jnp.exp(-jnp.power((x - self._loc) / self._scale, self._k)),
)

@partial(jit, static_argnums=(0,))
def ppf_x(self, x: Numeric) -> Numeric:
return self._lmbda * jnp.power(-jnp.log(1.0 - x), 1.0 / self._k)
return self._loc + self._scale * jnp.power(-jnp.log(1.0 - x), 1.0 / self._k)

def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
if key is None:
Expand All @@ -66,7 +72,7 @@ def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
return self.ppf_x(U)

def __repr__(self) -> str:
string = f"Weibull(lambda={self._lmbda}, k={self._k}"
string = f"Weibull(k={self._k}, loc={self._loc}, scale={self._scale}"
if self._name is not None:
string += f", name={self._name}"
string += ")"
Expand Down
20 changes: 10 additions & 10 deletions tests/weibull_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,44 @@

class TestWeibull:
def test_pdf(self):
W = Weibull(lmbda=1, k=1)
W = Weibull(scale=1, k=1)
assert jnp.allclose(W.pdf_x(1.0), 1 / jnp.e)
assert jnp.allclose(W.pdf_x(0.0), 0)

def test_negative_x(self):
assert jnp.allclose(Weibull(lmbda=1, k=1).pdf_x(-1.0), 0)
assert jnp.allclose(Weibull(scale=1, k=1).pdf_x(-1.0), 0)

def test_negative_lambda(self):
with pytest.raises(AssertionError):
Weibull(lmbda=-1, k=1)
Weibull(scale=-1, k=1)

def test_negative_k(self):
with pytest.raises(AssertionError):
Weibull(lmbda=1, k=-1)
Weibull(scale=1, k=-1)

def test_negative_lambda_and_k(self):
with pytest.raises(AssertionError):
Weibull(lmbda=-1, k=-1)
Weibull(scale=-1, k=-1)

def test_shapes(self):
assert jnp.allclose(Weibull(lmbda=[1, 1], k=[1, 1]).pdf_x(1.0), jnp.array([0.3678794412, 0.3678794412]))
assert jnp.allclose(Weibull(scale=[1, 1], k=[1, 1]).pdf_x(1.0), jnp.array([0.3678794412, 0.3678794412]))
assert jnp.allclose(
Weibull(lmbda=[1, 1, 1], k=[1, 1, 1]).pdf_x(1.0),
Weibull(scale=[1, 1, 1], k=[1, 1, 1]).pdf_x(1.0),
jnp.array([0.3678794412, 0.3678794412, 0.3678794412]),
)

def test_cdf_x(self):
W = Weibull(lmbda=[1, 1, 1], k=[1, 1, 1])
W = Weibull(scale=[1, 1, 1], k=[1, 1, 1])
assert jnp.allclose(W.cdf_x(1), 1 - 1 / jnp.e)
assert jnp.allclose(W.cdf_x(0), 0)

def test_ppf(self):
W = Weibull(lmbda=1, k=1)
W = Weibull(scale=1, k=1)
assert jnp.allclose(W.ppf_x(0.0), 0.0)
assert jnp.allclose(W.ppf_x(1 - 1 / jnp.e), 1)

def test_rvs(self):
W = Weibull(lmbda=1, k=1)
W = Weibull(scale=1, k=1)
rvs = W.rvs((1000, 1000), key=jax.random.PRNGKey(0))
assert jnp.allclose(rvs.mean(), 1.0, atol=0.1)
assert jnp.allclose(rvs.std(), 1.0, atol=0.1)

0 comments on commit bb3af6f

Please sign in to comment.