Skip to content

Commit

Permalink
update signiture
Browse files Browse the repository at this point in the history
  • Loading branch information
NKNaN committed Dec 20, 2023
1 parent 3fc4ad6 commit 6f5e3df
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 98 deletions.
64 changes: 28 additions & 36 deletions python/paddle/distribution/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class Binomial(distribution.Distribution):
r"""
The Binomial distribution with size `total_count` and `probability` parameters.
The Binomial distribution with size `total_count` and `probs` parameters.
In probability theory and statistics, the binomial distribution is the most basic discrete probability distribution defined on :math:`[0, n] \cap \mathbb{N}`,
which can be viewed as the number of times a potentially unfair coin is tossed to get heads, and the result
Expand All @@ -34,14 +34,14 @@ class Binomial(distribution.Distribution):
In the above equation:
* :math:`total_count = n`: is the size, meaning the total number of Bernoulli experiments.
* :math:`probability = p`: is the probability of the event happening in one Bernoulli experiments.
* :math:`total\_count = n`: is the size, meaning the total number of Bernoulli experiments.
* :math:`probs = p`: is the probability of the event happening in one Bernoulli experiments.
Args:
total_count(int|Tensor): The size of Binomial distribution which should be greater than 0, meaning the number of independent bernoulli
trials with probability parameter :math:`p`. The data type will be converted to 1-D Tensor with paddle global default dtype if the input
:attr:`probability` is not Tensor, otherwise will be converted to the same as :attr:`probability`.
probability(float|Tensor): The probability of Binomial distribution which should reside in [0, 1], meaning the probability of success
:attr:`probs` is not Tensor, otherwise will be converted to the same as :attr:`probs`.
probs(float|Tensor): The probability of Binomial distribution which should reside in [0, 1], meaning the probability of success
for each individual bernoulli trial. If the input data type is float, it will be converted to a 1-D Tensor with paddle global default dtype.
Examples:
Expand All @@ -67,53 +67,51 @@ class Binomial(distribution.Distribution):
[2.94053698, 3.00781751, 2.51124287])
"""

def __init__(self, total_count, probability):
def __init__(self, total_count, probs):
self.dtype = paddle.get_default_dtype()
self.total_count, self.probability = self._to_tensor(
total_count, probability
)
self.total_count, self.probs = self._to_tensor(total_count, probs)

if not self._check_constraint(self.total_count, self.probability):
if not self._check_constraint(self.total_count, self.probs):
raise ValueError(
'Every element of input parameter `total_count` should be grater than or equal to one, and `probability` should be grater than or equal to zero and less than or equal to one.'
'Every element of input parameter `total_count` should be grater than or equal to one, and `probs` should be grater than or equal to zero and less than or equal to one.'
)
if self.total_count.shape == []:
batch_shape = (1,)
else:
batch_shape = self.total_count.shape
super().__init__(batch_shape)

def _to_tensor(self, total_count, probability):
def _to_tensor(self, total_count, probs):
"""Convert the input parameters into Tensors if they were not and broadcast them
Returns:
Tuple[Tensor, Tensor]: converted total_count and probability.
Tuple[Tensor, Tensor]: converted total_count and probs.
"""
# convert type
if isinstance(probability, float):
probability = paddle.to_tensor(probability, dtype=self.dtype)
if isinstance(probs, float):
probs = paddle.to_tensor(probs, dtype=self.dtype)
else:
self.dtype = probability.dtype
self.dtype = probs.dtype
if isinstance(total_count, int):
total_count = paddle.to_tensor(total_count, dtype=self.dtype)
else:
total_count = paddle.cast(total_count, dtype=self.dtype)

# broadcast tensor
return paddle.broadcast_tensors([total_count, probability])
return paddle.broadcast_tensors([total_count, probs])

def _check_constraint(self, total_count, probability):
def _check_constraint(self, total_count, probs):
"""Check the constraints for input parameters
Args:
total_count (Tensor)
probability (Tensor)
probs (Tensor)
Returns:
bool: pass or not.
"""
total_count_check = (total_count >= 1).all()
probability_check = (probability >= 0).all() * (probability <= 1).all()
probability_check = (probs >= 0).all() * (probs <= 1).all()
return total_count_check and probability_check

