Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
NKNaN committed Oct 8, 2023
1 parent bffd1ee commit 91f9e01
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 23 deletions.
3 changes: 2 additions & 1 deletion python/paddle/distribution/binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def entropy(self):
numpy.ndarray: the entropy for the binomial r.v.
"""
values = self._enumerate_support()
log_prob = paddle.nan_to_num(self.log_prob(values), neginf=0)
eps = paddle.finfo(self.probability.dtype).eps
log_prob = paddle.nan_to_num(self.log_prob(values), neginf=eps)
return -(paddle.exp(log_prob) * log_prob).sum(0)

def _enumerate_support(self):
Expand Down
17 changes: 4 additions & 13 deletions python/paddle/distribution/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,8 @@ def entropy(self):
values = self._enumerate_bounded_support(self.rate).reshape(
(-1,) + (1,) * len(self.batch_shape)
)
log_prob = paddle.nan_to_num(
(
-self.rate
+ values * paddle.log(self.rate)
- paddle.lgamma(values + 1)
),
neginf=0,
)
return paddle.nan_to_num(
-(paddle.exp(log_prob) * log_prob), posinf=0
).sum(0)
log_prob = self.log_prob(values)
return -(paddle.exp(log_prob) * log_prob).sum(0)

def _enumerate_bounded_support(self, rate):
"""Generate a bounded approximation of the support. Approximately view Poisson r.v. as a Normal r.v. with mu = rate and sigma = sqrt(rate).
Expand Down Expand Up @@ -218,14 +209,14 @@ def log_prob(self, value):
raise ValueError(
'Every element of input parameter `value` should be nonnegative.'
)

eps = paddle.finfo(self.rate.dtype).eps
return paddle.nan_to_num(
(
-self.rate
+ value * paddle.log(self.rate)
- paddle.lgamma(value + 1)
),
neginf=0,
neginf=eps,
)

def prob(self, value):
Expand Down
18 changes: 9 additions & 9 deletions test/distribution/test_distribution_binomial_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
[
(
'one-dim',
np.array([100]),
np.array([20]),
parameterize.xrand((1,), dtype='float32', min=0, max=1),
),
(
'multi-dim',
np.array([100]),
parameterize.xrand((3, 2), dtype='float32', min=0, max=1),
np.array([20]),
parameterize.xrand((1, 3), dtype='float32', min=0, max=1),
),
],
)
Expand All @@ -58,7 +58,7 @@ def setUp(self):
var = dist.variance
entropy = dist.entropy()
mini_samples = dist.sample(shape=())
large_samples = dist.sample(shape=(5000,))
large_samples = dist.sample(shape=(500,))
fetch_list = [mean, var, entropy, mini_samples, large_samples]
feed = {
'probability': self.probability,
Expand Down Expand Up @@ -191,17 +191,17 @@ def test_prob(self):
[
(
'one-dim-probability',
np.array([75]),
np.array([16]),
parameterize.xrand((1,), dtype='float32', min=0, max=1),
np.array([75]),
parameterize.xrand((1,), dtype='float32', min=0, max=1),
),
(
'multi-dim-probability',
np.array([189]),
parameterize.xrand((5, 3), dtype='float32', min=0, max=1),
np.array([189]),
parameterize.xrand((5, 3), dtype='float32', min=0, max=1),
np.array([32]),
parameterize.xrand((1, 2), dtype='float32', min=0, max=1),
np.array([32]),
parameterize.xrand((1, 2), dtype='float32', min=0, max=1),
),
],
)
Expand Down

0 comments on commit 91f9e01

Please sign in to comment.