Skip to content

Commit

Permalink
Update docstrings for ReshapedDistribution.
Browse files Browse the repository at this point in the history
  • Loading branch information
ingmarschuster committed Aug 24, 2023
1 parent 60f9a8c commit 054ed98
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions gpjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,33 +224,33 @@ def __init__(self, distribution: tfd.Distribution, output_shape: Tuple[int, ...]
self._output_shape = output_shape

def mean(self) -> Float[Array, " N ..."]:
r"""Calculates the mean."""
r"""Mean of the base distribution, reshaped to the output shape."""
return jnp.reshape(self._distribution.mean(), self._output_shape)

def median(self) -> Float[Array, " N ..."]:
r"""Calculates the median."""
r"""Median of the base distribution, reshaped to the output shape"""
return jnp.reshape(self._distribution.median(), self._output_shape)

def mode(self) -> Float[Array, " N ..."]:
r"""Calculates the mode."""
r"""Mode of the base distribution, reshaped to the output shape"""
return jnp.reshape(self._distribution.mode(), self._output_shape)

def covariance(self) -> Float[Array, " N ..."]:
r"""Calculates the covariance matrix."""
r"""Covariance of the base distribution, reshaped to the squared output shape"""
return jnp.reshape(
self._distribution.covariance(), self._output_shape + self._output_shape
)

def variance(self) -> Float[Array, " N ..."]:
r"""Calculates the variance."""
r"""Variances of the base distribution, reshaped to the output shape"""
return jnp.reshape(self._distribution.variance(), self._output_shape)

def stddev(self) -> Float[Array, " N ..."]:
r"""Calculates the standard deviation."""
r"""Standard deviations of the base distribution, reshaped to the output shape"""
return jnp.reshape(self._distribution.stddev(), self._output_shape)

def entropy(self) -> ScalarFloat:
r"""Calculates the entropy."""
r"""Entropy of the base distribution."""
return self._distribution.entropy()

def log_prob(
Expand All @@ -264,7 +264,7 @@ def log_prob(
def sample(
self, seed: Any, sample_shape: Tuple[int, ...] = ()
) -> Float[Array, " n N ..."]:
r"""Draws samples from the distribution."""
r"""Draws samples from the distribution and reshapes them to the output shape."""
sample = self._distribution.sample(seed, sample_shape)
return jnp.reshape(sample, sample_shape + self._output_shape)

Expand Down

0 comments on commit 054ed98

Please sign in to comment.