diff --git a/jaxampler/_src/rvs/weibull.py b/jaxampler/_src/rvs/weibull.py index df3648a..c86115b 100644 --- a/jaxampler/_src/rvs/weibull.py +++ b/jaxampler/_src/rvs/weibull.py @@ -26,13 +26,19 @@ class Weibull(ContinuousRV): - def __init__(self, lmbda: Numeric | Any, k: Numeric | Any, name: Optional[str] = None) -> None: - shape, self._lmbda, self._k = jx_cast(lmbda, k) + def __init__( + self, + k: Numeric | Any, + loc: Numeric | Any = 0.0, + scale: Numeric | Any = 1.0, + name: Optional[str] = None, + ) -> None: + shape, self._k, self._loc, self._scale = jx_cast(k, loc, scale) self.check_params() super().__init__(name=name, shape=shape) def check_params(self) -> None: - assert jnp.all(self._lmbda > 0.0), "scale must be greater than 0" + assert jnp.all(self._scale > 0.0), "scale must be greater than 0" assert jnp.all(self._k > 0.0), "concentration must be greater than 0" @partial(jit, static_argnums=(0,)) @@ -41,9 +47,9 @@ def logpdf_x(self, x: Numeric) -> Numeric | tuple[Numeric, ...]: x <= 0, jnp.full_like(x, -jnp.inf), jnp.log(self._k) - - (self._k * jnp.log(self._lmbda)) - + (self._k - 1.0) * jnp.log(x) - - jnp.power(x / self._lmbda, self._k), + - (self._k * jnp.log(self._scale)) + + (self._k - 1.0) * jnp.log(x - self._loc) + - jnp.power(x / self._scale, self._k), ) @partial(jit, static_argnums=(0,)) @@ -51,12 +57,12 @@ def cdf_x(self, x: Numeric) -> Numeric: return jnp.where( x <= 0.0, 0.0, - 1.0 - jnp.exp(-jnp.power(x / self._lmbda, self._k)), + 1.0 - jnp.exp(-jnp.power((x - self._loc) / self._scale, self._k)), ) @partial(jit, static_argnums=(0,)) def ppf_x(self, x: Numeric) -> Numeric: - return self._lmbda * jnp.power(-jnp.log(1.0 - x), 1.0 / self._k) + return self._loc + self._scale * jnp.power(-jnp.log(1.0 - x), 1.0 / self._k) def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array: if key is None: @@ -66,7 +72,7 @@ def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array: return self.ppf_x(U) def __repr__(self) -> str: - string = f"Weibull(lambda={self._lmbda}, k={self._k}" + string = f"Weibull(k={self._k}, loc={self._loc}, scale={self._scale}" if self._name is not None: string += f", name={self._name}" string += ")" diff --git a/tests/weibull_test.py b/tests/weibull_test.py index 0f61c98..4225429 100644 --- a/tests/weibull_test.py +++ b/tests/weibull_test.py @@ -25,44 +25,44 @@ class TestWeibull: def test_pdf(self): - W = Weibull(lmbda=1, k=1) + W = Weibull(scale=1, k=1) assert jnp.allclose(W.pdf_x(1.0), 1 / jnp.e) assert jnp.allclose(W.pdf_x(0.0), 0) def test_negative_x(self): - assert jnp.allclose(Weibull(lmbda=1, k=1).pdf_x(-1.0), 0) + assert jnp.allclose(Weibull(scale=1, k=1).pdf_x(-1.0), 0) def test_negative_lambda(self): with pytest.raises(AssertionError): - Weibull(lmbda=-1, k=1) + Weibull(scale=-1, k=1) def test_negative_k(self): with pytest.raises(AssertionError): - Weibull(lmbda=1, k=-1) + Weibull(scale=1, k=-1) def test_negative_lambda_and_k(self): with pytest.raises(AssertionError): - Weibull(lmbda=-1, k=-1) + Weibull(scale=-1, k=-1) def test_shapes(self): - assert jnp.allclose(Weibull(lmbda=[1, 1], k=[1, 1]).pdf_x(1.0), jnp.array([0.3678794412, 0.3678794412])) + assert jnp.allclose(Weibull(scale=[1, 1], k=[1, 1]).pdf_x(1.0), jnp.array([0.3678794412, 0.3678794412])) assert jnp.allclose( - Weibull(lmbda=[1, 1, 1], k=[1, 1, 1]).pdf_x(1.0), + Weibull(scale=[1, 1, 1], k=[1, 1, 1]).pdf_x(1.0), jnp.array([0.3678794412, 0.3678794412, 0.3678794412]), ) def test_cdf_x(self): - W = Weibull(lmbda=[1, 1, 1], k=[1, 1, 1]) + W = Weibull(scale=[1, 1, 1], k=[1, 1, 1]) assert jnp.allclose(W.cdf_x(1), 1 - 1 / jnp.e) assert jnp.allclose(W.cdf_x(0), 0) def test_ppf(self): - W = Weibull(lmbda=1, k=1) + W = Weibull(scale=1, k=1) assert jnp.allclose(W.ppf_x(0.0), 0.0) assert jnp.allclose(W.ppf_x(1 - 1 / jnp.e), 1) def test_rvs(self): - W = Weibull(lmbda=1, k=1) + W = Weibull(scale=1, k=1) rvs = W.rvs((1000, 1000), key=jax.random.PRNGKey(0)) assert jnp.allclose(rvs.mean(), 1.0, atol=0.1) assert jnp.allclose(rvs.std(), 1.0, atol=0.1)