@property
Expand All @@ -123,7 +121,7 @@ def mean(self):
Returns:
Tensor: mean value.
"""
return self.total_count * self.probability
return self.total_count * self.probs

@property
def variance(self):
Expand All @@ -132,7 +130,7 @@ def variance(self):
Returns:
Tensor: variance value.
"""
return self.total_count * self.probability * (1 - self.probability)
return self.total_count * self.probs * (1 - self.probs)

def sample(self, shape=()):
"""Generate binomial samples of the specified shape. The final shape would be ``shape+batch_shape`` .
Expand All @@ -141,7 +139,7 @@ def sample(self, shape=()):
shape (Sequence[int], optional): Prepended shape of the generated samples.
Returns:
Tensor: Sampled data with shape `sample_shape` + `batch_shape`. The returned data type is the same as `probability`.
Tensor: Sampled data with shape `sample_shape` + `batch_shape`. The returned data type is the same as `probs`.
"""
if not isinstance(shape, Sequence):
raise TypeError('sample shape must be Sequence object.')
Expand All @@ -153,9 +151,7 @@ def sample(self, shape=()):
output_size = paddle.broadcast_to(
self.total_count, shape=output_shape
)
output_prob = paddle.broadcast_to(
self.probability, shape=output_shape
)
output_prob = paddle.broadcast_to(self.probs, shape=output_shape)
sample = paddle.binomial(
paddle.cast(output_size, dtype="int32"), output_prob
)
Expand All @@ -174,12 +170,8 @@ def entropy(self):
* :math:`\Omega`: is the support of the distribution.
Args:
n (float): size of the binomial r.v.
p (float): probability of the binomial r.v.
Returns:
Tensor: Shannon entropy of binomial distribution. The data type is the same as `probability`.
Tensor: Shannon entropy of binomial distribution. The data type is the same as `probs`.
"""
values = self._enumerate_support()
log_prob = self.log_prob(values)
Expand All @@ -204,7 +196,7 @@ def log_prob(self, value):
value (Tensor): The input tensor.
Returns:
Tensor: log probability. The data type is the same as `probability`.
Tensor: log probability. The data type is the same as `probs`.
"""
value = paddle.cast(value, dtype=self.dtype)

Expand All @@ -214,8 +206,8 @@ def log_prob(self, value):
- paddle.lgamma(self.total_count - value + 1.0)
- paddle.lgamma(value + 1.0)
)
eps = paddle.finfo(self.probability.dtype).eps
probs = paddle.clip(self.probability, min=eps, max=1 - eps)
eps = paddle.finfo(self.probs.dtype).eps
probs = paddle.clip(self.probs, min=eps, max=1 - eps)
# log_p
return paddle.nan_to_num(
(
Expand All @@ -233,7 +225,7 @@ def prob(self, value):
value (Tensor): The input tensor.
Returns:
Tensor: probability. The data type is the same as `probability`.
Tensor: probability. The data type is the same as `probs`.
"""
return paddle.exp(self.log_prob(value))

Expand All @@ -258,7 +250,7 @@ def kl_divergence(self, other):
other (Binomial): instance of ``Binomial``.
Returns:
Tensor: kl-divergence between two binomial distributions. The data type is the same as `probability`.
Tensor: kl-divergence between two binomial distributions. The data type is the same as `probs`.
"""
if not (paddle.equal(self.total_count, other.total_count)).all():
Expand Down
56 changes: 26 additions & 30 deletions test/distribution/test_distribution_binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'total_count', 'probability'),
(parameterize.TEST_CASE_NAME, 'total_count', 'probs'),
[
(
'one-dim',
Expand All @@ -43,37 +43,37 @@ class TestBinomial(unittest.TestCase):
def setUp(self):
self._dist = Binomial(
total_count=paddle.to_tensor(self.total_count),
probability=paddle.to_tensor(self.probability),
probs=paddle.to_tensor(self.probs),
)

def test_mean(self):
mean = self._dist.mean
self.assertEqual(mean.numpy().dtype, self.probability.dtype)
self.assertEqual(mean.numpy().dtype, self.probs.dtype)
np.testing.assert_allclose(
mean,
self._np_mean(),
rtol=config.RTOL.get(str(self.probability.dtype)),
atol=config.ATOL.get(str(self.probability.dtype)),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)),
)

def test_variance(self):
var = self._dist.variance
self.assertEqual(var.numpy().dtype, self.probability.dtype)
self.assertEqual(var.numpy().dtype, self.probs.dtype)
np.testing.assert_allclose(
var,
self._np_variance(),
rtol=config.RTOL.get(str(self.probability.dtype)),
atol=config.ATOL.get(str(self.probability.dtype)),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)),
)

def test_entropy(self):
entropy = self._dist.entropy()
self.assertEqual(entropy.numpy().dtype, self.probability.dtype)
self.assertEqual(entropy.numpy().dtype, self.probs.dtype)
np.testing.assert_allclose(
entropy,
self._np_entropy(),
rtol=config.RTOL.get(str(self.probability.dtype)),
atol=config.ATOL.get(str(self.probability.dtype)),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)),
)

def test_sample(self):
Expand All @@ -97,18 +97,18 @@ def test_sample(self):
)

def _np_variance(self):
return scipy.stats.binom.var(self.total_count, self.probability)
return scipy.stats.binom.var(self.total_count, self.probs)

def _np_mean(self):
return scipy.stats.binom.mean(self.total_count, self.probability)
return scipy.stats.binom.mean(self.total_count, self.probs)

def _np_entropy(self):
return scipy.stats.binom.entropy(self.total_count, self.probability)
return scipy.stats.binom.entropy(self.total_count, self.probs)


@parameterize.place(config.DEVICES)
@parameterize.parameterize_cls(
(parameterize.TEST_CASE_NAME, 'total_count', 'probability', 'value'),
(parameterize.TEST_CASE_NAME, 'total_count', 'probs', 'value'),
[
(
'value-same-shape',
Expand All @@ -128,27 +128,23 @@ class TestBinomialProbs(unittest.TestCase):
def setUp(self):
self._dist = Binomial(
total_count=self.total_count,
probability=paddle.to_tensor(self.probability),
probs=paddle.to_tensor(self.probs),
)

def test_prob(self):
np.testing.assert_allclose(
self._dist.prob(paddle.to_tensor(self.value)),
scipy.stats.binom.pmf(
self.value, self.total_count, self.probability
),
rtol=config.RTOL.get(str(self.probability.dtype)),
atol=config.ATOL.get(str(self.probability.dtype)),
scipy.stats.binom.pmf(self.value, self.total_count, self.probs),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)),
)

def test_log_prob(self):
np.testing.assert_allclose(
self._dist.log_prob(paddle.to_tensor(self.value)),
scipy.stats.binom.logpmf(
self.value, self.total_count, self.probability
),
rtol=config.RTOL.get(str(self.probability.dtype)),
atol=config.ATOL.get(str(self.probability.dtype)),
scipy.stats.binom.logpmf(self.value, self.total_count, self.probs),
rtol=config.RTOL.get(str(self.probs.dtype)),
atol=config.ATOL.get(str(self.probs.dtype)),
)


Expand Down Expand Up @@ -176,11 +172,11 @@ class TestBinomialKL(unittest.TestCase):
def setUp(self):
self._dist1 = Binomial(
total_count=paddle.to_tensor(self.n_1),
probability=paddle.to_tensor(self.p_1),
probs=paddle.to_tensor(self.p_1),
)
self._dist2 = Binomial(
total_count=paddle.to_tensor(self.n_2),
probability=paddle.to_tensor(self.p_2),
probs=paddle.to_tensor(self.p_2),
)

def test_kl_divergence(self):
Expand All @@ -200,10 +196,10 @@ def kl_divergence(self, dist1, dist2):
support = np.arange(1 + self.n_1.max(), dtype=self.p_1.dtype)
support = support.reshape((-1,) + (1,) * len(self.p_1.shape))
log_prob_1 = scipy.stats.binom.logpmf(
support, dist1.total_count, dist1.probability
support, dist1.total_count, dist1.probs
)
log_prob_2 = scipy.stats.binom.logpmf(
support, dist2.total_count, dist2.probability
support, dist2.total_count, dist2.probs
)
return (np.exp(log_prob_1) * (log_prob_1 - log_prob_2)).sum(0)

Expand Down
Loading

0 comments on commit 6f5e3df

Please sign in to comment.